-
Notifications
You must be signed in to change notification settings - Fork 4
/
ex_trans_poisoning.py
54 lines (42 loc) · 1.57 KB
/
ex_trans_poisoning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import logging
from torchtyping import patch_typeguard
from typeguard import typechecked
from gb.exutil import *
from gb.util import fetch
patch_typeguard()
ex = make_experiment("trans_poisoning")
@ex.config
def config():
evasion_run_id = 0
use_evasion_seeds = False
@ex.capture
@typechecked
def do_run(evasion_run_id: int, use_evasion_seeds: bool) -> NonPrintingDict:
logging.info(f"Loading config and result of evasion run with ID {evasion_run_id}...")
evasion_exs = fetch(
"evasion", ["config", "result"], filter={"_id": evasion_run_id}, incl_files={"perturbations"}
)
if len(evasion_exs) == 0:
raise ValueError(f"There is no evasion experiment with ID {evasion_run_id}")
evasion_ex = evasion_exs[0]
attack = evasion_ex.config.attack
out_test_acc, out_scores, out_margins = run_poisoning(
evasion_ex.config.dataset, attack.scope, attack.get("targets"), evasion_ex.config.model,
evasion_ex.config.training, evasion_ex.result.perturbations, use_evasion_seeds
)
logging.info("Done! Collecting results...")
if attack.scope == "global":
add_npz_artifact("scores", out_scores)
add_npz_artifact("proba_margins", out_margins)
return NonPrintingDict({
"test_accuracy": out_test_acc
})
else:
return NonPrintingDict({
"test_accuracy": out_test_acc,
"scores": recursive_tensors_to_lists(out_scores),
"proba_margins": recursive_tensors_to_lists(out_margins)
})
@ex.automain
def run() -> NonPrintingDict:
return do_run()