Artyom Gadetsky*, Yulun Jiang*, Maria Brbić
Project page
| Paper
| BibTeX
This repo contains the source code of 🐢 TURTLE, an unupervised learning algorithm written in PyTorch. 🔥 TURTLE achieves state-of-the-art unsupervised performance on the variety of benchmark datasets. For more details please check our paper Let Go of Your Labels with Unsupervised Transfer (ICML '24).
The code is built with the following libraries
- PyTorch - 2.2.1
- torchvision - 0.17.1
- numpy
- scipy
- scikit-learn
- clip
- tqdm
- cuml - 24.02
To install cuml, you can follow the instructions on this page.
In our paper, we consider 26 vision datasets studied in (Radford et al. 2021) and 9 different foundation models. As a running example, we present the full pipeline to train TURTLE on the CIFAR100 dataset.
- Precompute representations and save ground truth labels for the dataset
python precompute_representations.py --dataset cifar100 --phis clipvitL14
python precompute_representations.py --dataset cifar100 --phis dinov2
python precompute_labels.py --dataset cifar100
- Train TURTLE with 2 representation spaces
python run_turtle.py --dataset cifar100 --phis clipvitL14 dinov2
or with the single representation space
python run_turtle.py --dataset cifar100 --phis clipvitL14
python run_turtle.py --dataset cifar100 --phis dinov2
The results and the checkpoints will be saved at ./data/results
, ./data/task_checkpoints
. You can also use --root_dir
in all scripts to specify root directory instead of ./data
which is used by default.
Most datasets can be automatically downloaded by running precompute_representations.py
and precompute_labels.py
. However, some of the datasets require manual downloading. Please check dataset_preparation/data_utils.py
for guide to prepare all the datasets used in our paper.
As an example, to prepare pets
dataset that is not directly available at torchvision.datasets
, one can run:
python dataset_preparation/prepare_pets.py -i ./data/datasets/pets -o ./data/datasets/pets -d
to download and extract the dataset at ./data/datasets/pets
.
After downloading the dataset, run the following command to precompute the representations and labels:
python precompute_representations.py --dataset ${DATASET} --phis ${REPRESENTATION}
python precompute_labels.py --dataset ${DATASET}
Datasets and representations covered in this repo:
- 26 datasets:
food101, cifar10, cifar100, birdsnap, sun397, cars, aircraft, dtd, pets, caltech101, flowers, mnist, fer2013, stl10, eurosat, resisc45, gtsrb, kitti, country211, pcam, ucf101, kinetics700, clevr, hatefulmemes, sst, imagenet
. - 9 representations:
clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14, dinov2
.
Once the representations and labels are precomputed, to train TURTLE with a single space, run:
python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION}
or to train TURTLE with multiple representation spaces, run
python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION1} ${REPRESENTATION2}
You can also use --inner_lr
, ---outer_lr
, --warm_start
to specify inner step size, outer step size and whether to use cold-start or warm start bilevel optimization. Furthermore, use --cross_val
to compute the generalization score for the found labeling after training. You can perform hyperparameter sweep and use the generalization score to select the best hyperparemeters without using ground truth labels.
We also release the labelings found by TURTLE for all datasets and all model architectures used in our paper. To download pre-trained checkpoints, run:
wget https://brbiclab.epfl.ch/wp-content/uploads/2024/06/turtle_tasks.zip
unzip turtle_tasks.zip
Then, you can evaluate the pre-trained checkpoint of TURTLE with the single space by running:
python evaluate.py --dataset cifar100 --phis clipvitL14 --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/clipvitL14/cifar100.pt
python evaluate.py --dataset cifar100 --phis dinov2 --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/dinov2/cifar100.pt
or evaluate using two representation spaces using:
python evaluate.py --dataset cifar100 --phis clipvitL14 dinov2 --task_ckpt {PATH_TO_TURTLE_TASKS}/2space/clipvitL14_dinov2/cifar100.pt
We also provide implemetation of Zero-shot Transfer with CLIP, Linear Probe and K-Means baselines in the baselines
folder. To implement linear probe and K-Means baselines we employ cuml for highly efficient cuda implementations.
Precompute the representations and then perform linear probe evaluation by running:
python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION}
To select the l2 regularization strength for better performance, run
python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION} --validation
Precompute the representations and run K-Means baseline:
python baselines/kmeans.py --dataset ${DATASET} --phis ${REPRESENTATION}
Run CLIP zero-shot transfer:
python baselines/clip_zs.py --dataset ${DATASET} --phis ${REPRESENTATION}
While developing TURTLE we greatly benefited from the open-source repositories:
If you find our code useful, please consider citing:
@inproceedings{
gadetsky2024let,
title={Let Go of Your Labels with Unsupervised Transfer},
author={Gadetsky, Artyom and Jiang, Yulun and Brbi\'c, Maria},
booktitle={International Conference on Machine Learning},
year={2024},
}