From 4a6c6fd0720c964f7aa74f079a200fad42f49f6d Mon Sep 17 00:00:00 2001 From: Sterling Date: Fri, 3 Mar 2023 13:26:26 +0000 Subject: [PATCH 1/3] raise error if num-eval-batches is too large for dataset --- armory/utils/config_loading.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/armory/utils/config_loading.py b/armory/utils/config_loading.py index eb0101923..2cae1b99d 100644 --- a/armory/utils/config_loading.py +++ b/armory/utils/config_loading.py @@ -79,6 +79,11 @@ def load_dataset(dataset_config, *args, num_batches=None, check_run=False, **kwa if check_run: return EvalGenerator(dataset, num_eval_batches=1) if num_batches: + if num_batches > dataset.batches_per_epoch: + # since num-eval-batches only applies at test time, we can assume there is only 1 epoch + raise ValueError( + f"{num_batches} eval batches were requested, but dataset has only {dataset.batches_per_epoch} batches of size {dataset.batch_size}" + ) return EvalGenerator(dataset, num_eval_batches=num_batches) return dataset From 9da2f245a5cc5840fe2e010c7938ed81912ed66a Mon Sep 17 00:00:00 2001 From: Sterling Date: Fri, 3 Mar 2023 13:34:38 +0000 Subject: [PATCH 2/3] correct the names of some carla datasets in comments and error messages --- armory/data/adversarial_datasets.py | 18 +++++++++--------- armory/data/datasets.py | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/armory/data/adversarial_datasets.py b/armory/data/adversarial_datasets.py index db4e36327..2e4f6ba98 100644 --- a/armory/data/adversarial_datasets.py +++ b/armory/data/adversarial_datasets.py @@ -661,15 +661,15 @@ def carla_over_obj_det_dev( **kwargs, ): """ - Dev set for CARLA object detection dataset, containing RGB and depth channels. The dev + Dev set for CARLA overhead object detection dataset, containing RGB and depth channels. The dev set also contains green screens for adversarial patch insertion. """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_obj_det_dev dataset" + "Filtering by class is not supported for the carla_over_obj_det_dev dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_dev batch size must be set to 1") + raise ValueError("carla_over_obj_det_dev batch size must be set to 1") modality = kwargs.pop("modality", "rgb") if modality not in ["rgb", "depth", "both"]: @@ -729,15 +729,15 @@ def carla_over_obj_det_test( **kwargs, ): """ - Dev set for CARLA object detection dataset, containing RGB and depth channels. The test + Test set for CARLA overhead object detection dataset, containing RGB and depth channels. The test set also contains green screens for adversarial patch insertion. """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_obj_det_test dataset" + "Filtering by class is not supported for the carla_over_obj_det_test dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_test batch size must be set to 1") + raise ValueError("carla_over_obj_det_test batch size must be set to 1") modality = kwargs.pop("modality", "rgb") if modality not in ["rgb", "depth", "both"]: @@ -924,7 +924,7 @@ def carla_video_tracking_dev( "Filtering by class is not supported for the carla_video_tracking_dev dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_dev batch size must be set to 1") + raise ValueError("carla_video_tracking_dev batch size must be set to 1") if max_frames: clip = datasets.ClipFrames(max_frames) @@ -975,10 +975,10 @@ def carla_video_tracking_test( """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_video_tracking_dev dataset" + "Filtering by class is not supported for the carla_video_tracking_test dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_dev batch size must be set to 1") + raise ValueError("carla_video_tracking_test batch size must be set to 1") if max_frames: clip = datasets.ClipFrames(max_frames) diff --git a/armory/data/datasets.py b/armory/data/datasets.py index 5cb2afe3e..9c7f0071e 100644 --- a/armory/data/datasets.py +++ b/armory/data/datasets.py @@ -1034,11 +1034,11 @@ def carla_over_obj_det_train( **kwargs, ) -> ArmoryDataGenerator: """ - Training set for CARLA object detection dataset, containing RGB and depth channels. + Training set for CARLA overhead object detection dataset, containing RGB and depth channels. """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_obj_det_train dataset" + "Filtering by class is not supported for the carla_over_obj_det_train dataset" ) modality = kwargs.pop("modality", "rgb") if modality not in ["rgb", "depth", "both"]: From 5abe034292dae96a9af557f28c41ee3f541ef9ce Mon Sep 17 00:00:00 2001 From: Sterling Date: Fri, 3 Mar 2023 14:10:32 +0000 Subject: [PATCH 3/3] add a tf.function decorator to halve training time of audio resnet --- armory/baseline_models/tf_graph/audio_resnet50.py | 1 + 1 file changed, 1 insertion(+) diff --git a/armory/baseline_models/tf_graph/audio_resnet50.py b/armory/baseline_models/tf_graph/audio_resnet50.py index c8049ddef..49aa70cd0 100644 --- a/armory/baseline_models/tf_graph/audio_resnet50.py +++ b/armory/baseline_models/tf_graph/audio_resnet50.py @@ -53,6 +53,7 @@ def get_art_model( loss_object = losses.SparseCategoricalCrossentropy() + @tf.function def train_step(model, samples, labels): with tf.GradientTape() as tape: predictions = model(samples, training=True)