Authors: Shrinivas Ramasubramanian* , Harsh Rangwani * , Sho Takemori * , Kunal Samanta, Yuhei Umeda, Venkatesh Babu Radhakrishnan
This repo. contains code for our ICLR'24 spotlight [paper] "Selective Mixup Fine-Tuning for Optimizing Non-Decomposable Objectives"
Fig. 1: We demonstrate the effect of the variants of mixup on feature representations (a). With
Mixup, the feature representation gets equal contribution in all directions of other classes (b). Unlike
this, in SelMix (c), certain class mixups are selected at a timestep t such that they optimize the desired
metric. Above is an overview of how the SelMix distribution is obtained at timestep t
@inproceedings{
ramasubramanian2024selective,
title={Selective Mixup Fine-Tuning for Optimizing Non-Decomposable Metrics},
author={Shrinivas Ramasubramanian and Harsh Rangwani and Sho Takemori and Kunal Samanta and Yuhei Umeda and Venkatesh Babu Radhakrishnan},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}
SelMix is a novel method for optimizing non-decomposable objectives in long-tail semi-supervised and supervised learning tasks. It formulates the problem as a multi-armed bandit, where each arm represents a pair of classes to perform mixup on. SelMix introduces a selection policy that assigns probabilities to class pairs based on their estimated gain in the objective, updated using validation feedback.
To install PyTorch's latest version, please refer to official docs. After installing the latest version of PyTorch for your GPU, create and activate your conda environment:
conda create --name your_env_name --file requirements.txt
conda activate your_env_name
Install the necessary libraries using
pip install -r requirements.txt
Since we extensively log various detailed performance measures for the model, we strongly recommend that you install wandb before proceeding further. You can find the instruction here
The usage guidelines using custom-models is provided in Model Readme
Overview of the training pipeling
## Define your datasets for mixup; for the supervised case, they are assumed to be the same dataset
dataset1 = None
dataset2 = None
# Placeholder for optimizer
optimizer = None
# Placeholder for Lagrange multipliers
lagrange_multipliers = None
# Loop through epochs
for epoch in range(num_epochs):
# Perform validation and obtain confusion matrix and prototypes
confusion_matrix, prototypes = validation(valset, model)
# Calculate MinRecall objective and update Lagrange multipliers
objective = MinRecall(confusion_matrix, prototypes, lagrange_multipliers)
lagrange_multipliers = objective.lambdas
# Obtain P_selmix and create FastJointSampler using the objective's P
P_selmix = objective.P
SelMix_dataloader = FastJointSampler(dataset1, dataset2, model, P_selmix)
# Loop through steps in each epoch
for step in range(num_steps_per_epoch):
# Get batches from SelMix dataloader
(x1, y1), (x2, y2) = SelMix_dataloader.get_batch()
# Forward pass through the model with mixed inputs
logits = model(x1, x2)
# Calculate cross-entropy loss using labels from the first batch
loss = F.cross_entropy(logits, y1)
# Backward pass and optimization step
loss.backward()
optimizer.step()
# Reset gradients for the next iteration
optimizer.zero_grad()
We provide the code for pre-training your model using FixMatch and FixMatch w/ LA on various dataset configurations. The complete set of config files can be found in the ./configs directory with the following structure. Each dataset has configurations corresponding to its dataset distribution's pre-training for a given pre-training method and the corresponding config file for fine-tuning for a given objective.
.
├── cifar10/
│ ├── DataDistribution/
│ ├── GeometricMean.yaml
│ ├── ...
│ ├── OtherObjectives.yaml
│ ├── ...
│ └── MinRecall.yaml
│ ├── ...
│ └── ...
│ ├── pretraining/
│ ├── fixmatchLA.yaml
│ ├── fixmatchOriginal.yaml
├── cifar100/
├── imagenet1k/
└── stl10/
To start the pre-training for vanilla FixMatch for CIFAR-10 with
python pretrain.py --config_file configs/cifar10/$N_1$-1500_$M_1$-3000_IBRL-100_IBRU-100/pretraining/fixmatchOriginal.yaml
We also provide the config file for pre-training with FixMatch w/ LA on the same datasets as
python pretrain.py --config_file configs/cifar10/$N_1$-1500_$M_1$-3000_IBRL-100_IBRU-100/pretraining/fixmatchLA.yaml
We provide the seed 0 pre-training checkpoints for Fixmatch and Fixmatch w/ LA loss
Dataset | CIFAR-10 | CIFAR-10 | CIFAR-10 |
---|---|---|---|
Fixmatch | Google Drive Link | Google Drive Link | Google Drive Link |
w/ LA | Google Drive Link | Google Drive Link | Google Drive Link |
Dataset | STL-10 | Imagenet-100 | CIFAR-100 |
---|---|---|---|
FixMatch | Google Drive Link | Google Drive Link | Google Drive Link |
w/ LA | Google Drive Link | Google Drive Link | Google Drive Link |
To start the fine-tuning, set the hyperparameters for the specific objective and dataset in the config file parameters and run the following
python trainMetricOpt.py --config_file configs/cifar10/$N_1$-1500_$M_1$-3000_IBRL-100_IBRU-100/MinRecall.yaml
We obtain the following results when we run the fine-tuning process for the code. The metric reported in each table corresponds to the mean recall and min-recall for the fine-tuned FixMatch w/LA checkpoints.
Dataset | CIFAR-10 | CIFAR-10 | CIFAR-10 | CIFAR-100 | STL-10 | Imagenet-100 |
---|---|---|---|---|---|---|
Fixmatch w/ LA | 80.1 | 93.5 | 80.6 | 55.7 | ||
w/ SelMix | 85.3 | 93.8 | 81.4 | 56.1 |
Dataset | CIFAR-10 | CIFAR-10 | CIFAR-10 | CIFAR-100 (Min HT) | STL-10 | Imagenet-100 (Min HT) |
---|---|---|---|---|---|---|
Fixmatch w/ LA | 69.3 | 83.3 | 63.1 | 32.2 | ||
w/ SelMix | 79.7 | 88.4 | 72.7 | 56.4 |
Our pre-training code is based on the implementation FixMatch-pytorch.