This repository provides TensorFlow / Keras implementations of different CaiT
[1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been
populated with the original CaiT pre-trained params available from [2]. These
models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model
objects and one can call all the utility functions on them (example: .summary()
).
As of today, all the TensorFlow / Keras variants of the CaiT models listed here are available in this repository.
Refer to the "Using the models" section to get started.
Updates Oct 8 2022: This project received the Kaggle ML Research Spotlight Prize (September 2022)
Updates Sept 25 2022: Blog post on CaiT
- Conversion
- Collection of pre-trained models (converted from PyTorch to TensorFlow)
- Results of the converted models
- How to use the models?
- References
- Acknowledgements
TensorFlow / Keras implementations are available in cait/models.py
. Conversion
utilities are in convert.py
.
Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/cait/1. You can fully inspect the architecture of the TF-Hub models like so:
import tensorflow as tf
model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)
dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))
Results are on ImageNet-1k validation set (top-1 and top-5 accuracies).
model_name | top1_acc(%) | top5_acc(%) |
---|---|---|
cait_s24_224 | 83.368 | 96.576 |
cait_xxs24_224 | 78.524 | 94.212 |
cait_xxs36_224 | 79.76 | 94.876 |
cait_xxs36_384 | 81.976 | 96.064 |
cait_xxs24_384 | 80.648 | 95.516 |
cait_xs24_384 | 83.738 | 96.756 |
cait_s24_384 | 84.944 | 97.212 |
cait_s36_384 | 85.192 | 97.372 |
cait_m36_384 | 85.924 | 97.598 |
cait_m48_448 | 86.066 | 97.590 |
Results can be verified with the code in i1k_eval
. Results are in line with [1].
Slight differences in the results stemmed
from the fact that I used a different set of augmentation transformations. Original
transformations suggested by the authors can be found here.
Pre-trained models:
- Off-the-shelf classification: Colab Notebook
- Fine-tuning: Colab Notebook
These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image (following figures 6 and 7 of the original paper).
Original Image | Class Attention Maps | Class Saliency Map |
---|---|---|
For the best quality, refer to the assets
directory. You can also generate these plots using the following interactive demos on
Hugging Face Spaces:
Randomly initialized models:
from cait.model_configs import base_config
from cait.models import CaiT
import tensorflow as tf
config = base_config.get_config(
model_name="cait_xxs24_224"
)
cait_xxs24_224 = CaiT(config)
dummy_inputs = tf.ones((2, 224, 224, 3))
_ = cait_xxs24_224(dummy_inputs)
print(cait_xxs24_224.summary(expand_nested=True))
To initialize a network with say, 5 classes, do:
config = base_config.get_config(
model_name="cait_xxs24_224"
)
with config.unlocked():
config.num_classes = 5
cait_xxs24_224 = CaiT(config)
To view different model configurations, refer to convert_all_models.py
.
[1] CaiT paper: https://arxiv.org/abs/2103.17239
[2] Official CaiT code: https://github.com/facebookresearch/deit
timm
library source code for the awesome codebase.- ML-GDE program for providing GCP credits that supported my experiments.