-
Notifications
You must be signed in to change notification settings - Fork 3
/
gen_watermarks.py
160 lines (121 loc) · 6.86 KB
/
gen_watermarks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
""" generate watermarks """
import argparse
import time
import traceback
from babel.numbers import format_decimal
# import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
from torch.backends import cudnn
import models
import watermarks
from helpers.utils import *
from helpers.loaders import *
from helpers.image_folder_custom_class import *
from trainer import train_wo_wms
# possible models to use
model_names = sorted(name for name in models.__dict__ if name.islower() and callable(models.__dict__[name]))
print('models : ', model_names)
# possible watermarking methods to use
watermarking_methods = sorted(
watermark for watermark in watermarks.__dict__ if callable(watermarks.__dict__[watermark]))
print('watermarks: ', watermarking_methods)
# set up argument parser
parser = argparse.ArgumentParser(description='Train models with watermarks.')
# model and dataset
parser.add_argument('--dataset', default='cifar10', help='the dataset to train on [cifar10]')
parser.add_argument('--num_classes', default=10, type=int, help='number of classes for classification')
parser.add_argument('--arch', metavar='ARCH', default='simplenet', choices=model_names,
help='model architecture: ' + ' | '.join(model_names) + ' (default: simplenet)')
# watermark related
parser.add_argument('--method', default=None, choices=watermarking_methods,
help='watermarking method: ' + ' | '.join(
watermarking_methods) + ' (default: weakness_into_strength)')
parser.add_argument('--wm_type', default=None, help='wm type for ProtectingIPP: content, unrelated, noise')
parser.add_argument('--save_wm', action='store_true', help='save generated watermarks?')
parser.add_argument('--runname', default='train', help='the exp name')
parser.add_argument('--trg_set_size', default=100, type=int, help='the size of the trigger set (default: 100)')
parser.add_argument('--thresh', default=0.05, type=float, help='threshold for watermark verification')
parser.add_argument('--embed_type', default='', choices=['', 'fromscratch', 'pretrained', 'only_wm', 'augmented'],
help='either fromscratch or pretrained or only_wm or augmented')
parser.add_argument('--loadmodel', default='', help='path which model should be load for pretrained embed type')
parser.add_argument('--eps', default=0.1, help='epsilon for frontier stitching')
parser.add_argument('--lmbda', default=100, help='lambda for piracy resistant')
parser.add_argument('--pattern_size', default=6, help='patternsize for piracy resistant')
parser.add_argument('--test_quot', default=None, type=int,
help='the quotient of data subset (for testing reasons; default: None)')
# hyperparameters
parser.add_argument('--epochs_w_wm', default=0, type=int, help='number of epochs trained with watermarks')
parser.add_argument('--epochs_wo_wm', default=2, type=int, help='number of epochs trained without watermarks')
parser.add_argument('--batch_size', default=32, type=int, help='the batch size')
parser.add_argument('--wm_batch_size', default=16, type=int, help='the wm batch size')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--lradj', default=0.1, type=int, help='multiple the lr by lradj every 20 epochs')
parser.add_argument('--optim', default='SGD', help='optimizer (default SGD)')
parser.add_argument('--sched', default='MultiStepLR', help='scheduler (default MultiStepLR)')
parser.add_argument('--patience', default=20, help='early stopping patience (default 20)')
# cuda
parser.add_argument('--cuda', default=None, help='set cuda (e.g. cuda:0)')
parser.add_argument('--save_file', default="save_results.csv", help='file for saving results')
args = parser.parse_args()
try:
device = torch.device(args.cuda) if torch.cuda.is_available() else 'cpu'
cwd = os.getcwd()
# set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
np.random.seed(0)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
except Exception as e:
msg = 'An error occured during setup: ' + str(e)
try:
generation_time = 0
wm_method = watermarks.__dict__[args.method](args)
if args.method == 'ProtectingIPP':
start_time = time.time()
wm_method.gen_watermarks(device)
generation_time = time.time() - start_time
elif args.method == 'PiracyResistant':
transform = get_wm_transform('PiracyResistant', args.dataset)
dataset, _, _ = get_dataset(args.dataset, os.path.join(cwd, 'data'), os.path.join(cwd, 'data'),
transform, transform, valid_size=None,
testquot=None, size_train=args.trg_set_size, size_test=args.trg_set_size)
loader = torch.utils.data.DataLoader(dataset, batch_size=args.wm_batch_size, num_workers=0, shuffle=False,
drop_last=True)
start_time = time.time()
wm_method.gen_watermarks(loader)
generation_time = time.time() - start_time
elif args.method == 'ExponentialWeighting':
start_time = time.time()
wm_method.gen_watermarks()
generation_time = time.time() - start_time
elif args.method == 'FrontierStitching':
net = models.__dict__[args.arch](num_classes=args.num_classes)
net.load_state_dict(torch.load(os.path.join('checkpoint', args.loadmodel + '.t7')))
net.to(device)
criterion = nn.CrossEntropyLoss()
_, transform = get_data_transforms(args.dataset)
dataset, _, _ = get_dataset(args.dataset, os.path.join(cwd, 'data'), os.path.join(cwd, 'data'),
transform, transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=args.wm_batch_size, num_workers=0, shuffle=False,
drop_last=True)
start_time = time.time()
wm_method.gen_watermarks(net, criterion, device, loader, args.eps)
generation_time = time.time() - start_time
elif args.method == 'WMEmbeddedSystems':
transform = get_wm_transform('WMEmbeddedSystems', args.dataset)
dataset, _, _ = get_dataset(args.dataset, os.path.join(cwd, 'data'), os.path.join(cwd, 'data'),
transform, transform)
start_time = time.time()
wm_method.gen_watermarks(dataset, device)
generation_time = time.time() - start_time
csv_args = [args.method, args.wm_type, args.dataset, args.arch, generation_time]
with open(args.save_file, 'a') as file:
writer = csv.writer(file, delimiter=';')
writer.writerow(csv_args)
except Exception as e:
msg = 'An error occured during watermark generation in ' + args.runname + ': ' + str(e)
traceback.print_tb(e.__traceback__)