From 665aa62dec28615cb4d41c544359a6deb77cf94c Mon Sep 17 00:00:00 2001 From: pesekon2 Date: Sat, 23 Sep 2023 10:50:14 +0200 Subject: [PATCH] detect.py: support more TF versions for setting seed --- src/detect.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/detect.py b/src/detect.py index 67c79ea0..178c7f89 100644 --- a/src/detect.py +++ b/src/detect.py @@ -34,7 +34,12 @@ def main(data_dir, model, in_weights_path, visualization_path, batch_size, os.path.join(data_dir, 'label_colors.txt')) # set TensorFlow seed - tf.random.set_seed(seed) + if seed is not None: + import sys + if int(tf.__version__.split('.')[1]) < 4: + tf.random.set_seed(seed) + else: + tf.keras.utils.set_random_seed(seed) model = create_model(model, len(id2code), nr_bands, tensor_shape, backbone=backbone)