Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip misclassified examples #1005

Merged
merged 39 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3abfd11
added object detection metrics
Oct 20, 2020
22c8dd0
Merge remote-tracking branch 'upstream/master'
Oct 20, 2020
75efe00
Merge branch 'master' of https://github.com/twosixlabs/armory
Oct 22, 2020
bae155c
Merge remote-tracking branch 'upstream/master'
Oct 26, 2020
830152c
Merge remote-tracking branch 'upstream/master'
Oct 26, 2020
ce6b8d8
Merge remote-tracking branch 'upstream/master'
Oct 27, 2020
017b8eb
Merge remote-tracking branch 'upstream/master'
Oct 28, 2020
288bcb6
Merge remote-tracking branch 'upstream/master'
Oct 29, 2020
c161339
Merge remote-tracking branch 'upstream/master'
Nov 2, 2020
22d06d9
Merge remote-tracking branch 'upstream/master'
Nov 2, 2020
e117a55
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 4, 2020
3af2881
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 4, 2020
61074f3
Merge remote-tracking branch 'upstream/master'
Nov 5, 2020
9f6d42c
Merge remote-tracking branch 'upstream/master'
Nov 5, 2020
74386d7
Merge remote-tracking branch 'upstream/master'
Nov 5, 2020
ee1fee3
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 9, 2020
62cbde8
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 13, 2020
834ee3e
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 16, 2020
2b1cd3b
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 17, 2020
f3fe19d
Merge branch 'master' of https://github.com/twosixlabs/armory
Nov 19, 2020
3384dc8
Merge remote-tracking branch 'upstream/master'
Nov 30, 2020
fb36c0d
Merge branch 'master' of https://github.com/twosixlabs/armory
Dec 2, 2020
a2e7cb8
Merge branch 'master' of https://github.com/twosixlabs/armory
Dec 7, 2020
fe50057
Merge remote-tracking branch 'upstream/master'
Dec 18, 2020
1e0b7b0
Merge remote-tracking branch 'upstream/master'
Jan 4, 2021
9906d84
Merge remote-tracking branch 'upstream/master'
Jan 5, 2021
aab9be3
Merge branch 'master' of https://github.com/twosixlabs/armory
Feb 25, 2021
f0ffd60
Merge branch 'dev' of https://github.com/twosixlabs/armory
Mar 5, 2021
6de6d7c
adding proof of concept
Mar 5, 2021
05df706
black formatting
Mar 5, 2021
1d49dcf
move cli/config check to base scenario before _evaluate()
Mar 12, 2021
6e8fda5
refactor image_classification skip_misclassified
Mar 12, 2021
975cdb9
update __main__ with skip-misclassified
Mar 12, 2021
5ec6da1
record_metric_per_sample doesn't need to be true
Mar 12, 2021
7874237
add skip_misclassified to so2sat scenario
Mar 12, 2021
5ad3df8
adding skip_misclassified to scenarios where it shouldnt be used
Mar 12, 2021
4e141d9
add skip_misclassified to video
Mar 12, 2021
33dfb1e
minor refactor of audio_classification plus adding in skip_misclassified
Mar 12, 2021
35c850e
docs for skip-misclassified
Mar 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions armory/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def run(
num_eval_batches=None,
skip_benign=None,
skip_attack=None,
skip_misclassified=None,
validate_config=None,
) -> int:
exit_code = 0
Expand All @@ -176,6 +177,7 @@ def run(
num_eval_batches=num_eval_batches,
skip_benign=skip_benign,
skip_attack=skip_attack,
skip_misclassified=skip_misclassified,
validate_config=validate_config,
)
except KeyboardInterrupt:
Expand Down Expand Up @@ -213,6 +215,7 @@ def run(
num_eval_batches=num_eval_batches,
skip_benign=skip_benign,
skip_attack=skip_attack,
skip_misclassified=skip_misclassified,
validate_config=validate_config,
)
elif command:
Expand All @@ -224,6 +227,7 @@ def run(
num_eval_batches=num_eval_batches,
skip_benign=skip_benign,
skip_attack=skip_attack,
skip_misclassified=skip_misclassified,
validate_config=validate_config,
)
except KeyboardInterrupt:
Expand Down Expand Up @@ -263,6 +267,7 @@ def _run_config(
num_eval_batches=None,
skip_benign=None,
skip_attack=None,
skip_misclassified=None,
validate_config=None,
) -> int:
logger.info(bold(red("Running evaluation script")))
Expand All @@ -273,6 +278,7 @@ def _run_config(
num_eval_batches=num_eval_batches,
skip_benign=skip_benign,
skip_attack=skip_attack,
skip_misclassified=skip_misclassified,
validate_config=validate_config,
)
if self.no_docker:
Expand All @@ -294,7 +300,7 @@ def get_id(self):
Return uid, gid
"""
# Windows docker does not require synchronizing file and
# directoriy permissions via uid and gid.
# directory permissions via uid and gid.
if os.name == "nt" or self.root:
user_id = 0
group_id = 0
Expand All @@ -310,6 +316,7 @@ def _run_interactive_bash(
num_eval_batches=None,
skip_benign=None,
skip_attack=None,
skip_misclassified=None,
validate_config=None,
) -> None:
user_group_id = self.get_id()
Expand All @@ -330,6 +337,7 @@ def _run_interactive_bash(
num_eval_batches=num_eval_batches,
skip_benign=skip_benign,
skip_attack=skip_attack,
skip_misclassified=skip_misclassified,
validate_config=validate_config,
)
tmp_dir = os.path.join(self.host_paths.tmp_dir, self.config["eval_id"])
Expand Down Expand Up @@ -385,7 +393,13 @@ def _run_jupyter(self, runner: ArmoryInstance, ports: dict) -> None:
)

def _build_options(
self, check_run, num_eval_batches, skip_benign, skip_attack, validate_config
self,
check_run,
num_eval_batches,
skip_benign,
skip_attack,
skip_misclassified,
validate_config,
):
options = ""
if self.no_docker:
Expand All @@ -400,6 +414,8 @@ def _build_options(
options += " --skip-benign"
if skip_attack:
options += " --skip-attack"
if skip_misclassified:
options += " --skip-misclassified"
if validate_config:
options += " --validate-config"
return options
22 changes: 20 additions & 2 deletions armory/scenarios/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def evaluate(
num_eval_batches: Optional[int],
skip_benign: Optional[bool],
skip_attack: Optional[bool],
skip_misclassified: Optional[bool],
):
"""
Evaluate a config for robustness against attack.
Expand All @@ -74,7 +75,9 @@ def evaluate(
config["adhoc"]["train_epochs"] = 1

try:
results = self._evaluate(config, num_eval_batches, skip_benign, skip_attack)
results = self._evaluate(
config, num_eval_batches, skip_benign, skip_attack, skip_misclassified
)
except Exception as e:
if str(e) == "assignment destination is read-only":
logger.exception(
Expand Down Expand Up @@ -106,6 +109,7 @@ def _evaluate(
num_eval_batches: Optional[int],
skip_benign: Optional[bool],
skip_attack: Optional[bool],
skip_misclassified: Optional[bool],
) -> dict:
"""
Evaluate the config and return a results dict
Expand Down Expand Up @@ -253,6 +257,7 @@ def run_config(
num_eval_batches=None,
skip_benign=None,
skip_attack=None,
skip_misclassified=None,
):
config = _get_config(config_json, from_file=from_file)
scenario_config = config.get("scenario")
Expand All @@ -261,7 +266,14 @@ def run_config(
_scenario_setup(config)
scenario = config_loading.load(scenario_config)
scenario.set_check_run(check)
scenario.evaluate(config, mongo_host, num_eval_batches, skip_benign, skip_attack)
scenario.evaluate(
config,
mongo_host,
num_eval_batches,
skip_benign,
skip_attack,
skip_misclassified,
)


def init_interactive(config_json, from_file=True):
Expand Down Expand Up @@ -330,6 +342,11 @@ def init_interactive(config_json, from_file=True):
action="store_true",
help="Validate model configuration against several checks",
)
parser.add_argument(
"--skip-misclassified",
action="store_true",
help="Skip attack attacking of inputs that are already misclassified",
)
args = parser.parse_args()
coloredlogs.install(level=args.log_level)
calling_version = os.getenv(environment.ARMORY_VERSION, "UNKNOWN")
Expand Down Expand Up @@ -360,5 +377,6 @@ def init_interactive(config_json, from_file=True):
args.num_eval_batches,
args.skip_benign,
args.skip_attack,
args.skip_misclassified,
)
print(END_SENTINEL) # indicates to host that the scenario finished w/out error
35 changes: 33 additions & 2 deletions armory/scenarios/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,15 @@ def _evaluate(
num_eval_batches: Optional[int],
skip_benign: Optional[bool],
skip_attack: Optional[bool],
skip_misclassified: Optional[bool],
) -> dict:
"""
Evaluate the config and return a results dict
"""
if skip_misclassified and sum([skip_benign, skip_attack]) > 0:
lcadalzo marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Cannot pass skip_misclassified if skip_benign or skip_attack is also passed"
)

model_config = config["model"]
estimator, _ = load_model(model_config)
Expand Down Expand Up @@ -133,6 +138,17 @@ def _evaluate(
# Evaluate the ART estimator on adversarial test examples
logger.info("Generating or loading / testing adversarial examples...")

if skip_misclassified:
try:
acc_task_idx = [i.name for i in metrics_logger.tasks].index(
"categorical_accuracy"
)
except ValueError:
raise ValueError(
"Cannot pass skip_misclassified if 'categorical_accuracy' metric isn't enabled"
)
benign_acc = metrics_logger.tasks[acc_task_idx].values()
lcadalzo marked this conversation as resolved.
Show resolved Hide resolved

if targeted and attack_config.get("use_label"):
raise ValueError("Targeted attacks cannot have 'use_label'")
if attack_type == "preloaded":
Expand Down Expand Up @@ -170,7 +186,7 @@ def _evaluate(
else:
sample_exporter = None

for x, y in tqdm(test_data, desc="Attack"):
for batch_idx, (x, y) in enumerate(tqdm(test_data, desc="Attack")):
with metrics.resource_context(
name="Attack",
profiler=config["metric"].get("profiler_type"),
Expand All @@ -193,7 +209,22 @@ def _evaluate(
elif targeted:
y_target = label_targeter.generate(y)
generate_kwargs["y"] = y_target
x_adv = attack.generate(x=x, **generate_kwargs)

if skip_misclassified:
batch_size = x.shape[0]
if batch_size > 1:
logger.warning(
"Ignoring --skip-misclassified flag since batch_size is "
"greater than 1."
)
lcadalzo marked this conversation as resolved.
Show resolved Hide resolved
x_adv = attack.generate(x=x, **generate_kwargs)
else:
if benign_acc[batch_idx] == 0:
x_adv = x
else:
x_adv = attack.generate(x=x, **generate_kwargs)
else:
x_adv = attack.generate(x=x, **generate_kwargs)
lcadalzo marked this conversation as resolved.
Show resolved Hide resolved

# Ensure that input sample isn't overwritten by estimator
x_adv.flags.writeable = False
Expand Down