From 5c5a23e53f530cc07501346d28aa8212f46bc2f4 Mon Sep 17 00:00:00 2001 From: ianjjohnson <3072903+ianstenbit@users.noreply.github.com> Date: Tue, 25 Jul 2023 10:25:51 -0600 Subject: [PATCH] Update ImageNet training script to support Keras Core --- .../classification/imagenet/basic_training.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/examples/training/classification/imagenet/basic_training.py b/examples/training/classification/imagenet/basic_training.py index a6912f6b84..cb273228cd 100644 --- a/examples/training/classification/imagenet/basic_training.py +++ b/examples/training/classification/imagenet/basic_training.py @@ -25,14 +25,10 @@ import tensorflow as tf from absl import flags -from tensorflow import keras -from tensorflow.keras import callbacks -from tensorflow.keras import losses -from tensorflow.keras import metrics -from tensorflow.keras import optimizers import keras_cv from keras_cv import models +from keras_cv.backend import keras from keras_cv.datasets import imagenet """ @@ -58,9 +54,6 @@ flags.DEFINE_string( "imagenet_path", None, "Directory from which to load Imagenet." ) -flags.DEFINE_string( - "backup_path", None, "Directory which will be used for training backups." -) flags.DEFINE_string( "weights_path", None, @@ -367,14 +360,14 @@ def __call__(self, step): with strategy.scope(): if FLAGS.learning_rate_schedule == COSINE_DECAY_WITH_WARMUP: - optimizer = optimizers.SGD( + optimizer = keras.optimizers.SGD( weight_decay=FLAGS.weight_decay, learning_rate=schedule, momentum=0.9, use_ema=FLAGS.use_ema, ) else: - optimizer = optimizers.SGD( + optimizer = keras.optimizers.SGD( weight_decay=FLAGS.weight_decay, learning_rate=INITIAL_LEARNING_RATE, momentum=0.9, @@ -386,7 +379,7 @@ def __call__(self, step): Next, we pick a loss function. We use CategoricalCrossentropy with label smoothing. """ -loss_fn = losses.CategoricalCrossentropy(label_smoothing=0.1) +loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=0.1) """ @@ -395,8 +388,8 @@ def __call__(self, step): """ with strategy.scope(): training_metrics = [ - metrics.CategoricalAccuracy(), - metrics.TopKCategoricalAccuracy(k=5), + keras.metrics.CategoricalAccuracy(), + keras.metrics.TopKCategoricalAccuracy(k=5), ] """ @@ -404,19 +397,18 @@ def __call__(self, step): We use EarlyStopping, BackupAndRestore, and a model checkpointing callback. """ model_callbacks = [ - callbacks.EarlyStopping(patience=20), - callbacks.BackupAndRestore(FLAGS.backup_path), - callbacks.ModelCheckpoint( + keras.callbacks.EarlyStopping(patience=20), + keras.callbacks.ModelCheckpoint( FLAGS.weights_path, save_weights_only=True, save_best_only=True ), - callbacks.TensorBoard( + keras.callbacks.TensorBoard( log_dir=FLAGS.tensorboard_path, write_steps_per_second=True ), ] if FLAGS.learning_rate_schedule == REDUCE_ON_PLATEAU: model_callbacks.append( - callbacks.ReduceLROnPlateau( + keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.1, patience=10,