Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted improvements #1894

Merged
merged 3 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions armory/baseline_models/tf_graph/audio_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_art_model(

loss_object = losses.SparseCategoricalCrossentropy()

@tf.function
def train_step(model, samples, labels):
with tf.GradientTape() as tape:
predictions = model(samples, training=True)
Expand Down
18 changes: 9 additions & 9 deletions armory/data/adversarial_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,15 +661,15 @@ def carla_over_obj_det_dev(
**kwargs,
):
"""
Dev set for CARLA object detection dataset, containing RGB and depth channels. The dev
Dev set for CARLA overhead object detection dataset, containing RGB and depth channels. The dev
set also contains green screens for adversarial patch insertion.
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_obj_det_dev dataset"
"Filtering by class is not supported for the carla_over_obj_det_dev dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_dev batch size must be set to 1")
raise ValueError("carla_over_obj_det_dev batch size must be set to 1")

modality = kwargs.pop("modality", "rgb")
if modality not in ["rgb", "depth", "both"]:
Expand Down Expand Up @@ -729,15 +729,15 @@ def carla_over_obj_det_test(
**kwargs,
):
"""
Dev set for CARLA object detection dataset, containing RGB and depth channels. The test
Test set for CARLA overhead object detection dataset, containing RGB and depth channels. The test
set also contains green screens for adversarial patch insertion.
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_obj_det_test dataset"
"Filtering by class is not supported for the carla_over_obj_det_test dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_test batch size must be set to 1")
raise ValueError("carla_over_obj_det_test batch size must be set to 1")

modality = kwargs.pop("modality", "rgb")
if modality not in ["rgb", "depth", "both"]:
Expand Down Expand Up @@ -924,7 +924,7 @@ def carla_video_tracking_dev(
"Filtering by class is not supported for the carla_video_tracking_dev dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_dev batch size must be set to 1")
raise ValueError("carla_video_tracking_dev batch size must be set to 1")

if max_frames:
clip = datasets.ClipFrames(max_frames)
Expand Down Expand Up @@ -975,10 +975,10 @@ def carla_video_tracking_test(
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_video_tracking_dev dataset"
"Filtering by class is not supported for the carla_video_tracking_test dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_dev batch size must be set to 1")
raise ValueError("carla_video_tracking_test batch size must be set to 1")

if max_frames:
clip = datasets.ClipFrames(max_frames)
Expand Down
4 changes: 2 additions & 2 deletions armory/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,11 +1034,11 @@ def carla_over_obj_det_train(
**kwargs,
) -> ArmoryDataGenerator:
"""
Training set for CARLA object detection dataset, containing RGB and depth channels.
Training set for CARLA overhead object detection dataset, containing RGB and depth channels.
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_obj_det_train dataset"
"Filtering by class is not supported for the carla_over_obj_det_train dataset"
)
modality = kwargs.pop("modality", "rgb")
if modality not in ["rgb", "depth", "both"]:
Expand Down
5 changes: 5 additions & 0 deletions armory/utils/config_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def load_dataset(dataset_config, *args, num_batches=None, check_run=False, **kwa
if check_run:
return EvalGenerator(dataset, num_eval_batches=1)
if num_batches:
if num_batches > dataset.batches_per_epoch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested that this works, but this function has become obsolete on the tfdsv4 branch in lieu of this one. I'd suggest also incorporating this change on a branch based off tfdsv4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would occur in a separate PR into tfdsv4, though, so I'm good merging this one

# since num-eval-batches only applies at test time, we can assume there is only 1 epoch
raise ValueError(
f"{num_batches} eval batches were requested, but dataset has only {dataset.batches_per_epoch} batches of size {dataset.batch_size}"
)
return EvalGenerator(dataset, num_eval_batches=num_batches)
return dataset

Expand Down