From 10361cea4c574710d02fc5a04a345a8391b97083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Morales?= Date: Mon, 28 Feb 2022 19:21:16 +0000 Subject: [PATCH] Add Maml++ example (#290) * Implement derivative order annealing for MAML * Add MAML++ example for miniImageNet * Fix _split_batch() * Fix labels * Fix labels * Fix cuda issues * Fix val/test set parsing * Implement derivative-order annealing independently of MAML * Clean up * Add contribution * Lint * Implement BNRS and BNWB * Finish implementing MAML++! * Add clone_named_parameters() in utils * Export maml_pp_update * Clean up * Lint * Move all changes to examples/vision/mamlpp * Revert changes in maml.py * Move contribution to Unreleased --- CHANGELOG.md | 3 + examples/vision/mamlpp/MAMLpp.py | 305 +++++++++++++++++ examples/vision/mamlpp/cnn4_bnrs.py | 321 ++++++++++++++++++ examples/vision/mamlpp/maml++_miniimagenet.py | 316 +++++++++++++++++ learn2learn/utils/__init__.py | 4 + learn2learn/vision/models/__init__.py | 1 + 6 files changed, 950 insertions(+) create mode 100644 examples/vision/mamlpp/MAMLpp.py create mode 100644 examples/vision/mamlpp/cnn4_bnrs.py create mode 100755 examples/vision/mamlpp/maml++_miniimagenet.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a06cb5bc..3d9eda5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus)) * Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/)) ### Changed @@ -28,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Pretrained weights for vision models with: `l2l.vision.models.get_pretrained_backbone()`. * Add `keep_requires_grad` flag to `detach_module`. ([Zhaofeng Wu](https://github.com/ZhaofengWu)) +### Changed + ### Fixed * Fix arguments when instantiating `l2l.nn.Scale`. diff --git a/examples/vision/mamlpp/MAMLpp.py b/examples/vision/mamlpp/MAMLpp.py new file mode 100644 index 00000000..533193c9 --- /dev/null +++ b/examples/vision/mamlpp/MAMLpp.py @@ -0,0 +1,305 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# + +""" +MAML++ wrapper. +""" + +import torch +import traceback + +from torch.autograd import grad + +from learn2learn.algorithms.base_learner import BaseLearner +from learn2learn.utils import clone_module, update_module, clone_named_parameters + + +def maml_pp_update(model, step=None, lrs=None, grads=None): + """ + + **Description** + + Performs a MAML++ update on model using grads and lrs. + The function re-routes the Python object, thus avoiding in-place + operations. + + NOTE: The model itself is updated in-place (no deepcopy), but the + parameters' tensors are not. + + **Arguments** + + * **model** (Module) - The model to update. + * **lrs** (list) - The meta-learned learning rates used to update the model. + * **grads** (list, *optional*, default=None) - A list of gradients for each layer + of the model. If None, will use the gradients in .grad attributes. + + **Example** + ~~~python + maml_pp = l2l.algorithms.MAMLpp(Model(), lr=1.0) + lslr = torch.nn.ParameterDict() + for layer_name, layer in model.named_modules(): + # If the layer has learnable parameters + if ( + len( + [ + name + for name, param in layer.named_parameters(recurse=False) + if param.requires_grad + ] + ) + > 0 + ): + lslr[layer_name.replace(".", "-")] = torch.nn.Parameter( + data=torch.ones(adaptation_steps) * init_lr, + requires_grad=True, + ) + model = maml_pp.clone() # The next two lines essentially implement model.adapt(loss) + for inner_step in range(5): + loss = criterion(model(x), y) + grads = autograd.grad(loss, model.parameters(), create_graph=True) + maml_pp_update(model, inner_step, lrs=lslr, grads=grads) + ~~~ + """ + if grads is not None and lrs is not None: + params = list(model.parameters()) + if not len(grads) == len(list(params)): + msg = "WARNING:maml_update(): Parameters and gradients have different length. (" + msg += str(len(params)) + " vs " + str(len(grads)) + ")" + print(msg) + # TODO: Why doesn't this work?? I can't assign p.grad when zipping like this... Is this + # because I'm using a tuple? + # for named_param, g in zip( + # [(k, v) for k, l in model.named_parameters() for v in l], grads + # ): + # p_name, p = named_param + it = 0 + for name, p in model.named_parameters(): + if grads[it] is not None: + lr = None + layer_name = name[: name.rfind(".")].replace( + ".", "-" + ) # Extract the layer name from the named parameter + lr = lrs[layer_name][step] + assert ( + lr is not None + ), f"Parameter {name} does not have a learning rate in LSLR dict!" + p.grad = grads[it] + p._lr = lr + it += 1 + + # Update the params + for param_key in model._parameters: + p = model._parameters[param_key] + if p is not None and p.grad is not None: + model._parameters[param_key] = p - p._lr * p.grad + p.grad = None + p._lr = None + + # Second, handle the buffers if necessary + for buffer_key in model._buffers: + buff = model._buffers[buffer_key] + if buff is not None and buff.grad is not None and buff._lr is not None: + model._buffers[buffer_key] = buff - buff._lr * buff.grad + buff.grad = None + buff._lr = None + + # Then, recurse for each submodule + for module_key in model._modules: + model._modules[module_key] = maml_pp_update(model._modules[module_key]) + return model + + +class MAMLpp(BaseLearner): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py) + + **Description** + + High-level implementation of *Model-Agnostic Meta-Learning*. + + This class wraps an arbitrary nn.Module and augments it with `clone()` and `adapt()` + methods. + + For the first-order version of MAML (i.e. FOMAML), set the `first_order` flag to `True` + upon initialization. + + **Arguments** + + * **model** (Module) - Module to be wrapped. + * **lr** (float) - Fast adaptation learning rate. + * **lslr** (bool) - Whether to use Per-Layer Per-Step Learning Rates and Gradient Directions + (LSLR) or not. + * **lrs** (list of Parameters, *optional*, default=None) - If not None, overrides `lr`, and uses the list + as learning rates for fast-adaptation. + * **first_order** (bool, *optional*, default=False) - Whether to use the first-order + approximation of MAML. (FOMAML) + * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation + of unused parameters. Defaults to `allow_nograd`. + * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with + parameters that have `requires_grad = False`. + + **References** + + 1. Finn et al. 2017. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks." + + **Example** + + ~~~python + linear = l2l.algorithms.MAML(nn.Linear(20, 10), lr=0.01) + clone = linear.clone() + error = loss(clone(X), y) + clone.adapt(error) + error = loss(clone(X), y) + error.backward() + ~~~ + """ + + def __init__( + self, + model, + lr, + lrs=None, + adaptation_steps=1, + first_order=False, + allow_unused=None, + allow_nograd=False, + ): + super().__init__() + self.module = model + self.lr = lr + if lrs is None: + lrs = self._init_lslr_parameters(model, adaptation_steps, lr) + self.lrs = lrs + self.first_order = first_order + self.allow_nograd = allow_nograd + if allow_unused is None: + allow_unused = allow_nograd + self.allow_unused = allow_unused + + def _init_lslr_parameters( + self, model: torch.nn.Module, adaptation_steps: int, init_lr: float + ) -> torch.nn.ParameterDict: + lslr = torch.nn.ParameterDict() + for layer_name, layer in model.named_modules(): + # If the layer has learnable parameters + if ( + len( + [ + name + for name, param in layer.named_parameters(recurse=False) + if param.requires_grad + ] + ) + > 0 + ): + lslr[layer_name.replace(".", "-")] = torch.nn.Parameter( + data=torch.ones(adaptation_steps) * init_lr, + requires_grad=True, + ) + return lslr + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def adapt(self, loss, step=None, first_order=None, allow_unused=None, allow_nograd=None): + """ + **Description** + + Takes a gradient step on the loss and updates the cloned parameters in place. + + **Arguments** + + * **loss** (Tensor) - Loss to minimize upon update. + * **step** (int) - Current inner loop step. Used to fetch the corresponding learning rate. + * **first_order** (bool, *optional*, default=None) - Whether to use first- or + second-order updates. Defaults to self.first_order. + * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation + of unused parameters. Defaults to self.allow_unused. + * **allow_nograd** (bool, *optional*, default=None) - Whether to allow adaptation with + parameters that have `requires_grad = False`. Defaults to self.allow_nograd. + """ + if first_order is None: + first_order = self.first_order + if allow_unused is None: + allow_unused = self.allow_unused + if allow_nograd is None: + allow_nograd = self.allow_nograd + second_order = not first_order + + gradients = [] + if allow_nograd: + # Compute relevant gradients + diff_params = [p for p in self.module.parameters() if p.requires_grad] + grad_params = grad( + loss, + diff_params, + retain_graph=second_order, + create_graph=second_order, + allow_unused=allow_unused, + ) + grad_counter = 0 + + # Handles gradients for non-differentiable parameters + for param in self.module.parameters(): + if param.requires_grad: + gradient = grad_params[grad_counter] + grad_counter += 1 + else: + gradient = None + gradients.append(gradient) + else: + try: + gradients = grad( + loss, + self.module.parameters(), + retain_graph=second_order, + create_graph=second_order, + allow_unused=allow_unused, + ) + except RuntimeError: + traceback.print_exc() + print( + "learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?" + ) + + # Update the module + assert step is not None, "step cannot be None when using LSLR!" + self.module = maml_pp_update(self.module, step, lrs=self.lrs, grads=gradients) + + def clone(self, first_order=None, allow_unused=None, allow_nograd=None): + """ + **Description** + + Returns a `MAMLpp`-wrapped copy of the module whose parameters and buffers + are `torch.clone`d from the original module. + + This implies that back-propagating losses on the cloned module will + populate the buffers of the original module. + For more information, refer to learn2learn.clone_module(). + + **Arguments** + + * **first_order** (bool, *optional*, default=None) - Whether the clone uses first- + or second-order updates. Defaults to self.first_order. + * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation + of unused parameters. Defaults to self.allow_unused. + * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with + parameters that have `requires_grad = False`. Defaults to self.allow_nograd. + + """ + if first_order is None: + first_order = self.first_order + if allow_unused is None: + allow_unused = self.allow_unused + if allow_nograd is None: + allow_nograd = self.allow_nograd + return MAMLpp( + clone_module(self.module), + lr=self.lr, + lrs=clone_named_parameters(self.lrs), + first_order=first_order, + allow_unused=allow_unused, + allow_nograd=allow_nograd, + ) diff --git a/examples/vision/mamlpp/cnn4_bnrs.py b/examples/vision/mamlpp/cnn4_bnrs.py new file mode 100644 index 00000000..28f9666f --- /dev/null +++ b/examples/vision/mamlpp/cnn4_bnrs.py @@ -0,0 +1,321 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# + +""" +CNN4 extended with Batch-Norm Running Statistics. +""" + +import torch +import torch.nn.functional as F + +from copy import deepcopy +from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ + + +class MetaBatchNormLayer(torch.nn.Module): + """ + An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running + Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in + MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at + https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch, + with heavy refactoring and a bug fix + (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42). + """ + + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + meta_batch_norm=True, + adaptation_steps: int = 1, + ): + super(MetaBatchNormLayer, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.meta_batch_norm = meta_batch_norm + self.num_features = num_features + self.running_mean = torch.nn.Parameter( + torch.zeros(adaptation_steps, num_features), requires_grad=False + ) + self.running_var = torch.nn.Parameter( + torch.ones(adaptation_steps, num_features), requires_grad=False + ) + self.bias = torch.nn.Parameter( + torch.zeros(adaptation_steps, num_features), requires_grad=True + ) + self.weight = torch.nn.Parameter( + torch.ones(adaptation_steps, num_features), requires_grad=True + ) + self.backup_running_mean = torch.zeros(self.running_mean.shape) + self.backup_running_var = torch.ones(self.running_var.shape) + self.momentum = momentum + + def forward( + self, + input, + step, + ): + """ + :param input: input data batch, size either can be any. + :param step: The current inner loop step being taken. This is used when to learn per step params and + collecting per step batch statistics. + :return: The result of the batch norm operation. + """ + assert ( + step < self.running_mean.shape[0] + ), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!" + return F.batch_norm( + input, + self.running_mean[step], + self.running_var[step], + self.weight[step], + self.bias[step], + training=True, + momentum=self.momentum, + eps=self.eps, + ) + + def backup_stats(self): + self.backup_running_mean.data = deepcopy(self.running_mean.data) + self.backup_running_var.data = deepcopy(self.running_var.data) + + def restore_backup_stats(self): + """ + Resets batch statistics to their backup values which are collected after each forward pass. + """ + self.running_mean = torch.nn.Parameter( + self.backup_running_mean, requires_grad=False + ) + self.running_var = torch.nn.Parameter( + self.backup_running_var, requires_grad=False + ) + + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( + **self.__dict__ + ) + + +class LinearBlock_BNRS(torch.nn.Module): + def __init__(self, input_size, output_size, adaptation_steps): + super(LinearBlock_BNRS, self).__init__() + self.relu = torch.nn.ReLU() + self.normalize = MetaBatchNormLayer( + output_size, + affine=True, + momentum=0.999, + eps=1e-3, + adaptation_steps=adaptation_steps, + ) + self.linear = torch.nn.Linear(input_size, output_size) + fc_init_(self.linear) + + def forward(self, x, step): + x = self.linear(x) + x = self.normalize(x, step) + x = self.relu(x) + return x + + +class ConvBlock_BNRS(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + max_pool=True, + max_pool_factor=1.0, + adaptation_steps=1, + ): + super(ConvBlock_BNRS, self).__init__() + stride = (int(2 * max_pool_factor), int(2 * max_pool_factor)) + if max_pool: + self.max_pool = torch.nn.MaxPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=False, + ) + stride = (1, 1) + else: + self.max_pool = lambda x: x + self.normalize = MetaBatchNormLayer( + out_channels, + affine=True, + adaptation_steps=adaptation_steps, + # eps=1e-3, + # momentum=0.999, + ) + torch.nn.init.uniform_(self.normalize.weight) + self.relu = torch.nn.ReLU() + + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + bias=True, + ) + maml_init_(self.conv) + + def forward(self, x, step): + x = self.conv(x) + x = self.normalize(x, step) + x = self.relu(x) + x = self.max_pool(x) + return x + + +class ConvBase_BNRS(torch.nn.Sequential): + + # NOTE: + # Omniglot: hidden=64, channels=1, no max_pool + # MiniImagenet: hidden=32, channels=3, max_pool + + def __init__( + self, hidden=64, channels=1, max_pool=False, layers=4, max_pool_factor=1.0, + adaptation_steps=1 + ): + core = [ + ConvBlock_BNRS( + channels, + hidden, + (3, 3), + max_pool=max_pool, + max_pool_factor=max_pool_factor, + adaptation_steps=adaptation_steps + ), + ] + for _ in range(layers - 1): + core.append( + ConvBlock_BNRS( + hidden, + hidden, + kernel_size=(3, 3), + max_pool=max_pool, + max_pool_factor=max_pool_factor, + adaptation_steps=adaptation_steps + ) + ) + super(ConvBase_BNRS, self).__init__(*core) + + def forward(self, x, step): + for module in self: + x = module(x, step) + return x + + +class CNN4Backbone_BNRS(ConvBase_BNRS): + def __init__( + self, + hidden_size=64, + layers=4, + channels=3, + max_pool=True, + max_pool_factor=None, + adaptation_steps=1, + ): + if max_pool_factor is None: + max_pool_factor = 4 // layers + super(CNN4Backbone_BNRS, self).__init__( + hidden=hidden_size, + layers=layers, + channels=channels, + max_pool=max_pool, + max_pool_factor=max_pool_factor, + adaptation_steps=adaptation_steps + ) + + def forward(self, x, step): + x = super(CNN4Backbone_BNRS, self).forward(x, step) + x = x.reshape(x.size(0), -1) + return x + + +class CNN4_BNRS(torch.nn.Module): + """ + + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py) + + **Description** + + The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, 2017. + + This network assumes inputs of shapes (3, 84, 84). + + Instantiate `CNN4Backbone` if you only need the feature extractor. + + **References** + + 1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR. + + **Arguments** + + * **output_size** (int) - The dimensionality of the network's output. + * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden representation. + * **layers** (int, *optional*, default=4) - The number of convolutional layers. + * **channels** (int, *optional*, default=3) - The number of channels in input. + * **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling. + * **embedding_size** (int, *optional*, default=None) - Size of feature embedding. + Defaults to 25 * hidden_size (for mini-Imagenet). + + **Example** + ~~~python + model = CNN4(output_size=20, hidden_size=128, layers=3) + ~~~ + """ + + def __init__( + self, + output_size, + hidden_size=64, + layers=4, + channels=3, + max_pool=True, + embedding_size=None, + adaptation_steps=1, + ): + super(CNN4_BNRS, self).__init__() + if embedding_size is None: + embedding_size = 25 * hidden_size + self.features = CNN4Backbone_BNRS( + hidden_size=hidden_size, + channels=channels, + max_pool=max_pool, + layers=layers, + max_pool_factor=4 // layers, + adaptation_steps=adaptation_steps, + ) + self.classifier = torch.nn.Linear( + embedding_size, + output_size, + bias=True, + ) + maml_init_(self.classifier) + self.hidden_size = hidden_size + + def backup_stats(self): + """ + Backup stored batch statistics before running a validation epoch. + """ + for layer in self.features.modules(): + if type(layer) is MetaBatchNormLayer: + layer.backup_stats() + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for layer in self.features.modules(): + if type(layer) is MetaBatchNormLayer: + layer.restore_backup_stats() + + def forward(self, x, step): + x = self.features(x, step) + x = self.classifier(x) + return x diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py new file mode 100755 index 00000000..78085bf9 --- /dev/null +++ b/examples/vision/mamlpp/maml++_miniimagenet.py @@ -0,0 +1,316 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# +# Copyright © 2021 Théo Morales +# +# Distributed under terms of the MIT license. + +""" +Example implementation of MAML++ on miniImageNet. +""" + + +import learn2learn as l2l +import numpy as np +import random +import torch + +from collections import namedtuple +from typing import Tuple +from tqdm import tqdm + +from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS +from examples.vision.mamlpp.MAMLpp import MAMLpp + + +MetaBatch = namedtuple("MetaBatch", "support query") + +train_samples, val_samples, test_samples = 38400, 9600, 12000 # Is that correct? +tasks = 600 + + +def accuracy(predictions, targets): + predictions = predictions.argmax(dim=1).view(targets.shape) + return (predictions == targets).sum().float() / targets.size(0) + + +class MAMLppTrainer: + def __init__( + self, + ways=5, + k_shots=10, + n_queries=30, + steps=5, + msl_epochs=25, + DA_epochs=50, + use_cuda=True, + seed=42, + ): + self._use_cuda = use_cuda + self._device = torch.device("cpu") + if self._use_cuda and torch.cuda.device_count(): + torch.cuda.manual_seed(seed) + self._device = torch.device("cuda") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + # Dataset + print("[*] Loading mini-ImageNet...") + ( + self._train_tasks, + self._valid_tasks, + self._test_tasks, + ) = l2l.vision.benchmarks.get_tasksets( + "mini-imagenet", + train_samples=k_shots, + train_ways=ways, + test_samples=n_queries, + test_ways=ways, + root="~/data", + ) + + # Model + self._model = CNN4_BNRS(ways, adaptation_steps=steps) + if self._use_cuda: + self._model.cuda() + + # Meta-Learning related + self._steps = steps + self._k_shots = k_shots + self._n_queries = n_queries + self._inner_criterion = torch.nn.CrossEntropyLoss(reduction="mean") + + # Multi-Step Loss + self._msl_epochs = msl_epochs + self._step_weights = torch.ones(steps, device=self._device) * (1.0 / steps) + self._msl_decay_rate = 1.0 / steps / msl_epochs + self._msl_min_value_for_non_final_losses = torch.tensor(0.03 / steps) + self._msl_max_value_for_final_loss = 1.0 - ( + (steps - 1) * self._msl_min_value_for_non_final_losses + ) + + # Derivative-Order Annealing (when to start using second-order opt) + self._derivative_order_annealing_from_epoch = DA_epochs + + def _anneal_step_weights(self): + self._step_weights[:-1] = torch.max( + self._step_weights[:-1] - self._msl_decay_rate, + self._msl_min_value_for_non_final_losses, + ) + self._step_weights[-1] = torch.min( + self._step_weights[-1] + ((self._steps - 1) * self._msl_decay_rate), + self._msl_max_value_for_final_loss, + ) + + def _split_batch(self, batch: tuple) -> MetaBatch: + """ + Separate data batch into adaptation/evalutation sets. + """ + images, labels = batch + batch_size = self._k_shots + self._n_queries + assert batch_size <= images.shape[0], "K+N are greater than the batch size!" + indices = torch.randperm(batch_size) + support_indices = indices[: self._k_shots] + query_indices = indices[self._k_shots :] + return MetaBatch( + ( + images[support_indices], + labels[support_indices], + ), + (images[query_indices], labels[query_indices]), + ) + + def _training_step( + self, + batch: MetaBatch, + learner: MAMLpp, + msl: bool = True, + epoch: int = 0, + ) -> Tuple[torch.Tensor, float]: + s_inputs, s_labels = batch.support + q_inputs, q_labels = batch.query + query_loss = torch.tensor(.0, device=self._device) + + if self._use_cuda: + s_inputs = s_inputs.float().cuda(device=self._device) + s_labels = s_labels.cuda(device=self._device) + q_inputs = q_inputs.float().cuda(device=self._device) + q_labels = q_labels.cuda(device=self._device) + + # Derivative-Order Annealing + second_order = True + if epoch < self._derivative_order_annealing_from_epoch: + second_order = False + + # Adapt the model on the support set + for step in range(self._steps): + # forward + backward + optimize + pred = learner(s_inputs, step) + support_loss = self._inner_criterion(pred, s_labels) + learner.adapt(support_loss, first_order=not second_order, step=step) + # Multi-Step Loss + if msl: + q_pred = learner(q_inputs, step) + query_loss += self._step_weights[step] * self._inner_criterion( + q_pred, q_labels + ) + + # Evaluate the adapted model on the query set + if not msl: + q_pred = learner(q_inputs, self._steps-1) + query_loss = self._inner_criterion(q_pred, q_labels) + acc = accuracy(q_pred, q_labels).detach() + + return query_loss, acc + + def _testing_step( + self, batch: MetaBatch, learner: MAMLpp + ) -> Tuple[torch.Tensor, float]: + s_inputs, s_labels = batch.support + q_inputs, q_labels = batch.query + + if self._use_cuda: + s_inputs = s_inputs.float().cuda(device=self._device) + s_labels = s_labels.cuda(device=self._device) + q_inputs = q_inputs.float().cuda(device=self._device) + q_labels = q_labels.cuda(device=self._device) + + # Adapt the model on the support set + for step in range(self._steps): + # forward + backward + optimize + pred = learner(s_inputs, step) + support_loss = self._inner_criterion(pred, s_labels) + learner.adapt(support_loss, step=step) + + # Evaluate the adapted model on the query set + q_pred = learner(q_inputs, self._steps-1) + query_loss = self._inner_criterion(q_pred, q_labels).detach() + acc = accuracy(q_pred, q_labels) + + return query_loss, acc + + def train( + self, + meta_lr=0.001, + fast_lr=0.01, + meta_bsz=5, + epochs=100, + val_interval=1, + ): + print("[*] Training...") + maml = MAMLpp( + self._model, + lr=fast_lr, # Initialisation LR for all layers and steps + adaptation_steps=self._steps, # For LSLR + first_order=False, + allow_nograd=True, # For the parameters of the MetaBatchNorm layers + ) + opt = torch.optim.AdamW(maml.parameters(), meta_lr, betas=(0, 0.999)) + + iter_per_epoch = ( + train_samples // (meta_bsz * (self._k_shots + self._n_queries)) + ) + 1 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + opt, + T_max=epochs * iter_per_epoch, + eta_min=0.00001, + ) + + for epoch in range(epochs): + epoch_meta_train_loss, epoch_meta_train_acc = 0.0, 0.0 + for _ in tqdm(range(iter_per_epoch)): + opt.zero_grad() + meta_train_losses, meta_train_accs = [], [] + + for _ in range(meta_bsz): + meta_batch = self._split_batch(self._train_tasks.sample()) + meta_loss, meta_acc = self._training_step( + meta_batch, + maml.clone(), + msl=(epoch < self._msl_epochs), + epoch=epoch, + ) + meta_loss.backward() + meta_train_losses.append(meta_loss.detach()) + meta_train_accs.append(meta_acc) + + epoch_meta_train_loss += torch.Tensor(meta_train_losses).mean().item() + epoch_meta_train_acc += torch.Tensor(meta_train_accs).mean().item() + + # Average the accumulated gradients and optimize + with torch.no_grad(): + for p in maml.parameters(): + # Remember the MetaBatchNorm layer has parameters that don't require grad! + if p.requires_grad: + p.grad.data.mul_(1.0 / meta_bsz) + + opt.step() + scheduler.step() + # Multi-Step Loss + self._anneal_step_weights() + + epoch_meta_train_loss /= iter_per_epoch + epoch_meta_train_acc /= iter_per_epoch + print(f"==========[Epoch {epoch}]==========") + print(f"Meta-training Loss: {epoch_meta_train_loss:.6f}") + print(f"Meta-training Acc: {epoch_meta_train_acc:.6f}") + + # ======= Validation ======== + if (epoch + 1) % val_interval == 0: + # Backup the BatchNorm layers' running statistics + maml.backup_stats() + + # Compute the meta-validation loss + # TODO: Go through the entire validation set, which shouldn't be shuffled, and + # which tasks should not be continuously resampled from! + meta_val_losses, meta_val_accs = [], [] + for _ in tqdm(range(val_samples // tasks)): + meta_batch = self._split_batch(self._valid_tasks.sample()) + loss, acc = self._testing_step(meta_batch, maml.clone()) + meta_val_losses.append(loss) + meta_val_accs.append(acc) + meta_val_loss = float(torch.Tensor(meta_val_losses).mean().item()) + meta_val_acc = float(torch.Tensor(meta_val_accs).mean().item()) + print(f"Meta-validation Loss: {meta_val_loss:.6f}") + print(f"Meta-validation Accuracy: {meta_val_acc:.6f}") + # Restore the BatchNorm layers' running statistics + maml.restore_backup_stats() + print("============================================") + + return self._model.state_dict() + + def test( + self, + model_state_dict, + meta_lr=0.001, + fast_lr=0.01, + meta_bsz=5, + ): + self._model.load_state_dict(model_state_dict) + maml = MAMLpp( + self._model, + lr=fast_lr, + adaptation_steps=self._steps, + first_order=False, + allow_nograd=True, + ) + opt = torch.optim.AdamW(maml.parameters(), meta_lr, betas=(0, 0.999)) + + meta_losses, meta_accs = [], [] + for _ in tqdm(range(test_samples // tasks)): + meta_batch = self._split_batch(self._test_tasks.sample()) + loss, acc = self._testing_step(meta_batch, maml.clone()) + meta_losses.append(loss) + meta_accs.append(acc) + loss = float(torch.Tensor(meta_losses).mean().item()) + acc = float(torch.Tensor(meta_accs).mean().item()) + print(f"Meta-training Loss: {loss:.6f}") + print(f"Meta-training Acc: {acc:.6f}") + + +if __name__ == "__main__": + mamlPlusPlus = MAMLppTrainer() + model = mamlPlusPlus.train() + mamlPlusPlus.test(model) diff --git a/learn2learn/utils/__init__.py b/learn2learn/utils/__init__.py index 406e325d..9d464036 100644 --- a/learn2learn/utils/__init__.py +++ b/learn2learn/utils/__init__.py @@ -50,6 +50,10 @@ def clone_parameters(param_list): return [p.clone() for p in param_list] +def clone_named_parameters(param_dict): + return {k: p.clone() for k, p in param_dict.items()} + + def clone_module(module, memo=None): """ diff --git a/learn2learn/vision/models/__init__.py b/learn2learn/vision/models/__init__.py index fe921384..54cdeec4 100644 --- a/learn2learn/vision/models/__init__.py +++ b/learn2learn/vision/models/__init__.py @@ -30,6 +30,7 @@ def forward(self, x): CNN4, CNN4Backbone, ) + from .resnet12 import ResNet12, ResNet12Backbone from .wrn28 import WRN28, WRN28Backbone