diff --git a/requirements.txt b/requirements.txt index 8cdba01..03e2997 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,8 @@ tensorflow>=2.10.0 matplotlib>=3.7.2 Pillow>=10.0.0 keras>=2.10.0 +# Hyperparameter optimization +keras-tuner<=1.4.6 # Inference rembg>=2.0.50 onnxruntime>=1.15.1 diff --git a/training/train.py b/training/train.py index 59f9a56..b4c7c86 100644 --- a/training/train.py +++ b/training/train.py @@ -12,6 +12,7 @@ from keras.callbacks import EarlyStopping, ModelCheckpoint import os import random +import keras_tuner import tensorflow as tf # Ignore warnings @@ -45,6 +46,8 @@ "img_width": img_width, "seed": random.randint(0, 1000) if random_seed else 123 } +tune_model = False +train = True # Load dataset and classes train_ds, val_ds, class_names = load_dataset(**config) @@ -125,16 +128,25 @@ device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0' print("Using Device:", device) -with tf.device(device): - history = model.fit( - train_ds, - validation_data=val_ds, - epochs=epochs, - callbacks=[lr, early_stopping, model_checkpoint, discord_callback], - class_weight=class_weights - ) -# Plot and save model score -plot_model_score(history, name, model_type) +if tune_model: + tuner = keras_tuner.RandomSearch( + model, + objective='val_loss', + max_trials=5) + tuner.search(train_ds, epochs=5, validation_data=val_ds) + model = tuner.get_best_models()[0] + +if train or not tune_model: + with tf.device(device): + history = model.fit( + train_ds, + validation_data=val_ds, + epochs=epochs, + callbacks=[lr, early_stopping, model_checkpoint, discord_callback], + class_weight=class_weights + ) + # Plot and save model score + plot_model_score(history, name, model_type) # Save model model.save(f"{save_path}{name}.h5")