Online EWC
Last section mentions that space and time complexity of Offline EWC can become unacceptable as task number grows.
In light of this, Online EWC is introduced as a variant of the EWC technique. Online EWC compromises the performance for a better complexity than the Offline version. So, it makes sense considering Online EWC as a product of the trade-off between performance and complexity.
How Online EWC works
Online EWC realizes multi-task continual learning by maintaining a single FIM (call it online FIM for differentiating purpose). This online FIM gets updated each time a new task is trained. Denote the online FIM before the update as Fold and the online FIM after the update as Fnew. Let Fc be the FIM corresponding to the current task, and α be the importance coefficient controlling the weight of previous tasks. The update process of the online FIM can then be formulated as follows:
Given the maintenance of a single FIM, suppose we are trying to learn the Kth task, the loss function L using Online EWC would then be
Implementation of Online EWC
Below we show our implementation of Online EWC using pytorch
class OnlineEWC:
def __init__(self, model: nn.Module, loss=nn.MSELoss()):
self._model = model
self._params = {}
self._fim = {}
self._loss = loss
self._inputs = {}
self._labels = {}
self._loss_lst = {}
self._optim = None
self._lambda = 0
def train(self, inputs, labels, index, lr, alpha = 0.5, lam=0, epochs=500):
self._optim = torch.optim.Adam(self._model.parameters(), lr=lr)
loss_values_x1 = []
self._lambda = lam
self._inputs[index] = inputs
self._labels[index] = inputs
# training
for _ in range(epochs):
f = self._model(inputs.float())
regularizer = 0
if len(self._params) != 0:
loss_ewc = 0
for n, p in self._model.named_parameters():
loss_ewc += torch.matmul(self._fim[n].T, (torch.reshape(p, (-1,1)) - torch.reshape(self._params[n], (-1,1))) ** 2)
regularizer += self._lambda * loss_ewc
loss = self._loss(f, labels.unsqueeze(1).float()) + regularizer
self._optim.zero_grad()
loss.backward()
self._optim.step()
# store loss
loss_values_x1.append(loss.item())
for n in self._loss_lst:
tmp_f = self._model(self._inputs[n].float())
tmp_loss = self._loss(tmp_f, self._labels[n].unsqueeze(1).float())
self._loss_lst[n].append(tmp_loss)
for n, p in deepcopy(self._model).named_parameters():
if p.requires_grad:
self._params[n] = p
# update fisher information matrix
f = self._model(inputs.float())
loss = self._loss(f, labels.unsqueeze(1).float())
self._optim.zero_grad()
loss.backward()
temp_fisher = {}
for n, p in self._model.named_parameters():
temp_fisher[n] = torch.reshape(p.grad.data, (-1,1))
for n, p in temp_fisher.items():
if n in self._fim:
self._fim[n] = self._fim[n]*alpha + p**2 * (1-alpha)
else:
self._fim[n] = p**2
self._loss_lst[index] = loss_values_x1
To compare online EWC with offline EWC, it’s a good idea to conduct experiments on online EWC with the same sample data as that of offline EWC. The sample data we use is as follows
Just like what we did for the Offline EWC, we use a 4-hidden-layer MLP with perceptron number of 1, 100, 100, 100, 100, and 1 for the Online EWC.
Below is the trace of the experiments after each individual task being trained
Task 1:
Task 2:
Task 3:
Task 4:
Not bad, right? But can we do better? Obviously, Online EWC is not the end, the next section will focus on possible improvements for EWC techniques.