diff --git a/training/make_nsfw_model.py b/training/make_nsfw_model.py index 82b375b..36b5257 100644 --- a/training/make_nsfw_model.py +++ b/training/make_nsfw_model.py @@ -128,6 +128,47 @@ flags.DEFINE_bool( "is_deprecated_tfhub_module", False, "Whether or not the supplied TF hub module is old and from Tensorflow 1.") +flags.DEFINE_float( + "label_smoothing", _DEFAULT_HPARAMS.label_smoothing, + "The degree of label smoothing to use.") +flags.DEFINE_float( + "validation_split", _DEFAULT_HPARAMS.validation_split, + "The percentage of data to use for validation.") +flags.DEFINE_string( + 'optimizer', _DEFAULT_HPARAMS.optimizer, + 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' + '"ftrl", "momentum", "sgd" or "rmsprop".') +flags.DEFINE_float( + 'adadelta_rho', _DEFAULT_HPARAMS.adadelta_rho, + 'The decay rate for adadelta.') +flags.DEFINE_float( + 'adagrad_initial_accumulator_value', _DEFAULT_HPARAMS.adagrad_initial_accumulator_value, + 'Starting value for the AdaGrad accumulators.') +flags.DEFINE_float( + 'adam_beta1', _DEFAULT_HPARAMS.adam_beta1, + 'The exponential decay rate for the 1st moment estimates.') +flags.DEFINE_float( + 'adam_beta2', _DEFAULT_HPARAMS.adam_beta2, + 'The exponential decay rate for the 2nd moment estimates.') +flags.DEFINE_float('opt_epsilon', _DEFAULT_HPARAMS.opt_epsilon, 'Epsilon term for the optimizer.') +flags.DEFINE_float('ftrl_learning_rate_power', _DEFAULT_HPARAMS.ftrl_learning_rate_power, + 'The learning rate power.') +flags.DEFINE_float( + 'ftrl_initial_accumulator_value', _DEFAULT_HPARAMS.ftrl_initial_accumulator_value, + 'Starting value for the FTRL accumulators.') +flags.DEFINE_float( + 'ftrl_l1', _DEFAULT_HPARAMS.ftrl_l1, 'The FTRL l1 regularization strength.') + +flags.DEFINE_float( + 'ftrl_l2', _DEFAULT_HPARAMS.ftrl_l2, 'The FTRL l2 regularization strength.') +flags.DEFINE_float('rmsprop_momentum', _DEFAULT_HPARAMS.rmsprop_momentum, 'Momentum.') +flags.DEFINE_float('rmsprop_decay', _DEFAULT_HPARAMS.rmsprop_decay, 'Decay term for RMSProp.') +flags.DEFINE_bool( + "do_data_augmentation", False, + "Whether or not to do data augmentation.") +flags.DEFINE_bool( + "use_mixed_precision", False, + "Whether or not to use NVIDIA mixed precision. Requires NVIDIA card with at least compute level 7.0") FLAGS = flags.FLAGS @@ -140,7 +181,26 @@ def _get_hparams_from_flags(): batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate, momentum=FLAGS.momentum, - dropout_rate=FLAGS.dropout_rate) + dropout_rate=FLAGS.dropout_rate, + label_smoothing=FLAGS.label_smoothing, + validation_split=FLAGS.validation_split, + optimizer=FLAGS.optimizer, + adadelta_rho=FLAGS.adadelta_rho, + adagrad_initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value, + adam_beta1=FLAGS.adam_beta1, + adam_beta2=FLAGS.adam_beta2, + opt_epsilon=FLAGS.opt_epsilon, + ftrl_learning_rate_power=FLAGS.ftrl_learning_rate_power, + ftrl_initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value, + ftrl_l1=FLAGS.ftrl_l1, + ftrl_l2=FLAGS.ftrl_l2, + rmsprop_momentum=FLAGS.rmsprop_momentum, + rmsprop_decay=FLAGS.rmsprop_decay, + do_data_augmentation=FLAGS.do_data_augmentation, + use_mixed_precision=FLAGS.use_mixed_precision + ) + + def _check_keras_dependencies(): @@ -178,6 +238,9 @@ def main(args): """Main function to be called by absl.app.run() after flag parsing.""" del args + #policy = mixed_precision.Policy('mixed_float16') + #mixed_precision.set_policy(policy) + #tf.config.gpu.set_per_process_memory_fraction(0.75) #tf.config.gpu.set_per_process_memory_growth(False) physical_devices = tf.config.list_physical_devices('GPU') diff --git a/training/make_nsfw_model_lib.py b/training/make_nsfw_model_lib.py index 4e59d97..ee65059 100644 --- a/training/make_nsfw_model_lib.py +++ b/training/make_nsfw_model_lib.py @@ -21,11 +21,13 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +import multiprocessing from pathlib import Path from absl import app from absl import flags from absl import logging from tensorflow import keras +from tensorflow.keras.mixed_precision import experimental as mixed_precision from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -67,11 +69,57 @@ def get_default_image_dir(): return tf.keras.utils.get_file("flower_photos", _DEFAULT_IMAGE_URL, untar=True) +def configure_optimizer(hparams): + """Configures the optimizer used for training. + + Args: + learning_rate: A scalar or `Tensor` learning rate. + + Returns: + An instance of an optimizer. + + Raises: + ValueError: if hparams.optimizer is not recognized. + """ + if hparams.optimizer == 'adadelta': + optimizer = tf.keras.optimizers.Adadelta( + hparams.learning_rate, + rho=hparams.adadelta_rho, + epsilon=hparams.opt_epsilon) + elif hparams.optimizer == 'adagrad': + optimizer = tf.keras.optimizers.Adagrad( + hparams.learning_rate, + initial_accumulator_value=hparams.adagrad_initial_accumulator_value) + elif hparams.optimizer == 'adam': + optimizer = tf.keras.optimizers.Adam( + hparams.learning_rate, + beta_1=hparams.adam_beta1, + beta_2=hparams.adam_beta2, + epsilon=hparams.opt_epsilon) + elif hparams.optimizer == 'ftrl': + optimizer = tf.keras.optimizers.Ftrl( + hparams.learning_rate, + learning_rate_power=hparams.ftrl_learning_rate_power, + initial_accumulator_value=hparams.ftrl_initial_accumulator_value, + l1_regularization_strength=hparams.ftrl_l1, + l2_regularization_strength=hparams.ftrl_l2) + elif hparams.optimizer == 'rmsprop': + optimizer = tf.keras.optimizers.RMSprop(learning_rate=hparams.learning_rate, epsilon=hparams.opt_epsilon, momentum=hparams.rmsprop_momentum) + elif hparams.optimizer == 'sgd': + optimizer = tf.keras.optimizers.SGD(learning_rate=hparams.learning_rate, momentum=hparams.momentum) + else: + raise ValueError('Optimizer [%s] was not recognized' % hparams.optimizer) + return optimizer + class HParams( collections.namedtuple("HParams", [ "train_epochs", "do_fine_tuning", "batch_size", "learning_rate", - "momentum", "dropout_rate" + "momentum", "dropout_rate", "label_smoothing", "validation_split", + "optimizer", "adadelta_rho", "adagrad_initial_accumulator_value", + "adam_beta1", "adam_beta2", "opt_epsilon", "ftrl_learning_rate_power", + "ftrl_initial_accumulator_value", "ftrl_l1", "ftrl_l2", "rmsprop_momentum", + "rmsprop_decay", "do_data_augmentation", "use_mixed_precision" ])): """The hyperparameters for make_image_classifier. @@ -93,11 +141,28 @@ def get_default_hparams(): batch_size=32, learning_rate=0.005, momentum=0.9, - dropout_rate=0.2) + dropout_rate=0.2, + label_smoothing=0.1, + validation_split=.20, + optimizer='rmsprop', + adadelta_rho=0.95, + adagrad_initial_accumulator_value=0.1, + adam_beta1=0.9, + adam_beta2=0.999, + opt_epsilon=1.0, + ftrl_learning_rate_power=-0.5, + ftrl_initial_accumulator_value=0.1, + ftrl_l1=0.0, + ftrl_l2=0.0, + rmsprop_momentum=0.9, + rmsprop_decay=0.9, + do_data_augmentation=False, + use_mixed_precision=False + ) def _get_data_with_keras(image_dir, image_size, batch_size, - do_data_augmentation=False): + validation_size=0.2, do_data_augmentation=False): """Gets training and validation data via keras_preprocessing. Args: @@ -126,7 +191,7 @@ def _get_data_with_keras(image_dir, image_size, batch_size, """ datagen_kwargs = dict(rescale=1./255, # TODO(b/139467904): Expose this as a flag. - validation_split=.20) + validation_split=validation_size) dataflow_kwargs = dict(target_size=image_size, batch_size=batch_size, interpolation="bilinear") @@ -143,7 +208,8 @@ def _get_data_with_keras(image_dir, image_size, batch_size, **datagen_kwargs) else: train_datagen = valid_datagen - train_generator = train_datagen.flow_from_directory( + + train_generator = train_datagen.flow_from_directory( image_dir, subset="training", shuffle=True, **dataflow_kwargs) indexed_labels = [(index, label) @@ -215,14 +281,27 @@ def build_model(module_layer, hparams, image_size, num_classes): The full classifier model. """ # TODO(b/139467904): Expose the hyperparameters below as flags. - model = tf.keras.Sequential([ - tf.keras.Input(shape=(image_size[0], image_size[1], 3), name='input', dtype='float32'), module_layer, - tf.keras.layers.Dropout(rate=hparams.dropout_rate), - tf.keras.layers.Dense( - num_classes, - kernel_regularizer=tf.keras.regularizers.l2(0.0001)), - tf.keras.layers.Activation('softmax', dtype='float32', name='prediction') - ]) + + if hparams.dropout_rate is not None and hparams.dropout_rate > 0: + model = tf.keras.Sequential([ + tf.keras.Input(shape=(image_size[0], image_size[1], 3), name='input', dtype='float32'), + module_layer, + tf.keras.layers.Dropout(rate=hparams.dropout_rate), + tf.keras.layers.Dense( + num_classes, + kernel_regularizer=tf.keras.regularizers.l2(0.0001)), + tf.keras.layers.Activation('softmax', dtype='float32', name='prediction') + ]) + else: + model = tf.keras.Sequential([ + tf.keras.Input(shape=(image_size[0], image_size[1], 3), name='input', dtype='float32'), + module_layer, + tf.keras.layers.Dense( + num_classes, + kernel_regularizer=None), + tf.keras.layers.Activation('softmax', dtype='float32', name='prediction') + ]) + print(model.summary()) return model @@ -249,20 +328,33 @@ def train_model(model, hparams, train_data_and_size, valid_data_and_size): Returns: The tf.keras.callbacks.History object returned by tf.keras.Model.fit(). """ + + earlystop_callback = tf.keras.callbacks.EarlyStopping( + monitor='val_accuracy', min_delta=0.0001, + patience=1) + train_data, train_size = train_data_and_size valid_data, valid_size = valid_data_and_size # TODO(b/139467904): Expose this hyperparameter as a flag. - loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1) + loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing=hparams.label_smoothing) + + if hparams.use_mixed_precision is True: + optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(configure_optimizer(hparams)) + else: + optimizer = configure_optimizer(hparams) + model.compile( - optimizer=tf.keras.optimizers.SGD( - lr=hparams.learning_rate, momentum=hparams.momentum), + optimizer=optimizer, loss=loss, metrics=["accuracy"]) steps_per_epoch = train_size // hparams.batch_size validation_steps = valid_size // hparams.batch_size return model.fit( train_data, + use_multiprocessing=False, + workers=multiprocessing.cpu_count() -1, epochs=hparams.train_epochs, + callbacks=[earlystop_callback], steps_per_epoch=steps_per_epoch, validation_data=valid_data, validation_steps=validation_steps) @@ -404,13 +496,17 @@ def make_image_classifier(tfhub_module, image_dir, hparams, must be omitted or set to that same value. """ + print("Using hparams:") + for key, value in hparams._asdict().items(): + print("\t{0} : {1}".format(key, value)) + module_layer = hub.KerasLayer(tfhub_module, trainable=hparams.do_fine_tuning) image_size = _image_size_for_module(module_layer, requested_image_size) print("Using module {} with image size {}".format( tfhub_module, image_size)) train_data_and_size, valid_data_and_size, labels = _get_data_with_keras( - image_dir, image_size, hparams.batch_size) + image_dir, image_size, hparams.batch_size, hparams.validation_split, hparams.do_data_augmentation) print("Found", len(labels), "classes:", ", ".join(labels)) model = build_model(module_layer, hparams, image_size, len(labels)) diff --git a/training/train_all_models.cmd b/training/train_all_models.cmd index aa88bc0..b31548f 100644 --- a/training/train_all_models.cmd +++ b/training/train_all_models.cmd @@ -1,27 +1,18 @@ :: You can add more models types from here: https://tfhub.dev/s?module-type=image-classification&tf-version=tf2 :: However, you must choose Tensorflow 2 models. V1 models will not work here. -:: https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4 -:: https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4 -:: https://tfhub.dev/google/imagenet/inception_v3/classification/4 -:: https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4 -:: https://tfhub.dev/tensorflow/efficientnet/b0/classification/1 +:: https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/4 +:: https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4 +:: https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4 +:: https://tfhub.dev/google/imagenet/nasnet_mobile/feature_vector/4 :: :: If you get CUDA_OUT_OF_MEMORY crash, you need to pass --batch_size NUMBER, reducing until you don't get this error. :: It is advised by Google not to have a batch size < 8. -:: Train EfficientNet B0 -python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\efficientnet_b0_224 --labels_output_file %cd%\..\trained_models\efficientnet_b0_224\class_labels.txt --tfhub_module https://tfhub.dev/tensorflow/efficientnet/b0/classification/1 --tflite_output_file %cd%\..\trained_models\efficientnet_b0_224\saved_model.tflite --train_epochs 5 --batch_size 16 --do_fine_tuning --learning_rate 0.05 --dropout_rate 0.0 --momentum 0.9 -:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. -:: tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\efficientnet_b0_224 %cd%\..\trained_models\efficientnet_b0_224\web_model -:: Or, for a quantized (1 byte) version -:: tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\efficientnet_b0_224 %cd%\..\trained_models\efficientnet_b0_224\web_model_quantized --quantization_bytes 1 - -:: Wait for Python/CUDA/GPU to recover. Seems to die without this. -Timeout /T 60 /Nobreak +:: Note that we set all of our target epochs to over 9000. This is because the trainer just uses early stopping internally. :: Train Mobilenet V2 140 -python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\mobilenet_v2_140_224 --labels_output_file %cd%\..\trained_models\mobilenet_v2_140_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4 --tflite_output_file %cd%\..\trained_models\mobilenet_v2_140_224\saved_model.tflite --train_epochs 5 --batch_size 32 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. +python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\mobilenet_v2_140_224 --labels_output_file %cd%\..\trained_models\mobilenet_v2_140_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/4 --tflite_output_file %cd%\..\trained_models\mobilenet_v2_140_224\saved_model.tflite --train_epochs 9001 --batch_size 32 --do_fine_tuning --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. :: tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\mobilenet_v2_140_224 %cd%\..\trained_models\mobilenet_v2_140_224\web_model :: Or, for a quantized (1 byte) version :: tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\mobilenet_v2_140_224 %cd%\..\trained_models\mobilenet_v2_140_224\web_model_quantized --quantization_bytes 1 @@ -30,8 +21,8 @@ python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_mo Timeout /T 60 /Nobreak :: Train Resnet V2 50 -python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\resnet_v2_50_224 --labels_output_file %cd%\..\trained_models\resnet_v2_50_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4 --tflite_output_file %cd%\..\trained_models\resnet_v2_50_224\saved_model.tflite --train_epochs 5 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. +python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\resnet_v2_50_224 --labels_output_file %cd%\..\trained_models\resnet_v2_50_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4 --tflite_output_file %cd%\..\trained_models\resnet_v2_50_224\saved_model.tflite --train_epochs 9001 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. ::tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\resnet_v2_50_224 %cd%\..\trained_models\resnet_v2_50_224\web_model :: Or, for a quantized (1 byte) version ::tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\resnet_v2_50_224 %cd%\..\trained_models\resnet_v2_50_224\web_model_quantized --quantization_bytes 1 @@ -40,8 +31,8 @@ python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_mo Timeout /T 60 /Nobreak :: Train Inception V3 -python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\inception_v3_224 --labels_output_file %cd%\..\trained_models\inception_v3_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/inception_v3/classification/4 --tflite_output_file %cd%\..\trained_models\inception_v3_224\saved_model.tflite --train_epochs 5 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. +python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\inception_v3_224 --labels_output_file %cd%\..\trained_models\inception_v3_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4 --tflite_output_file %cd%\..\trained_models\inception_v3_224\saved_model.tflite --train_epochs 9001 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. ::tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\inception_v3_224 %cd%\..\trained_models\inception_v3_224\web_model :: Or, for a quantized (1 byte) version ::tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\inception_v3_224 %cd%\..\trained_models\inception_v3_224\web_model_quantized --quantization_bytes 1 @@ -50,8 +41,8 @@ python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_mo Timeout /T 60 /Nobreak :: Train NasNetMobile -python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\nasnet_a_224 --labels_output_file %cd%\..\trained_models\nasnet_a_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4 --tflite_output_file %cd%\..\trained_models\nasnet_a_224\saved_model.tflite --train_epochs 5 --batch_size 24 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. +python make_nsfw_model.py --image_dir %cd%\..\images --image_size 224 --saved_model_dir %cd%\..\trained_models\nasnet_a_224 --labels_output_file %cd%\..\trained_models\nasnet_a_224\class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/nasnet_mobile/feature_vector/4 --tflite_output_file %cd%\..\trained_models\nasnet_a_224\saved_model.tflite --train_epochs 9001 --batch_size 24 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +:: Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. ::tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\nasnet_a_224 %cd%\..\trained_models\nasnet_a_224\web_modely :: Or, for a quantized (1 byte) version ::tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve %cd%\..\trained_models\nasnet_a_224 %cd%\..\trained_models\nasnet_a_224\web_modely_quantized --quantization_bytes 1 \ No newline at end of file diff --git a/training/train_all_models.sh b/training/train_all_models.sh index 746cd51..a43db37 100644 --- a/training/train_all_models.sh +++ b/training/train_all_models.sh @@ -1,28 +1,19 @@ #!/bin/sh # You can add more models types from here: https://tfhub.dev/s?module-type=image-classification&tf-version=tf2 # However, you must choose Tensorflow 2 models. V1 models will not work here. -# https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4 -# https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4 -# https://tfhub.dev/google/imagenet/inception_v3/classification/4 -# https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4 -# https://tfhub.dev/tensorflow/efficientnet/b0/classification/1 -# +# https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/4 +# https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4 +# https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4 +# https://tfhub.dev/google/imagenet/nasnet_mobile/feature_vector/4 +# # If you get CUDA_OUT_OF_MEMORY crash, you need to pass --batch_size NUMBER, reducing until you don't get this error. # It is advised by Google not to have a batch size < 8. -# Train EfficientNet B0 -python3 make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/efficientnet_b0_224 --labels_output_file $PWD/../trained_models/efficientnet_b0_224/class_labels.txt --tfhub_module https://tfhub.dev/tensorflow/efficientnet/b0/classification/1 --tflite_output_file $PWD/../trained_models/efficientnet_b0_224/saved_model.tflite --train_epochs 5 --batch_size 16 --do_fine_tuning --learning_rate 0.05 --dropout_rate 0.0 --momentum 0.9 -# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. -# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/efficientnet_b0_224 $PWD/../trained_models/efficientnet_b0_224/web_model -# Or, for a quantized (1 byte) version -# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/efficientnet_b0_224 $PWD/../trained_models/efficientnet_b0_224/web_model_quantized --quantization_bytes 1 - -# Wait for Python/CUDA/GPU to recover. Seems to die without this. -sleep 60 +# Note that we set all of our target epochs to over 9000. This is because the trainer just uses early stopping internally. # Train Mobilenet V2 140 -python3 make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/mobilenet_v2_140_224 --labels_output_file $PWD/../trained_models/mobilenet_v2_140_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4 --tflite_output_file $PWD/../trained_models/mobilenet_v2_140_224/saved_model.tflite --train_epochs 5 --batch_size 32 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. +python make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/mobilenet_v2_140_224 --labels_output_file $PWD/../trained_models/mobilenet_v2_140_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/4 --tflite_output_file $PWD/../trained_models/mobilenet_v2_140_224/saved_model.tflite --train_epochs 9001 --batch_size 32 --do_fine_tuning --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. # tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/mobilenet_v2_140_224 $PWD/../trained_models/mobilenet_v2_140_224/web_model # Or, for a quantized (1 byte) version # tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/mobilenet_v2_140_224 $PWD/../trained_models/mobilenet_v2_140_224/web_model_quantized --quantization_bytes 1 @@ -31,28 +22,28 @@ python3 make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_m sleep 60 # Train Resnet V2 50 -python3 make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/resnet_v2_50_224 --labels_output_file $PWD/../trained_models/resnet_v2_50_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4 --tflite_output_file $PWD/../trained_models/resnet_v2_50_224/saved_model.tflite --train_epochs 5 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. -#tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/resnet_v2_50_224 $PWD/../trained_models/resnet_v2_50_224/web_model +python make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/resnet_v2_50_224 --labels_output_file $PWD/../trained_models/resnet_v2_50_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4 --tflite_output_file $PWD/../trained_models/resnet_v2_50_224/saved_model.tflite --train_epochs 9001 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. +# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/resnet_v2_50_224 $PWD/../trained_models/resnet_v2_50_224/web_model # Or, for a quantized (1 byte) version -#tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/resnet_v2_50_224 $PWD/../trained_models/resnet_v2_50_224/web_model_quantized --quantization_bytes 1 +# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/resnet_v2_50_224 $PWD/../trained_models/resnet_v2_50_224/web_model_quantized --quantization_bytes 1 # Wait for Python/CUDA/GPU to recover. Seems to die without this. sleep 60 # Train Inception V3 -python3 make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/inception_v3_224 --labels_output_file $PWD/../trained_models/inception_v3_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/inception_v3/classification/4 --tflite_output_file $PWD/../trained_models/inception_v3_224/saved_model.tflite --train_epochs 5 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. -#tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/inception_v3_224 $PWD/../trained_models/inception_v3_224/web_model +python make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/inception_v3_224 --labels_output_file $PWD/../trained_models/inception_v3_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4 --tflite_output_file $PWD/../trained_models/inception_v3_224/saved_model.tflite --train_epochs 9001 --batch_size 16 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. +# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/inception_v3_224 $PWD/../trained_models/inception_v3_224/web_model # Or, for a quantized (1 byte) version -#tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/inception_v3_224 $PWD/../trained_models/inception_v3_224/web_model_quantized --quantization_bytes 1 +# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/inception_v3_224 $PWD/../trained_models/inception_v3_224/web_model_quantized --quantization_bytes 1 # Wait for Python/CUDA/GPU to recover. Seems to die without this. sleep 60 # Train NasNetMobile -python3 make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/nasnet_a_224 --labels_output_file $PWD/../trained_models/nasnet_a_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4 --tflite_output_file $PWD/../trained_models/nasnet_a_224/saved_model.tflite --train_epochs 5 --batch_size 24 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --momentum 0.9 -# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. -#tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/nasnet_a_224 $PWD/../trained_models/nasnet_a_224/web_modely +python make_nsfw_model.py --image_dir $PWD/../images --image_size 224 --saved_model_dir $PWD/../trained_models/nasnet_a_224 --labels_output_file $PWD/../trained_models/nasnet_a_224/class_labels.txt --tfhub_module https://tfhub.dev/google/imagenet/nasnet_mobile/feature_vector/4 --tflite_output_file $PWD/../trained_models/nasnet_a_224/saved_model.tflite --train_epochs 9001 --batch_size 24 --do_fine_tuning --learning_rate 0.001 --dropout_rate 0.0 --label_smoothing=0.0 --validation_split=0.1 --do_data_augmentation=True --use_mixed_precision=True --rmsprop_momentum=0.0 +# Note that installing tensorflowjs also installs tensorflow-cpu A.K.A. bye-bye-training. So make sure you perform this step after all your training is done, and then restore a GPU version of TF. +# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/nasnet_a_224 $PWD/../trained_models/nasnet_a_224/web_modely # Or, for a quantized (1 byte) version -#tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/nasnet_a_224 $PWD/../trained_models/nasnet_a_224/web_modely_quantized --quantization_bytes 1 \ No newline at end of file +# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve $PWD/../trained_models/nasnet_a_224 $PWD/../trained_models/nasnet_a_224/web_modely_quantized --quantization_bytes 1 \ No newline at end of file