A Replication Study of Decoupling Feature Extraction and Classification Layers for Calibrated Neural Networks
This project extends on the paper Decoupling Feature Extraction and Classification Layers for Calibrated Neural Networks. We aimed to reproduce the main experiments of the paper "Decoupling Feature Extraction and Classification Layers for Calibrated Neural Networks" by Jordahn and Olmos, and extended the code base in the following way:
- Added a simple CNN architecture
- Added a ViT architecture and integrated Tiny ImageNet dataset
- Added different TST and V-TST MLP heads
- Added training with multiple samples
- Added a ResNet-50 architecture
- Added Focal Loss and Adaptive Focal Loss
We used CIFAR-10 and CIFAR-100 as well as CIFAR-10-C and CIFAR-100-C to reproduce the main results. For finetuning the ViT, we used Tiny ImageNet. For out of distribution detection, we included SVHN for evaluation.
We did the following experiments:
- Reproduce first two rows of Table 1 and 2 from the original paper.
- Reproduce Table 3 and 6 from the original paper.
- Four additional ablation studies with CIFAR10:
- Dependence on Network Architecture: Run the experiments of Table 1 and 2 with ResNet-50 to see if they extend to a new architecture.
- Second Stage Network Architecture: Run the experiments of Table 1 and 2 with different sized MLP for the second stage to see their effect.
- VTST Dependence on Number of Training Samples: Run the experiments of Table 1 and 2 for V-TST with different number of samples used for training to see their effect.
- Focal Loss: Run the experiments of Table 1 and 2 with Focal Loss and Adaptive Focal Loss to see how the TST and V-TST techniques perform in combination with another implicit regularization method for calibration.
TBD
To get started follow the instructions of the original authors, and if you want to run our additional experiments use the following arguments as shown below.
For the model you can now use the following arguments: WRN
, VIT
, CNN
and ResNet50
. The VIT
is supposed to be used in combination with --dataset TINYIMAGENET
, and through --vitbase
you can specify which weights you want to load for finetuning.
If you want to experiment with TST or V-TST MLP sizes or number of training samples, you can use the argument --model TSTEXP
or --model VTSTEXP
together with arguments --mlp_size
and --train_samples
.
If you want to use different loss functions, we provide three examples of how to specify it into your experiment.sh
bash script.
To firstly train a base WRN 28-10 for a dataset, run the following command but replacing <DATASET> with CIFAR10, SVHN or CIFAR100:
python3 src/experiments/00_train_models.py \
--model <MODEL> \
--epochs 600 \
--accelerator gpu \
--seed <SEED> \
--dataset <DATASET> \
--model_name <MODEL>_<DATASET>_Base \
--batch_size 256 \
--loss <loss>
where loss
is one of the following: ce
, fl
or fla
. If the loss is fl
or fla
, the arguments --gammas
and --probs
have to be included. The default loss is ce
.
The trained model found with best validation should be saved in ./experiment_results/\<DATASET\>_<MODEL>/checkpoints
. Now to run TST using Adaptive Focal Loss (fla
), run the following command-line command:
python3 src/experiments/00_train_models.py \
--freeze_qyx \
--model_name TST_<DATASET>_Z<Z> \
--model TST \
--epochs 40 \
--accelerator gpu \
--latent_dim <Z> \
--seed <SEED> \
--pretrained_qyx <PATH_TO_TRAINED_MODEL> \
--dataset <DATASET> \
--loss fla \
--gammas 5 3 \
--probs 0.2 1
To similarly train V-TST using Focal Loss (fl
) with constant gamma, run:
python3 src/experiments/00_train_models.py \
--freeze_qyx \
--model_name VTST_<DATASET>_Z<Z> \
--model VTST \
--epochs 40 \
--accelerator gpu \
--latent_dim <Z> \
--seed <SEED> \
--pretrained_qyx <PATH_TO_TRAINED_MODEL> \
--dataset <DATASET> \
--loss fl \
--gammas 3 \
--probs 1
Tip
If you need help setting up a deep learning VM instance on Google Cloud Platform and connecting it to your local VS Code, check out this repository for detailed guidance.
To run experiments in a background process use nohup
, a small example is provided in train_scripts/example_train_run.ipynb
.