diff --git a/armory/baseline_models/tf_graph/audio_resnet50.py b/armory/baseline_models/tf_graph/audio_resnet50.py index c8049ddef..49aa70cd0 100644 --- a/armory/baseline_models/tf_graph/audio_resnet50.py +++ b/armory/baseline_models/tf_graph/audio_resnet50.py @@ -53,6 +53,7 @@ def get_art_model( loss_object = losses.SparseCategoricalCrossentropy() + @tf.function def train_step(model, samples, labels): with tf.GradientTape() as tape: predictions = model(samples, training=True) diff --git a/armory/data/adversarial_datasets.py b/armory/data/adversarial_datasets.py index db4e36327..2e4f6ba98 100644 --- a/armory/data/adversarial_datasets.py +++ b/armory/data/adversarial_datasets.py @@ -661,15 +661,15 @@ def carla_over_obj_det_dev( **kwargs, ): """ - Dev set for CARLA object detection dataset, containing RGB and depth channels. The dev + Dev set for CARLA overhead object detection dataset, containing RGB and depth channels. The dev set also contains green screens for adversarial patch insertion. """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_obj_det_dev dataset" + "Filtering by class is not supported for the carla_over_obj_det_dev dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_dev batch size must be set to 1") + raise ValueError("carla_over_obj_det_dev batch size must be set to 1") modality = kwargs.pop("modality", "rgb") if modality not in ["rgb", "depth", "both"]: @@ -729,15 +729,15 @@ def carla_over_obj_det_test( **kwargs, ): """ - Dev set for CARLA object detection dataset, containing RGB and depth channels. The test + Test set for CARLA overhead object detection dataset, containing RGB and depth channels. The test set also contains green screens for adversarial patch insertion. """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_obj_det_test dataset" + "Filtering by class is not supported for the carla_over_obj_det_test dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_test batch size must be set to 1") + raise ValueError("carla_over_obj_det_test batch size must be set to 1") modality = kwargs.pop("modality", "rgb") if modality not in ["rgb", "depth", "both"]: @@ -924,7 +924,7 @@ def carla_video_tracking_dev( "Filtering by class is not supported for the carla_video_tracking_dev dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_dev batch size must be set to 1") + raise ValueError("carla_video_tracking_dev batch size must be set to 1") if max_frames: clip = datasets.ClipFrames(max_frames) @@ -975,10 +975,10 @@ def carla_video_tracking_test( """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_video_tracking_dev dataset" + "Filtering by class is not supported for the carla_video_tracking_test dataset" ) if batch_size != 1: - raise ValueError("carla_obj_det_dev batch size must be set to 1") + raise ValueError("carla_video_tracking_test batch size must be set to 1") if max_frames: clip = datasets.ClipFrames(max_frames) diff --git a/armory/data/datasets.py b/armory/data/datasets.py index 5cb2afe3e..9c7f0071e 100644 --- a/armory/data/datasets.py +++ b/armory/data/datasets.py @@ -1034,11 +1034,11 @@ def carla_over_obj_det_train( **kwargs, ) -> ArmoryDataGenerator: """ - Training set for CARLA object detection dataset, containing RGB and depth channels. + Training set for CARLA overhead object detection dataset, containing RGB and depth channels. """ if "class_ids" in kwargs: raise ValueError( - "Filtering by class is not supported for the carla_obj_det_train dataset" + "Filtering by class is not supported for the carla_over_obj_det_train dataset" ) modality = kwargs.pop("modality", "rgb") if modality not in ["rgb", "depth", "both"]: diff --git a/armory/utils/config_loading.py b/armory/utils/config_loading.py index eb0101923..2cae1b99d 100644 --- a/armory/utils/config_loading.py +++ b/armory/utils/config_loading.py @@ -79,6 +79,11 @@ def load_dataset(dataset_config, *args, num_batches=None, check_run=False, **kwa if check_run: return EvalGenerator(dataset, num_eval_batches=1) if num_batches: + if num_batches > dataset.batches_per_epoch: + # since num-eval-batches only applies at test time, we can assume there is only 1 epoch + raise ValueError( + f"{num_batches} eval batches were requested, but dataset has only {dataset.batches_per_epoch} batches of size {dataset.batch_size}" + ) return EvalGenerator(dataset, num_eval_batches=num_batches) return dataset