-
Notifications
You must be signed in to change notification settings - Fork 0
/
textAttacker.py
183 lines (152 loc) · 6.52 KB
/
textAttacker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import textattack as ta
import OpenAttack as oa
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# List of prebuilt TextAttack attacks
Prebuilt_TextAttacks = [
# ta.attack_recipes.a2t_yoo_2021.A2TYoo2021, # Error in the library
# ta.attack_recipes.faster_genetic_algorithm_jia_2019.FasterGeneticAlgorithmJia2019,
# ta.attack_recipes.bae_garg_2019.BAEGarg2019, # High memory usage
ta.attack_recipes.bert_attack_li_2020.BERTAttackLi2020,
ta.attack_recipes.checklist_ribeiro_2020.CheckList2020,
ta.attack_recipes.deepwordbug_gao_2018.DeepWordBugGao2018,
# ta.attack_recipes.hotflip_ebrahimi_2017.HotFlipEbrahimi2017,
ta.attack_recipes.input_reduction_feng_2018.InputReductionFeng2018,
ta.attack_recipes.pso_zang_2020.PSOZang2020,
ta.attack_recipes.pwws_ren_2019.PWWSRen2019,
ta.attack_recipes.textfooler_jin_2019.TextFoolerJin2019,
ta.attack_recipes.textbugger_li_2018.TextBuggerLi2018,
ta.attack_recipes.pruthi_2019.Pruthi2019,
ta.attack_recipes.clare_li_2020.CLARE2020
]
# List of prebuilt OpenAttack attacks
Prebuilt_OpenAttack = [
# oa.attackers.BAEAttacker, # Available in TextAttack # High memory usage
# oa.attackers.BERTAttacker, # Available in TextAttack
# oa.attackers.DeepWordBugAttacker, # Available in TextAttack
oa.attackers.FDAttacker,
# oa.attackers.GANAttacker, # Poor Output Quality
# oa.attackers.GEOAttacker, # Undefined
oa.attackers.GeneticAttacker,
oa.attackers.HotFlipAttacker, # Available in TextAttack
# oa.attackers.PSOAttacker, # Available in TextAttack
# oa.attackers.PWWSAttacker, # Available in TextAttack
oa.attackers.SCPNAttacker,
# oa.attackers.TextBuggerAttacker, # Available in TextAttack
# oa.attackers.TextFoolerAttacker, # Available in TextAttack
oa.attackers.UATAttacker,
oa.attackers.VIPERAttacker
]
def load_data_TextAttack():
"""
Load the dataset for TextAttack.
Returns:
- tokenized_dataset: The tokenized dataset for TextAttack.
"""
dataset = load_dataset("PolyAI/banking77")
tokenized_dataset = ta.datasets.HuggingFaceDataset(dataset["test"].shuffle(seed=0).select(range(100)))
return tokenized_dataset
def load_data_OpenAttack():
"""
Load the dataset for OpenAttack.
Returns:
- dataset: The dataset for OpenAttack.
"""
dataset = load_dataset("PolyAI/banking77")
dataset = dataset["test"].shuffle(seed=0).select(range(100))
def dataset_mapping(x):
return {
"x": x["text"],
"y": x["label"],
}
dataset = dataset.map(function=dataset_mapping, remove_columns=["text", "label"])
return dataset
def load_model_TextAttack():
"""
Load the TextAttack model for sequence classification.
Returns:
- model_wrapped: The wrapped HuggingFace model for TextAttack.
"""
model_id = 'philschmid/BERT-Banking77'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model_wrapped = ta.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
return model_wrapped
def load_model_OpenAttack():
"""
Load the OpenAttack model for sequence classification.
Returns:
- model_wrapped: The wrapped HuggingFace model for OpenAttack.
"""
model_id = 'philschmid/BERT-Banking77'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model_wrapped = oa.classifiers.TransformersClassifier(model, tokenizer, model.bert.embeddings.word_embeddings)
return model_wrapped
def apply_TextAttacks(model_wrapped, tokenized_dataset):
"""
Apply prebuilt TextAttack attacks to the dataset.
Args:
- model_wrapped: The wrapped HuggingFace model for TextAttack.
- tokenized_dataset: The tokenized dataset for TextAttack.
Returns:
- attacked: A list of attacked examples, including the attack type, original text, and perturbed text.
"""
attacked = []
for _attack in Prebuilt_TextAttacks:
attack_module = _attack.__module__
_type = attack_module.split('.')[-1]
attack_args = ta.AttackArgs(
num_examples=100,
parallel=True,
checkpoint_dir=None,
disable_stdout=False
)
attack = _attack.build(model_wrapped)
attacker = ta.Attacker(attack, tokenized_dataset, attack_args)
results = attacker.attack_dataset()
for res in results:
# Wrongly classified examples are skipped by the textAttack library
if 'SKIPPED' not in res.goal_function_result_str() and 'FAILED' not in res.goal_function_result_str():
attacked.append([_type, res.original_text(), res.perturbed_text(), res.original_result.ground_truth_output])
return attacked
def apply_OpenAttacks(model_wrapped, dataset):
"""
Apply prebuilt OpenAttack attacks to the dataset.
Args:
- model_wrapped: The wrapped HuggingFace model for OpenAttack.
- dataset: The dataset for OpenAttack.
Returns:
- attacked: A list of attacked examples, including the attack type, original text, and perturbed text.
"""
attacked = []
for _attack in Prebuilt_OpenAttack:
_type = _attack.__name__
attacker = _attack()
attack_eval = oa.AttackEval(attacker, model_wrapped)
result = attack_eval.ieval(dataset)
for res in result:
if(res['success'] == True):
attacked.append([_type, res['data']['x'], res['result'], res['data']['y']])
return attacked
if __name__ == "__main__":
# Load TextAttack model and dataset
model_wrapped = load_model_TextAttack()
tokenized_dataset = load_data_TextAttack()
# Apply TextAttack attacks
attacked_TextAttack = apply_TextAttacks(model_wrapped, tokenized_dataset)
# Load OpenAttack model and dataset
model_wrapped = load_model_OpenAttack()
dataset = load_data_OpenAttack()
# Apply OpenAttack attacks
attacked_OpenAttack = apply_OpenAttacks(model_wrapped, dataset)
# Merge attacked examples and save to CSV
merge_attacked = attacked_TextAttack + attacked_OpenAttack
df = pd.DataFrame(merge_attacked, columns=['Type', 'Original', 'Attacked', 'Original Label'])
df.to_csv('attacked.csv', index=False)