MINT Lab Continual Learning

MINT Lab | Vanderbilt University

Maybe More Efficiency? | MINT Lab Continual Learning

Online EWC

Last section mentions that space and time complexity of Offline EWC can become unacceptable as task number grows.

In light of this, Online EWC is introduced as a variant of the EWC technique. Online EWC compromises the performance for a better complexity than the Offline version. So, it makes sense considering Online EWC as a product of the trade-off between performance and complexity.

How Online EWC works

Online EWC realizes multi-task continual learning by maintaining a single FIM (call it online FIM for differentiating purpose). This online FIM gets updated each time a new task is trained. Denote the online FIM before the update as Fold and the online FIM after the update as Fnew. Let Fc be the FIM corresponding to the current task, and α be the importance coefficient controlling the weight of previous tasks. The update process of the online FIM can then be formulated as follows:

Given the maintenance of a single FIM, suppose we are trying to learn the Kth task, the loss function L using Online EWC would then be

Implementation of Online EWC

Below we show our implementation of Online EWC using pytorch

class OnlineEWC:
    def __init__(self, model: nn.Module, loss=nn.MSELoss()):
        self._model = model
        self._params = {}
        self._fim = {}
        self._loss = loss
        self._inputs = {}
        self._labels = {}
        self._loss_lst = {}
        self._optim = None
        self._lambda = 0

    def train(self, inputs, labels, index, lr, alpha = 0.5, lam=0, epochs=500):
        self._optim = torch.optim.Adam(self._model.parameters(), lr=lr)

        loss_values_x1 = []
        self._lambda = lam
        self._inputs[index] = inputs
        self._labels[index] = inputs

        # training
        for _ in range(epochs):
            f = self._model(inputs.float())
            regularizer = 0
            if len(self._params) != 0:
                loss_ewc = 0
                for n, p in self._model.named_parameters():
                    loss_ewc += torch.matmul(self._fim[n].T, (torch.reshape(p, (-1,1)) - torch.reshape(self._params[n], (-1,1))) ** 2)
                regularizer += self._lambda * loss_ewc

            loss = self._loss(f, labels.unsqueeze(1).float()) + regularizer
            self._optim.zero_grad()
            loss.backward()
            self._optim.step()

            # store loss
            loss_values_x1.append(loss.item())


            for n in self._loss_lst:
                tmp_f = self._model(self._inputs[n].float())
                tmp_loss = self._loss(tmp_f, self._labels[n].unsqueeze(1).float())
                self._loss_lst[n].append(tmp_loss)


        for n, p in deepcopy(self._model).named_parameters():
            if p.requires_grad:
                self._params[n] = p

        # update fisher information matrix
        f = self._model(inputs.float())
        loss = self._loss(f, labels.unsqueeze(1).float())
        self._optim.zero_grad()
        loss.backward()

        temp_fisher = {}
        for n, p in self._model.named_parameters():
            temp_fisher[n] = torch.reshape(p.grad.data, (-1,1))

        for n, p in temp_fisher.items():
            if n in self._fim:
                self._fim[n] = self._fim[n]*alpha + p**2 * (1-alpha)
            else:
                self._fim[n] = p**2
        self._loss_lst[index] = loss_values_x1

To compare online EWC with offline EWC, it’s a good idea to conduct experiments on online EWC with the same sample data as that of offline EWC. The sample data we use is as follows

online4_data

Just like what we did for the Offline EWC, we use a 4-hidden-layer MLP with perceptron number of 1, 100, 100, 100, 100, and 1 for the Online EWC.

Below is the trace of the experiments after each individual task being trained

Task 1:

loss1_task4 task1_online4

Task 2:

loss2_task4 task2_online4

Task 3:

loss3_task4 task3_online4

Task 4:

loss4_task4 task4_online4

Not bad, right? But can we do better? Obviously, Online EWC is not the end, the next section will focus on possible improvements for EWC techniques.

Back to top

Prev: Elastic Weight Consolidation Next: Masking