-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: teacher training notebook. * chore: fixed the random seed. * fix: loss function/ * refactor: layers. * chore: aligned the initialization with timm. * fix: stochastic depth block. * chore: added deit trainer and modified readme.
- Loading branch information
Showing
8 changed files
with
695 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c8be0c8a", | ||
"metadata": {}, | ||
"source": [ | ||
"## Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bc1ff8d8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"import tensorflow_datasets as tfds\n", | ||
"import tensorflow_hub as hub\n", | ||
"\n", | ||
"from tensorflow import keras\n", | ||
"\n", | ||
"tfds.disable_progress_bar()\n", | ||
"tf.random.set_seed(42)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4d46385b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Constants" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2ec55567", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"MODULE_URL = \"https://tfhub.dev/google/bit/m-r50x3/1\"\n", | ||
"\n", | ||
"BATCH_SIZE = 128\n", | ||
"SZ = 224\n", | ||
"NUM_EPOCHS = 10\n", | ||
"\n", | ||
"AUTO = tf.data.AUTOTUNE\n", | ||
"NB_CLASSES = 5" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7b39b885", | ||
"metadata": {}, | ||
"source": [ | ||
"## Data preprocessing and loading" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5c83598e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def preprocess_dataset(is_training=True):\n", | ||
" def _pp(image, label):\n", | ||
" if is_training:\n", | ||
" # Resize to a bigger spatial resolution and take the random\n", | ||
" # crops.\n", | ||
" image = tf.image.resize(image, (SZ + 20, SZ + 20))\n", | ||
" image = tf.image.random_crop(image, (SZ, SZ, 3))\n", | ||
" image = tf.image.random_flip_left_right(image)\n", | ||
" else:\n", | ||
" image = tf.image.resize(image, (SZ, SZ))\n", | ||
" label = tf.one_hot(label, depth=NB_CLASSES)\n", | ||
" return image, label\n", | ||
"\n", | ||
" return _pp\n", | ||
"\n", | ||
"\n", | ||
"def prepare_dataset(dataset, is_training=True):\n", | ||
" if is_training:\n", | ||
" dataset = dataset.shuffle(BATCH_SIZE * 10)\n", | ||
" dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)\n", | ||
" return dataset.batch(BATCH_SIZE).prefetch(AUTO)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b9b01e60", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_dataset, val_dataset = tfds.load(\n", | ||
" \"tf_flowers\", split=[\"train[:90%]\", \"train[90%:]\"], as_supervised=True\n", | ||
")\n", | ||
"num_train = train_dataset.cardinality()\n", | ||
"num_val = val_dataset.cardinality()\n", | ||
"print(f\"Number of training examples: {num_train}\")\n", | ||
"print(f\"Number of validation examples: {num_val}\")\n", | ||
"\n", | ||
"train_dataset = prepare_dataset(train_dataset, is_training=True)\n", | ||
"val_dataset = prepare_dataset(val_dataset, is_training=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "969e65eb", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model initialization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "60bc5d85", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"hub_module = hub.KerasLayer(MODULE_URL)\n", | ||
"\n", | ||
"model = keras.Sequential(\n", | ||
" [\n", | ||
" keras.Input((SZ, SZ, 3)),\n", | ||
" keras.layers.Rescaling(scale=1.0 / 255),\n", | ||
" hub_module,\n", | ||
" keras.layers.Dense(NB_CLASSES, kernel_initializer=\"zeros\"),\n", | ||
" ],\n", | ||
" name=\"bit_teacher_flowers\",\n", | ||
")\n", | ||
"print(f\"Number of parameters (millions): {model.count_params() / 1e6}.\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5ee776ae", | ||
"metadata": {}, | ||
"source": [ | ||
"## Optimizer and loss function" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ac3e5526", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"SCHEDULE_LENGTH = 500\n", | ||
"SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE\n", | ||
"\n", | ||
"SCHEDULE_BOUNDARIES = [200, 300, 400]\n", | ||
"lr = 0.003 * BATCH_SIZE / 512\n", | ||
"\n", | ||
"# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.\n", | ||
"lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(\n", | ||
" boundaries=SCHEDULE_BOUNDARIES, values=[lr, lr * 0.1, lr * 0.001, lr * 0.0001]\n", | ||
")\n", | ||
"optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)\n", | ||
"\n", | ||
"loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "01a9c53a", | ||
"metadata": {}, | ||
"source": [ | ||
"## Train the model and save it" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6b037387", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.compile(optimizer=optimizer, loss=loss_fn, metrics=[\"accuracy\"])\n", | ||
"\n", | ||
"history = model.fit(\n", | ||
" train_dataset.repeat(),\n", | ||
" batch_size=BATCH_SIZE,\n", | ||
" steps_per_epoch=10,\n", | ||
" epochs=NUM_EPOCHS,\n", | ||
" validation_data=val_dataset,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9a27e3ec", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.save(\"bit_teacher_flowers\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "862ff7a3", | ||
"metadata": {}, | ||
"source": [ | ||
"## References\n", | ||
"\n", | ||
"* [Official Colab Notebook from BiT authors](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.