Skip to content

Commit

Permalink
Update ImageNet training script to support Keras Core (#1976)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianstenbit authored Jul 25, 2023
1 parent 90ccc1f commit 79a3cbe
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions examples/training/classification/imagenet/basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@

import tensorflow as tf
from absl import flags
from tensorflow import keras
from tensorflow.keras import callbacks
from tensorflow.keras import losses
from tensorflow.keras import metrics
from tensorflow.keras import optimizers

import keras_cv
from keras_cv import models
from keras_cv.backend import keras
from keras_cv.datasets import imagenet

"""
Expand All @@ -58,9 +54,6 @@
flags.DEFINE_string(
"imagenet_path", None, "Directory from which to load Imagenet."
)
flags.DEFINE_string(
"backup_path", None, "Directory which will be used for training backups."
)
flags.DEFINE_string(
"weights_path",
None,
Expand Down Expand Up @@ -367,14 +360,14 @@ def __call__(self, step):

with strategy.scope():
if FLAGS.learning_rate_schedule == COSINE_DECAY_WITH_WARMUP:
optimizer = optimizers.SGD(
optimizer = keras.optimizers.SGD(
weight_decay=FLAGS.weight_decay,
learning_rate=schedule,
momentum=0.9,
use_ema=FLAGS.use_ema,
)
else:
optimizer = optimizers.SGD(
optimizer = keras.optimizers.SGD(
weight_decay=FLAGS.weight_decay,
learning_rate=INITIAL_LEARNING_RATE,
momentum=0.9,
Expand All @@ -386,7 +379,7 @@ def __call__(self, step):
Next, we pick a loss function. We use CategoricalCrossentropy with label
smoothing.
"""
loss_fn = losses.CategoricalCrossentropy(label_smoothing=0.1)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=0.1)


"""
Expand All @@ -395,28 +388,27 @@ def __call__(self, step):
"""
with strategy.scope():
training_metrics = [
metrics.CategoricalAccuracy(),
metrics.TopKCategoricalAccuracy(k=5),
keras.metrics.CategoricalAccuracy(),
keras.metrics.TopKCategoricalAccuracy(k=5),
]

"""
As a last piece of configuration, we configure callbacks for the method.
We use EarlyStopping, BackupAndRestore, and a model checkpointing callback.
"""
model_callbacks = [
callbacks.EarlyStopping(patience=20),
callbacks.BackupAndRestore(FLAGS.backup_path),
callbacks.ModelCheckpoint(
keras.callbacks.EarlyStopping(patience=20),
keras.callbacks.ModelCheckpoint(
FLAGS.weights_path, save_weights_only=True, save_best_only=True
),
callbacks.TensorBoard(
keras.callbacks.TensorBoard(
log_dir=FLAGS.tensorboard_path, write_steps_per_second=True
),
]

if FLAGS.learning_rate_schedule == REDUCE_ON_PLATEAU:
model_callbacks.append(
callbacks.ReduceLROnPlateau(
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.1,
patience=10,
Expand Down

0 comments on commit 79a3cbe

Please sign in to comment.