Skip to content

Commit

Permalink
index and class filtering from command line; also doc update (#1162)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidslater authored Oct 23, 2021
1 parent e122dd5 commit ebbc88e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 16 deletions.
64 changes: 64 additions & 0 deletions armory/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import json
import logging
import os
import re
import sys

import coloredlogs
Expand All @@ -38,6 +39,36 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)


def sorted_unique_nonnegative_numbers(values, warning_string):
if not isinstance(values, str):
raise ValueError(f"{values} invalid.\n Must be a string input.")

if not re.match(r"^\s*\d+(\s*,\s*\d+)*\s*$", values):
raise ValueError(
f"{values} invalid. Must be ','-separated nonnegative integers"
)

numbers = [int(x) for x in values.split(",")]
sorted_unique_numbers = sorted(set(numbers))
if numbers != sorted_unique_numbers:
print(
f"WARNING: {warning_string} sorted and made unique: {sorted_unique_numbers}"
)
return sorted_unique_numbers


class Index(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
sorted_unique_numbers = sorted_unique_nonnegative_numbers(values, "--index")
setattr(namespace, self.dest, sorted_unique_numbers)


class Classes(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
sorted_unique_numbers = sorted_unique_nonnegative_numbers(values, "--classes")
setattr(namespace, self.dest, sorted_unique_numbers)


class Command(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values not in COMMANDS:
Expand Down Expand Up @@ -192,6 +223,26 @@ def _root(parser):
)


def _index(parser):
parser.add_argument(
"--index",
type=str,
help="Comma-separated nonnegative index for evaluation data point filtering"
"e.g.: `2` or ``1,3,7`",
action=Index,
)


def _classes(parser):
parser.add_argument(
"--classes",
type=str,
help="Comma-separated nonnegative class ids for filtering"
"e.g.: `2` or ``1,3,7`",
action=Classes,
)


# Config


Expand Down Expand Up @@ -242,6 +293,8 @@ def run(command_args, prog, description):
_gpus(parser)
_no_docker(parser)
_root(parser)
_index(parser)
_classes(parser)
parser.add_argument(
"--output-dir", type=str, help="Override of default output directory prefix",
)
Expand Down Expand Up @@ -314,6 +367,17 @@ def run(command_args, prog, description):
(config, args) = arguments.merge_config_and_args(config, args)
logging.debug("unified sysconfig %s and args %s", config["sysconfig"], args)

if args.num_eval_batches and args.index:
raise ValueError("Cannot have --num-eval-batches and --index")
if args.index and config["dataset"].get("index"):
logging.info("Overriding index in config with command line argument")
if args.index:
config["dataset"]["index"] = args.index
if args.classes and config["dataset"].get("class_ids"):
logging.info("Overriding class_ids in config with command line argument")
if args.classes:
config["dataset"]["class_ids"] = args.classes

rig = Evaluator(config, no_docker=args.no_docker, root=args.root)
exit_code = rig.run(
interactive=args.interactive,
Expand Down
24 changes: 9 additions & 15 deletions armory/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,18 @@ def filter_by_index(dataset: "tf.data.Dataset", index: list, dataset_size: int):
"""
logger.info(f"Filtering dataset to the following indices: {index}")
dataset_size = int(dataset_size)
if len(index) == 0:
raise ValueError(
"The specified dataset 'index' param must have at least one value"
)
valid_indices = sorted([int(x) for x in set(index) if int(x) < dataset_size])
num_valid_indices = len(valid_indices)
if num_valid_indices == 0:
raise ValueError(
f"The specified dataset 'index' param values all exceed dataset size of {dataset_size}"
)
elif index[0] < 0:
sorted_index = sorted([int(x) for x in set(index)])
if len(sorted_index) == 0:
raise ValueError("The specified dataset 'index' param must be nonempty")
if sorted_index[0] < 0:
raise ValueError("The specified dataset 'index' values must be nonnegative")
elif num_valid_indices != len(set(index)):
logger.warning(
f"All dataset 'index' values exceeding dataset size of {dataset_size} are being ignored"
if sorted_index[-1] >= dataset_size:
raise ValueError(
f"The specified dataset 'index' values exceed dataset size {dataset_size}"
)
num_valid_indices = len(sorted_index)

index_tensor = tf.constant(index, dtype=tf.int64)
index_tensor = tf.constant(sorted_index, dtype=tf.int64)

def enum_index(i, x):
i = tf.expand_dims(i, 0)
Expand Down
15 changes: 14 additions & 1 deletion docs/command_line.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ armory launch tf1 --gpus=1,4 --interactive
armory exec pytorch --gpus=0 -- nvidia-smi
```

## Check Runs and Number of Examples
## Check Runs, Number of Example Batches, Indexing, and Class Filtering
* `armory run <config> --check [...]`
* `armory run <config> --num-eval-batches=X [...]`
* `armory run <config> --index=a,b,c [...]`
* `armory run <config> --classes=x,y,z [...]`
Applies to `run` command.

The `--check` flag will make every dataset return a single batch,
Expand All @@ -68,6 +70,17 @@ both benign and adversarial test sets.
It is primarily designed for attack development iteration, where it is typically unhelpful
to run more than 10-100 examples.

The `--index` argument will only use samples from the comma-separated, non-negative list of numbers provided.
Any duplicate numbers will be removed and the list will be sorted.
If indices beyond the size of the dataset are provided, an error will result at runtime.
Cannot be used with the `--num-eval-batches` argument.
Currently, batch size must be set to 1.

The `--classes` argument will only use samples from the comma-separated, non-negative list of numbers provided.
Any duplicate numbers will be removed and the list will be sorted.
If indices beyond the size of the dataset are provided, an error will result at runtime.
Can be used with `--index` argument. In that case, indexing will be done after class filtering.

NOTE: `--check` will take precedence over the `--num-eval-batches` argument.

### Example Usage
Expand Down

0 comments on commit ebbc88e

Please sign in to comment.