Skip to content

Commit

Permalink
Dataset index and slicing (#878)
Browse files Browse the repository at this point in the history
* update split naming

* update dataset loading

* updated docs

* updated token tests

* working tests for slicing

* allow ordering and duplicates
  • Loading branch information
davidslater authored Nov 4, 2020
1 parent 4324438 commit 6da24a4
Show file tree
Hide file tree
Showing 16 changed files with 271 additions and 135 deletions.
26 changes: 13 additions & 13 deletions armory/data/adversarial_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def apricot_canonical_preprocessing(batch):


def imagenet_adversarial(
split_type: str = "adversarial",
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
Expand Down Expand Up @@ -86,7 +86,7 @@ def imagenet_adversarial(

return datasets._generator_from_tfds(
"imagenet_adversarial:1.1.0",
split_type=split_type,
split=split,
batch_size=batch_size,
epochs=epochs,
dataset_dir=dataset_dir,
Expand All @@ -99,7 +99,7 @@ def imagenet_adversarial(


def librispeech_adversarial(
split_type: str = "adversarial",
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
Expand All @@ -115,7 +115,7 @@ def librispeech_adversarial(
Adversarial dataset based on Librispeech-dev-clean including clean,
Universal Perturbation using PGD, and PGD.
split_type - one of ("adversarial")
split - one of ("adversarial")
returns:
Generator
Expand All @@ -130,7 +130,7 @@ def librispeech_adversarial(

return datasets._generator_from_tfds(
"librispeech_adversarial:1.1.0",
split_type=split_type,
split=split,
batch_size=batch_size,
epochs=epochs,
dataset_dir=dataset_dir,
Expand All @@ -146,7 +146,7 @@ def librispeech_adversarial(


def resisc45_adversarial_224x224(
split_type: str = "adversarial",
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
Expand Down Expand Up @@ -185,7 +185,7 @@ def lambda_map(x, y):

return datasets._generator_from_tfds(
"resisc45_densenet121_univpatch_and_univperturbation_adversarial224x224:1.0.2",
split_type=split_type,
split=split,
batch_size=batch_size,
epochs=epochs,
dataset_dir=dataset_dir,
Expand All @@ -201,7 +201,7 @@ def lambda_map(x, y):


def ucf101_adversarial_112x112(
split_type: str = "adversarial",
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
Expand Down Expand Up @@ -242,7 +242,7 @@ def lambda_map(x, y):

return datasets._generator_from_tfds(
"ucf101_mars_perturbation_and_patch_adversarial112x112:1.1.0",
split_type=split_type,
split=split,
batch_size=batch_size,
epochs=epochs,
dataset_dir=dataset_dir,
Expand All @@ -258,7 +258,7 @@ def lambda_map(x, y):


def gtsrb_poison(
split_type: str = "poison",
split: str = "poison",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
Expand All @@ -277,7 +277,7 @@ def gtsrb_poison(
"""
return datasets._generator_from_tfds(
"gtsrb_bh_poison_micronnet:1.0.0",
split_type=split_type,
split=split,
batch_size=batch_size,
epochs=epochs,
dataset_dir=dataset_dir,
Expand All @@ -293,7 +293,7 @@ def gtsrb_poison(


def apricot_dev_adversarial(
split_type: str = "adversarial",
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
Expand Down Expand Up @@ -323,7 +323,7 @@ def replace_magic_val(data, raw_val, transformed_val, sub_key):

return datasets._generator_from_tfds(
"apricot_dev:1.0.1",
split_type=split_type,
split=split,
batch_size=batch_size,
epochs=epochs,
dataset_dir=dataset_dir,
Expand Down
Loading

0 comments on commit 6da24a4

Please sign in to comment.