diff --git a/assets/banner_gif.gif b/assets/banner_gif.gif new file mode 100644 index 0000000..c87ff2a Binary files /dev/null and b/assets/banner_gif.gif differ diff --git a/training/old/vision.py b/training/old/vision.py index ac96073..2d3963a 100644 --- a/training/old/vision.py +++ b/training/old/vision.py @@ -115,7 +115,9 @@ # Train model epochs = 20 -with tf.device('/GPU:0'): +device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0' +print("Using Device:", device) +with tf.device(device): history = model.fit( train_ds, validation_data=val_ds, diff --git a/training/old/with_augmentation.py b/training/old/with_augmentation.py index 6b53b7b..03dce1c 100644 --- a/training/old/with_augmentation.py +++ b/training/old/with_augmentation.py @@ -69,7 +69,9 @@ # Train model epochs = 20 -with tf.device('/GPU:1'): +device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0' +print("Using Device:", device) +with tf.device(device): history = model.fit( train_ds, validation_data=val_ds, diff --git a/training/old/without_augmentation.py b/training/old/without_augmentation.py index e82cd49..535fb26 100644 --- a/training/old/without_augmentation.py +++ b/training/old/without_augmentation.py @@ -62,7 +62,9 @@ # Train model epochs = 20 -with tf.device('/GPU:1'): +device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0' +print("Using Device:", device) +with tf.device(device): history = model.fit( train_ds, validation_data=val_ds, diff --git a/training/pre_filter.py b/training/pre_filter.py index 41605a3..490bc4c 100644 --- a/training/pre_filter.py +++ b/training/pre_filter.py @@ -115,7 +115,9 @@ # Train model epochs = 15 -with tf.device('/GPU:0'): +device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0' +print("Using Device:", device) +with tf.device(device): history = model.fit( train_ds, validation_data=val_ds, diff --git a/training/train.py b/training/train.py index aaac8d5..6212471 100644 --- a/training/train.py +++ b/training/train.py @@ -1,8 +1,10 @@ # This file contains the code for training a model with data augmentation and a pretrained base. # Import libraries +import keras +from keras import layers from keras.models import Sequential from keras.applications import EfficientNetV2B1 -from utilities.tools import * +from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch, show_batch_shape from utilities.discord_callback import DiscordCallback from keras.optimizers import AdamW from keras.regularizers import l1_l2 @@ -10,6 +12,7 @@ import os import random +import tensorflow as tf # Ignore warnings import warnings @@ -32,9 +35,10 @@ # Set seed for reproducibility random_seed = True # Config +base_path = get_base_path() path_addon = get_data_path_addon(model_type) config = { - "path": f"C:/Users\phili/.keras/datasets/resized_DVM/{path_addon}", + "path": f"{base_path}/{path_addon}", "batch_size": 32, "img_height": img_height, "img_width": img_width, diff --git a/utilities/tools.py b/utilities/tools.py index 33e72fd..34cd29a 100644 --- a/utilities/tools.py +++ b/utilities/tools.py @@ -8,6 +8,7 @@ from keras import layers import os import logging +import platform def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]: @@ -317,3 +318,17 @@ def get_data_path_addon(name: str) -> str: return "pre_filter" else: raise ValueError("Invalid model name") + + +def get_base_path(): + # Determine the base path depending on the operating system + if platform.system() == 'Windows': + base_path = r"C:/Users\phili/.keras/datasets/resized_DVM" + elif platform.system() == 'Linux': + base_path = "/home/luke/datasets/" + elif platform.system() == 'Darwin': # Darwin is the system name for macOS + base_path = "/Users/flippchen/datasets/" + else: + raise ValueError("Operating system not supported.") + + return base_path