Official PyTorch implementation of the MixMo framework | paper | docs
Alexandre Ramé, Rémy Sun, Matthieu Cord
If you find this code useful for your research, please cite:
@inproceedings{rame2021ixmo,
title={MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks},
author={Alexandre Rame and Remy Sun and Matthieu Cord},
year={2021},
booktitle={ICCV 2021}
}
Recent strategies achieved ensembling “for free” by fitting concurrently diverse subnetworks inside a single base network. The main idea during training is that each subnetwork learns to classify only one of the multiple inputs simultaneously provided. However, the question of how to best mix these multiple inputs has not been studied so far.
In this paper, we introduce MixMo, a new generalized framework for learning multi-input multi-output deep subnetworks. Our key motivation is to replace the suboptimal summing operation hidden in previous approaches by a more appropriate mixing mechanism. For that purpose, we draw inspiration from successful mixed sample data augmentations. We show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse.
We improve state of the art for image classification on CIFAR-100 and Tiny ImageNet datasets. Our easy to implement models notably outperform data augmented deep ensembles, without the inference and memory overheads. As we operate in features and simply better leverage the expressiveness of large networks, we open a new line of research complementary to previous works.
This repository provides a general wrapper over PyTorch to reproduce the main results from the paper. The code sections specific to MixMo can be found in:
mixmo.loaders.dataset_wrapper.py
and specificallyMixMoDataset
to create batches with multiple inputs and multiple outputs.mixmo.augmentations.mixing_blocks.py
where we create the mixing masks, e.g. via linear summing (_mixup_mask
) or via patch mixing (_cutmix_mask
).mixmo.networks.resnet.py
andmixmo.networks.wrn.py
where we adapt the network structures to handle:- multiple inputs via multiple conv1s encoders (one for each input). The function
mixmo.augmentations.mixing_blocks.mix_manifold
is used to mix the extracted representations according to the masks provided in metadata from MixMoDataset. - multiple outputs via multiple predictions.
- multiple inputs via multiple conv1s encoders (one for each input). The function
This translates to additional tensor management in mixmo.learners.learner.py
.
Our MixMoDataset
wraps a PyTorch Dataset. The batch_repetition_sampler
repeats the same index b
times in each batch. Moreover, we provide SoftCrossEntropyLoss
which handles soft-labels required by mixed sample data augmentations such as CutMix.
from mixmo.loaders import (dataset_wrapper, batch_repetition_sampler)
from mixmo.networks.wrn import WideResNetMixMo
from mixmo.core.loss import SoftCrossEntropyLoss as criterion
...
# cf mixmo.loaders.loader
train_dataset = dataset_wrapper.MixMoDataset(
dataset=CIFAR100(os.path.join(dataplace, "cifar100-data")),
num_members=2, # we use M=2 subnetworks
mixmo_mix_method="cutmix", # patch mixing, linker to mixmo.augmentations.mixing_blocks._cutmix_mask
mixmo_alpha=2, # mixing ratio sampled from Beta distribution with concentration 2
mixmo_weight_root=3 # root for reweighting of loss components 3
)
network = WideResNetMixMo(depth=28, widen_factor=10, num_classes=100)
...
# cf mixmo.learners.learner and mixmo.learners.model_wrapper
for _ in range(num_epochs):
for indexes_0, indexes_1 in batch_repetition_sampler(batch_size=64, b=4, max_index=len(train_dataset)):
for (inputs_0, inputs_1, targets_0, targets_1, metadata_mixmo_masks) in train_dataset(indexes_0, indexes_1):
outputs_0, outputs_1 = network([inputs_0, inputs_1], metadata_mixmo_masks)
loss = criterion(outputs_0, targets_0) + criterion(outputs_1, targets_1)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Our code heavily relies on yaml config files. In the mixmo-pytorch/config
folder, we provide the configs to reproduce the main paper results.
For example, the state-of-the-art exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4
means that:
cifar100
: dataset is CIFAR-100.wrn2810-2
: WideResNet-28-10 network architecture withM=2
subnetworks.cutmixmo-p5
: mixing block is patch mixing with probabilityp=0.5
else linear mixing.msdacutmix
: use CutMix mixed sample data augmentation.bar4
: batch repetition tob=4
.
Subnetwork method | MSDA | Top-1 Accuracy | config file in mixmo-pytorch/config/cifar100 |
---|---|---|---|
-- | Vanilla | 81.79 | exp_cifar100_wrn2810_1net_standard_bar1.yaml |
-- | Mixup | 83.43 | exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml |
-- | CutMix | 83.95 | exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml |
MIMO | -- | 82.92 | exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml |
Linear-MixMo | -- | 82.96 | exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml |
Cut-MixMo | -- | 85.52 - 85.59 | exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml |
Linear-MixMo | CutMix | 85.36 - 85.57 | exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml |
Cut-MixMo | CutMix | 85.77 - 85.92 | exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml |
Subnetwork method | MSDA | Top-1 Accuracy | config file in mixmo-pytorch/config/cifar10 |
---|---|---|---|
-- | Vanilla | 96.37 | exp_cifar10_wrn2810_1net_standard_bar1.yaml |
-- | Mixup | 97.07 | exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml |
-- | CutMix | 97.28 | exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml |
MIMO | -- | 96.71 | exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml |
Linear-MixMo | -- | 96.88 | exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml |
Cut-MixMo | -- | 97.52 | exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml |
Linear-MixMo | CutMix | 97.73 | exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml |
Cut-MixMo | CutMix | 97.83 | exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml |
Method | Width | Top-1 Accuracy | config file in mixmo-pytorch/config/tiny |
---|---|---|---|
Vanilla | 1 | 62.75 | exp_tinyimagenet_res18_1net_standard_bar1.yaml |
Linear-MixMo | 1 | 62.91 | exp_tinyimagenet_res18-2_linearmixmo_standard_bar4.yaml |
Cut-MixMo | 1 | 64.32 | exp_tinyimagenet_res18-2_cutmixmo-p5_standard_bar4.yaml |
Vanilla | 2 | 64.91 | exp_tinyimagenet_res182_1net_standard_bar1.yaml |
Linear-MixMo | 2 | 67.03 | exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml |
Cut-MixMo | 2 | 69.12 | exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml |
Vanilla | 3 | 65.84 | exp_tinyimagenet_res183_1net_standard_bar1.yaml |
Linear-MixMo | 3 | 68.36 | exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml |
Cut-MixMo | 3 | 70.23 | exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml |
- python >= 3.6
- torch >= 1.4.0
- torchsummary >= 1.5.1
- torchvision >= 0.5.0
- tensorboard >= 1.14.0
- Clone the repo:
$ git clone https://github.com/alexrame/mixmo-pytorch.git
- Install this repository and the dependencies using pip:
$ conda create --name mixmo python=3.6.10
$ conda activate mixmo
$ cd mixmo-pytorch
$ pip install -r requirements.txt
With this, you can edit the MixMo code on the fly.
We advise to first create a dedicated data folder dataplace
, that will be provided as an argument in the subsequent scripts.
- CIFAR
CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in your provided dataplace
.
- Tiny-ImageNet
Tiny-ImageNet dataset needs to be download beforehand. The following process is forked from manifold mixup.
- Download the zipped data from https://tiny-imagenet.herokuapp.com/.
- Extract the zipped data in folder
dataplace
. - Run the following script (This will arange the validation data in the format required by the pytorch loader).
$ python scripts/script_load_tiny_data.py --dataplace $dataplace
First, to train a baseline model, simply execute the following command:
$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810_1net_standard_bar1.yaml --dataplace $dataplace --saveplace $saveplace
It will create an output folder exp_cifar100_wrn2810_1net_standard_bar1
located in parent folder saveplace
. This folder includes model checkpoints, a copy of your config file, logs and tensorboard logs. By default, if the output folder already exists, training will load the last weights epoch and will continue. If you want to forcefully restart training, simply add --from_scratch
as an argument.
When training MixMo, you just need to select the appropriate config file. For example, to obtain state of the art results on CIFAR-100 by combining Cut-MixMo and CutMix, just execute:
$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --saveplace $saveplace
To evaluate the accuracy of a given strategy, you can train your own model, or just download our pretrained checkpoints:
$ python3 scripts/evaluate.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --checkpoint $checkpoint --tempscal
checkpoint
can be either:- a path towards a checkpoint.
- an int matching the training epoch you wish to evaluate. In that case, you need to provide
--saveplace $saveplace
. - the string
best
: we then automatically select the best training epoch. In that case, you need to provide--saveplace $saveplace
.
--tempscal
: indicates that you will apply temperature scaling
Results will be printed at the end of the script.
If you wish to test the models against common corruptions and perturbations, download the CIFAR-100-c dataset in your dataplace
. Then use --corruptions
at evaluation.
You can create new configs automatically via:
$ python3 scripts/templateutils_mixmo.py --template_path scripts/exp_mixmo_template.yaml --config_dir config/$your_config_dir --dataset $dataset
- Our implementation is based on the repository: https://github.com/valeoai/ConfidNet. We thus thank Charles Corbière for his work Addressing Failure Prediction by Learning Model Confidence.
- MIMO: https://github.com/google/edward2/
- CutMix: https://github.com/ildoonet/cutmix/
- Mixup: https://github.com/facebookresearch/mixup-cifar10
- AugMix: https://github.com/google-research/augmix/
- Temperature Scaling: https://github.com/gpleiss/temperature_scaling/
- Metrics: