Skip to content

Commit

Permalink
Feat/deit trainer (#2)
Browse files Browse the repository at this point in the history
* feat: teacher training notebook.

* chore: fixed the random seed.

* fix: loss function/

* refactor: layers.

* chore: aligned the initialization with timm.

* fix: stochastic depth block.

* chore: added deit trainer and modified readme.
  • Loading branch information
sayakpaul authored Apr 4, 2022
1 parent 3572474 commit 762ac52
Show file tree
Hide file tree
Showing 8 changed files with 695 additions and 89 deletions.
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ section to get started.
* [Collection of pre-trained models (converted from PyTorch to TensorFlow)](https://github.com/sayakpaul/deit-tf#models)
* [Results of the converted models](https://github.com/sayakpaul/deit-tf#results)
* [How to use the models?](https://github.com/sayakpaul/deit-tf#using-the-models)
* [Training with DeiT](https://github.com/sayakpaul/deit-tf#training-with-deit)
* [References](https://github.com/sayakpaul/deit-tf#references)
* [Acknowledgements](https://github.com/sayakpaul/deit-tf#references)
* [Acknowledgements](https://github.com/sayakpaul/deit-tf#acknowledgements)

## Conversion

Expand Down Expand Up @@ -80,9 +81,9 @@ for more details. Additionally, the notebook shows how to visualize the attentio

<br>

**Randomly initialized models**:
**Randomly initialized models**:

```py
```py
from vit.model_configs import base_config
from vit.deit_models import ViTDistilled

Expand All @@ -94,17 +95,17 @@ deit_tiny_distilled_patch16_224 = ViTDistilled(distilled_tiny_tf_config)
dummy_inputs = tf.ones((2, 224, 224, 3))
_ = deit_tiny_distilled_patch16_224(dummy_inputs)
print(deit_tiny_distilled_patch16_224.summary(expand_nested=True))
```

To view different model configurations, refer to `convert_all_models.py`.

## Upcoming (contributions welcome)
```

- [ ] Align layer initializers (useful if someone wanted to train the models
from scratch)
- [x] Fine-tuning notebook
- [x] Off-the-shelf-classification notebook
- [x] Publish models on TF-Hub
To view different model configurations, refer to `convert_all_models.py`.


## Training with DeiT

You can refer to the `notebooks/deit-trainer.ipynb` notebok to get a sense of how distillation
is actually performed using DeiT. Additionally, that notebook also provides code in case you
wanted to train a model from scratch instead of distillation.


## References

Expand Down
237 changes: 237 additions & 0 deletions notebooks/bit-teacher.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c8be0c8a",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc1ff8d8",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow_hub as hub\n",
"\n",
"from tensorflow import keras\n",
"\n",
"tfds.disable_progress_bar()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "markdown",
"id": "4d46385b",
"metadata": {},
"source": [
"## Constants"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ec55567",
"metadata": {},
"outputs": [],
"source": [
"MODULE_URL = \"https://tfhub.dev/google/bit/m-r50x3/1\"\n",
"\n",
"BATCH_SIZE = 128\n",
"SZ = 224\n",
"NUM_EPOCHS = 10\n",
"\n",
"AUTO = tf.data.AUTOTUNE\n",
"NB_CLASSES = 5"
]
},
{
"cell_type": "markdown",
"id": "7b39b885",
"metadata": {},
"source": [
"## Data preprocessing and loading"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c83598e",
"metadata": {},
"outputs": [],
"source": [
"def preprocess_dataset(is_training=True):\n",
" def _pp(image, label):\n",
" if is_training:\n",
" # Resize to a bigger spatial resolution and take the random\n",
" # crops.\n",
" image = tf.image.resize(image, (SZ + 20, SZ + 20))\n",
" image = tf.image.random_crop(image, (SZ, SZ, 3))\n",
" image = tf.image.random_flip_left_right(image)\n",
" else:\n",
" image = tf.image.resize(image, (SZ, SZ))\n",
" label = tf.one_hot(label, depth=NB_CLASSES)\n",
" return image, label\n",
"\n",
" return _pp\n",
"\n",
"\n",
"def prepare_dataset(dataset, is_training=True):\n",
" if is_training:\n",
" dataset = dataset.shuffle(BATCH_SIZE * 10)\n",
" dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)\n",
" return dataset.batch(BATCH_SIZE).prefetch(AUTO)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9b01e60",
"metadata": {},
"outputs": [],
"source": [
"train_dataset, val_dataset = tfds.load(\n",
" \"tf_flowers\", split=[\"train[:90%]\", \"train[90%:]\"], as_supervised=True\n",
")\n",
"num_train = train_dataset.cardinality()\n",
"num_val = val_dataset.cardinality()\n",
"print(f\"Number of training examples: {num_train}\")\n",
"print(f\"Number of validation examples: {num_val}\")\n",
"\n",
"train_dataset = prepare_dataset(train_dataset, is_training=True)\n",
"val_dataset = prepare_dataset(val_dataset, is_training=False)"
]
},
{
"cell_type": "markdown",
"id": "969e65eb",
"metadata": {},
"source": [
"## Model initialization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60bc5d85",
"metadata": {},
"outputs": [],
"source": [
"hub_module = hub.KerasLayer(MODULE_URL)\n",
"\n",
"model = keras.Sequential(\n",
" [\n",
" keras.Input((SZ, SZ, 3)),\n",
" keras.layers.Rescaling(scale=1.0 / 255),\n",
" hub_module,\n",
" keras.layers.Dense(NB_CLASSES, kernel_initializer=\"zeros\"),\n",
" ],\n",
" name=\"bit_teacher_flowers\",\n",
")\n",
"print(f\"Number of parameters (millions): {model.count_params() / 1e6}.\")"
]
},
{
"cell_type": "markdown",
"id": "5ee776ae",
"metadata": {},
"source": [
"## Optimizer and loss function"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac3e5526",
"metadata": {},
"outputs": [],
"source": [
"SCHEDULE_LENGTH = 500\n",
"SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE\n",
"\n",
"SCHEDULE_BOUNDARIES = [200, 300, 400]\n",
"lr = 0.003 * BATCH_SIZE / 512\n",
"\n",
"# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.\n",
"lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(\n",
" boundaries=SCHEDULE_BOUNDARIES, values=[lr, lr * 0.1, lr * 0.001, lr * 0.0001]\n",
")\n",
"optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)\n",
"\n",
"loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)"
]
},
{
"cell_type": "markdown",
"id": "01a9c53a",
"metadata": {},
"source": [
"## Train the model and save it"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b037387",
"metadata": {},
"outputs": [],
"source": [
"model.compile(optimizer=optimizer, loss=loss_fn, metrics=[\"accuracy\"])\n",
"\n",
"history = model.fit(\n",
" train_dataset.repeat(),\n",
" batch_size=BATCH_SIZE,\n",
" steps_per_epoch=10,\n",
" epochs=NUM_EPOCHS,\n",
" validation_data=val_dataset,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a27e3ec",
"metadata": {},
"outputs": [],
"source": [
"model.save(\"bit_teacher_flowers\")"
]
},
{
"cell_type": "markdown",
"id": "862ff7a3",
"metadata": {},
"source": [
"## References\n",
"\n",
"* [Official Colab Notebook from BiT authors](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 762ac52

Please sign in to comment.