Skip to content

Commit

Permalink
Merge pull request #95 from Flippchen/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Flippchen authored Feb 15, 2024
2 parents 2e7955c + 22c3c67 commit 30e521d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ tensorflow>=2.10.0
matplotlib>=3.7.2
Pillow>=10.0.0
keras>=2.10.0
# Hyperparameter optimization
keras-tuner<=1.4.6
# Inference
rembg>=2.0.50
onnxruntime>=1.15.1
Expand Down
32 changes: 22 additions & 10 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.callbacks import EarlyStopping, ModelCheckpoint
import os
import random
import keras_tuner

import tensorflow as tf
# Ignore warnings
Expand Down Expand Up @@ -45,6 +46,8 @@
"img_width": img_width,
"seed": random.randint(0, 1000) if random_seed else 123
}
tune_model = False
train = True

# Load dataset and classes
train_ds, val_ds, class_names = load_dataset(**config)
Expand Down Expand Up @@ -125,16 +128,25 @@
device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0'
print("Using Device:", device)

with tf.device(device):
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[lr, early_stopping, model_checkpoint, discord_callback],
class_weight=class_weights
)
# Plot and save model score
plot_model_score(history, name, model_type)
if tune_model:
tuner = keras_tuner.RandomSearch(
model,
objective='val_loss',
max_trials=5)
tuner.search(train_ds, epochs=5, validation_data=val_ds)
model = tuner.get_best_models()[0]

if train or not tune_model:
with tf.device(device):
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[lr, early_stopping, model_checkpoint, discord_callback],
class_weight=class_weights
)
# Plot and save model score
plot_model_score(history, name, model_type)

# Save model
model.save(f"{save_path}{name}.h5")
Expand Down

0 comments on commit 30e521d

Please sign in to comment.