-
Notifications
You must be signed in to change notification settings - Fork 37
/
train.py
1697 lines (1508 loc) · 74.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.backends import cudnn
from bisect import bisect_right
import math
import os
parser = argparse.ArgumentParser(description='PyTorch local error training')
parser.add_argument('--model', default='vgg8b',
help='model, mlp, vgg13, vgg16, vgg19, vgg8b, vgg11b, resnet18, resnet34, wresnet28-10 and more (default: vgg8b)')
parser.add_argument('--dataset', default='CIFAR10',
help='dataset, MNIST, KuzushijiMNIST, FashionMNIST, CIFAR10, CIFAR100, SVHN, STL10 or ImageNet (default: CIFAR10)')
parser.add_argument('--batch-size', type=int, default=128,
help='input batch size for training (default: 128)')
parser.add_argument('--num-layers', type=int, default=1,
help='number of hidden fully-connected layers for mlp and vgg models (default: 1')
parser.add_argument('--num-hidden', type=int, default=1024,
help='number of hidden units for mpl model (default: 1024)')
parser.add_argument('--dim-in-decoder', type=int, default=4096,
help='input dimension of decoder_y used in pred and predsim loss (default: 4096)')
parser.add_argument('--feat-mult', type=float, default=1,
help='multiply number of CNN features with this number (default: 1)')
parser.add_argument('--epochs', type=int, default=400,
help='number of epochs to train (default: 400)')
parser.add_argument('--classes-per-batch', type=int, default=0,
help='aim for this number of different classes per batch during training (default: 0, random batches)')
parser.add_argument('--classes-per-batch-until-epoch', type=int, default=0,
help='limit number of classes per batch until this epoch (default: 0, until end of training)')
parser.add_argument('--lr', type=float, default=5e-4,
help='initial learning rate (default: 5e-4)')
parser.add_argument('--lr-decay-milestones', nargs='+', type=int, default=[200,300,350,375],
help='decay learning rate at these milestone epochs (default: [200,300,350,375])')
parser.add_argument('--lr-decay-fact', type=float, default=0.25,
help='learning rate decay factor to use at milestone epochs (default: 0.25)')
parser.add_argument('--optim', default='adam',
help='optimizer, adam, amsgrad or sgd (default: adam)')
parser.add_argument('--momentum', type=float, default=0.0,
help='SGD momentum (default: 0.0)')
parser.add_argument('--weight-decay', type=float, default=0.0,
help='weight decay (default: 0.0)')
parser.add_argument('--alpha', type=float, default=0.0,
help='unsupervised fraction in similarity matching loss (default: 0.0)')
parser.add_argument('--beta', type=float, default=0.99,
help='fraction of similarity matching loss in predsim loss (default: 0.99)')
parser.add_argument('--dropout', type=float, default=0.0,
help='dropout after each nonlinearity (default: 0.0)')
parser.add_argument('--loss-sup', default='predsim',
help='supervised local loss, sim or pred (default: predsim)')
parser.add_argument('--loss-unsup', default='none',
help='unsupervised local loss, none, sim or recon (default: none)')
parser.add_argument('--nonlin', default='relu',
help='nonlinearity, relu or leakyrelu (default: relu)')
parser.add_argument('--no-similarity-std', action='store_true', default=False,
help='disable use of standard deviation in similarity matrix for feature maps')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disable CUDA training')
parser.add_argument('--backprop', action='store_true', default=False,
help='disable local loss training')
parser.add_argument('--no-batch-norm', action='store_true', default=False,
help='disable batch norm before non-linearities')
parser.add_argument('--no-detach', action='store_true', default=False,
help='do not detach computational graph')
parser.add_argument('--pre-act', action='store_true', default=False,
help='use pre-activation in ResNet')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--save-dir', default='/hdd/results/local-error', type=str,
help='the directory used to save the trained models')
parser.add_argument('--resume', default='', type=str,
help='checkpoint to resume training from')
parser.add_argument('--progress-bar', action='store_true', default=False,
help='show progress bar during training')
parser.add_argument('--no-print-stats', action='store_true', default=False,
help='do not print layerwise statistics during training with local loss')
parser.add_argument('--bio', action='store_true', default=False,
help='use more biologically plausible versions of pred and sim loss (default: False)')
parser.add_argument('--target-proj-size', type=int, default=128,
help='size of target projection back to hidden layers for biologically plausible loss (default: 128')
parser.add_argument('--cutout', action='store_true', default=False,
help='apply cutout regularization')
parser.add_argument('--n_holes', type=int, default=1,
help='number of holes to cut out from image')
parser.add_argument('--length', type=int, default=16,
help='length of the cutout holes in pixels')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
cudnn.enabled = True
cudnn.benchmark = True
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
class Cutout(object):
'''Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
'''
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
'''
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
'''
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
class NClassRandomSampler(torch.utils.data.sampler.Sampler):
r'''Samples elements such that most batches have N classes per batch.
Elements are shuffled before each epoch.
Arguments:
targets: target class for each example in the dataset
n_classes_per_batch: the number of classes we want to have per batch
'''
def __init__(self, targets, n_classes_per_batch, batch_size):
self.targets = targets
self.n_classes = int(np.max(targets))
self.n_classes_per_batch = n_classes_per_batch
self.batch_size = batch_size
def __iter__(self):
n = self.n_classes_per_batch
ts = list(self.targets)
ts_i = list(range(len(self.targets)))
np.random.shuffle(ts_i)
#algorithm outline:
#1) put n examples in batch
#2) fill rest of batch with examples whose class is already in the batch
while len(ts_i) > 0:
idxs, ts_i = ts_i[:n], ts_i[n:] #pop n off the list
t_slice_set = set([ts[i] for i in idxs])
#fill up idxs until we have n different classes in it. this should be quick.
k = 0
while len(t_slice_set) < 10 and k < n*10 and k < len(ts_i):
if ts[ts_i[k]] not in t_slice_set:
idxs.append(ts_i.pop(k))
t_slice_set = set([ts[i] for i in idxs])
else:
k += 1
#fill up idxs with indexes whose classes are in t_slice_set.
j = 0
while j < len(ts_i) and len(idxs) < self.batch_size:
if ts[ts_i[j]] in t_slice_set:
idxs.append(ts_i.pop(j)) #pop is O(n), can we do better?
else:
j += 1
if len(idxs) < self.batch_size:
needed = self.batch_size-len(idxs)
idxs += ts_i[:needed]
ts_i = ts_i[needed:]
for i in idxs:
yield i
def __len__(self):
return len(self.targets)
class KuzushijiMNIST(datasets.MNIST):
urls = [
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz'
]
kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'MNIST':
input_dim = 28
input_ch = 1
num_classes = 10
train_transform = transforms.Compose([
transforms.RandomCrop(28, padding=2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = datasets.MNIST('../data/MNIST', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels.numpy(), args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'FashionMNIST':
input_dim = 28
input_ch = 1
num_classes = 10
train_transform = transforms.Compose([
transforms.RandomCrop(28, padding=2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.286,), (0.353,))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = datasets.FashionMNIST('../data/FashionMNIST', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels.numpy(), args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST('../data/FashionMNIST', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.286,), (0.353,))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'KuzushijiMNIST':
input_dim = 28
input_ch = 1
num_classes = 10
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1904,), (0.3475,))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = KuzushijiMNIST('../data/KuzushijiMNIST', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels.numpy(), args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
KuzushijiMNIST('../data/KuzushijiMNIST', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1904,), (0.3475,))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'CIFAR10':
input_dim = 32
input_ch = 3
num_classes = 10
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.424, 0.415, 0.384), (0.283, 0.278, 0.284))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = datasets.CIFAR10('../data/CIFAR10', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels, args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data/CIFAR10', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.424, 0.415, 0.384), (0.283, 0.278, 0.284))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'CIFAR100':
input_dim = 32
input_ch = 3
num_classes = 100
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.438, 0.418, 0.377), (0.300, 0.287, 0.294))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = datasets.CIFAR100('../data/CIFAR100', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.train_labels, args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR100('../data/CIFAR100', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.438, 0.418, 0.377), (0.300, 0.287, 0.294))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'SVHN':
input_dim = 32
input_ch = 3
num_classes = 10
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.431, 0.430, 0.446), (0.197, 0.198, 0.199))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = torch.utils.data.ConcatDataset((
datasets.SVHN('../data/SVHN', split='train', download=True, transform=train_transform),
datasets.SVHN('../data/SVHN', split='extra', download=True, transform=train_transform)))
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.labels, args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.SVHN('../data/SVHN', split='test', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.431, 0.430, 0.446), (0.197, 0.198, 0.199))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'STL10':
input_dim = 96
input_ch = 3
num_classes = 10
train_transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.447, 0.440, 0.407), (0.260, 0.257, 0.271))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = datasets.STL10('../data/STL10', split='train', download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(dataset_train.labels, args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.STL10('../data/STL10', split='test',
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.447, 0.440, 0.407), (0.260, 0.257, 0.271))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
elif args.dataset == 'ImageNet':
input_dim = 224
input_ch = 3
num_classes = 1000
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
dataset_train = datasets.ImageFolder('../data/ImageNet/train', transform=train_transform)
labels = np.array([a[1] for a in dataset_train.samples])
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler = None if args.classes_per_batch == 0 else NClassRandomSampler(labels, args.classes_per_batch, args.batch_size),
batch_size=args.batch_size, shuffle=args.classes_per_batch == 0, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder('../data/ImageNet/val',
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
else:
print('No valid dataset is specified')
class LinearFAFunction(torch.autograd.Function):
'''Autograd function for linear feedback alignment module.
'''
@staticmethod
def forward(context, input, weight, weight_fa, bias=None):
context.save_for_backward(input, weight, weight_fa, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(context, grad_output):
input, weight, weight_fa, bias = context.saved_variables
grad_input = grad_weight = grad_weight_fa = grad_bias = None
if context.needs_input_grad[0]:
grad_input = grad_output.matmul(weight_fa)
if context.needs_input_grad[1]:
grad_weight = grad_output.t().matmul(input)
if bias is not None and context.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_weight_fa, grad_bias
class LinearFA(nn.Module):
'''Linear feedback alignment module.
Args:
input_features (int): Number of input features to linear layer.
output_features (int): Number of output features from linear layer.
bias (bool): True if to use trainable bias.
'''
def __init__(self, input_features, output_features, bias=True):
super(LinearFA, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
self.weight_fa = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
if args.cuda:
self.weight.data = self.weight.data.cuda()
self.weight_fa.data = self.weight_fa.data.cuda()
if bias:
self.bias.data = self.bias.data.cuda()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
self.weight_fa.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
return LinearFAFunction.apply(input, self.weight, self.weight_fa, self.bias)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.input_features) \
+ ', out_features=' + str(self.output_features) \
+ ', bias=' + str(self.bias is not None) + ')'
class LocalLossBlockLinear(nn.Module):
'''A module containing nn.Linear -> nn.BatchNorm1d -> nn.ReLU -> nn.Dropout
The block can be trained by backprop or by locally generated error signal based on cross-entropy and/or similarity matching loss.
Args:
num_in (int): Number of input features to linear layer.
num_out (int): Number of output features from linear layer.
num_classes (int): Number of classes (used in local prediction loss).
first_layer (bool): True if this is the first layer in the network (used in local reconstruction loss).
dropout (float): Dropout rate, if None, read from args.dropout.
batchnorm (bool): True if to use batchnorm, if None, read from args.no_batch_norm.
'''
def __init__(self, num_in, num_out, num_classes, first_layer=False, dropout=None, batchnorm=None):
super(LocalLossBlockLinear, self).__init__()
self.num_classes = num_classes
self.first_layer = first_layer
self.dropout_p = args.dropout if dropout is None else dropout
self.batchnorm = not args.no_batch_norm if batchnorm is None else batchnorm
self.encoder = nn.Linear(num_in, num_out, bias=True)
if not args.backprop and args.loss_unsup == 'recon':
self.decoder_x = nn.Linear(num_out, num_in, bias=True)
if not args.backprop and (args.loss_sup == 'pred' or args.loss_sup == 'predsim'):
if args.bio:
self.decoder_y = LinearFA(num_out, args.target_proj_size)
else:
self.decoder_y = nn.Linear(num_out, num_classes)
self.decoder_y.weight.data.zero_()
if not args.backprop and args.bio:
self.proj_y = nn.Linear(num_classes, args.target_proj_size, bias=False)
if not args.backprop and not args.bio and (args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim'):
self.linear_loss = nn.Linear(num_out, num_out, bias=False)
if self.batchnorm:
self.bn = torch.nn.BatchNorm1d(num_out)
nn.init.constant_(self.bn.weight, 1)
nn.init.constant_(self.bn.bias, 0)
if args.nonlin == 'relu':
self.nonlin = nn.ReLU(inplace=True)
elif args.nonlin == 'leakyrelu':
self.nonlin = nn.LeakyReLU(negative_slope=0.01, inplace=True)
if self.dropout_p > 0:
self.dropout = torch.nn.Dropout(p=self.dropout_p, inplace=False)
if args.optim == 'sgd':
self.optimizer = optim.SGD(self.parameters(), lr=0, weight_decay=args.weight_decay, momentum=args.momentum)
elif args.optim == 'adam' or args.optim == 'amsgrad':
self.optimizer = optim.Adam(self.parameters(), lr=0, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')
self.clear_stats()
def clear_stats(self):
if not args.no_print_stats:
self.loss_sim = 0.0
self.loss_pred = 0.0
self.correct = 0
self.examples = 0
def print_stats(self):
if not args.backprop:
stats = '{}, loss_sim={:.4f}, loss_pred={:.4f}, error={:.3f}%, num_examples={}\n'.format(
self.encoder,
self.loss_sim / self.examples,
self.loss_pred / self.examples,
100.0 * float(self.examples - self.correct) / self.examples,
self.examples)
return stats
else:
return ''
def set_learning_rate(self, lr):
self.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
def optim_zero_grad(self):
self.optimizer.zero_grad()
def optim_step(self):
self.optimizer.step()
def forward(self, x, y, y_onehot):
# The linear transformation
h = self.encoder(x)
# Add batchnorm and nonlinearity
if self.batchnorm:
h = self.bn(h)
h = self.nonlin(h)
# Save return value and add dropout
h_return = h
if self.dropout_p > 0:
h_return = self.dropout(h_return)
# Calculate local loss and update weights
if (self.training or not args.no_print_stats) and not args.backprop:
# Calculate hidden layer similarity matrix
if args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim':
if args.bio:
h_loss = h
else:
h_loss = self.linear_loss(h)
Rh = similarity_matrix(h_loss)
# Calculate unsupervised loss
if args.loss_unsup == 'sim':
Rx = similarity_matrix(x).detach()
loss_unsup = F.mse_loss(Rh, Rx)
elif args.loss_unsup == 'recon' and not self.first_layer:
x_hat = self.nonlin(self.decoder_x(h))
loss_unsup = F.mse_loss(x_hat, x.detach())
else:
if args.cuda:
loss_unsup = torch.cuda.FloatTensor([0])
else:
loss_unsup = torch.FloatTensor([0])
# Calculate supervised loss
if args.loss_sup == 'sim':
if args.bio:
Ry = similarity_matrix(self.proj_y(y_onehot)).detach()
else:
Ry = similarity_matrix(y_onehot).detach()
loss_sup = F.mse_loss(Rh, Ry)
if not args.no_print_stats:
self.loss_sim += loss_sup.item() * h.size(0)
self.examples += h.size(0)
elif args.loss_sup == 'pred':
y_hat_local = self.decoder_y(h.view(h.size(0), -1))
if args.bio:
float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
loss_sup = F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
else:
loss_sup = F.cross_entropy(y_hat_local, y.detach())
if not args.no_print_stats:
self.loss_pred += loss_sup.item() * h.size(0)
self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
self.examples += h.size(0)
elif args.loss_sup == 'predsim':
y_hat_local = self.decoder_y(h.view(h.size(0), -1))
if args.bio:
Ry = similarity_matrix(self.proj_y(y_onehot)).detach()
float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
loss_pred = (1-args.beta) * F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
else:
Ry = similarity_matrix(y_onehot).detach()
loss_pred = (1-args.beta) * F.cross_entropy(y_hat_local, y.detach())
loss_sim = args.beta * F.mse_loss(Rh, Ry)
loss_sup = loss_pred + loss_sim
if not args.no_print_stats:
self.loss_pred += loss_pred.item() * h.size(0)
self.loss_sim += loss_sim.item() * h.size(0)
self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
self.examples += h.size(0)
# Combine unsupervised and supervised loss
loss = args.alpha * loss_unsup + (1 - args.alpha) * loss_sup
# Single-step back-propagation
if self.training:
loss.backward(retain_graph = args.no_detach)
# Update weights in this layer and detatch computational graph
if self.training and not args.no_detach:
self.optimizer.step()
self.optimizer.zero_grad()
h_return.detach_()
loss = loss.item()
else:
loss = 0.0
return h_return, loss
class LocalLossBlockConv(nn.Module):
'''
A block containing nn.Conv2d -> nn.BatchNorm2d -> nn.ReLU -> nn.Dropou2d
The block can be trained by backprop or by locally generated error signal based on cross-entropy and/or similarity matching loss.
Args:
ch_in (int): Number of input features maps.
ch_out (int): Number of output features maps.
kernel_size (int): Kernel size in Conv2d.
stride (int): Stride in Conv2d.
padding (int): Padding in Conv2d.
num_classes (int): Number of classes (used in local prediction loss).
dim_out (int): Feature map height/width for input (and output).
first_layer (bool): True if this is the first layer in the network (used in local reconstruction loss).
dropout (float): Dropout rate, if None, read from args.dropout.
bias (bool): True if to use trainable bias.
pre_act (bool): True if to apply layer order nn.BatchNorm2d -> nn.ReLU -> nn.Dropou2d -> nn.Conv2d (used for PreActResNet).
post_act (bool): True if to apply layer order nn.Conv2d -> nn.BatchNorm2d -> nn.ReLU -> nn.Dropou2d.
'''
def __init__(self, ch_in, ch_out, kernel_size, stride, padding, num_classes, dim_out, first_layer=False, dropout=None, bias=None, pre_act=False, post_act=True):
super(LocalLossBlockConv, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.num_classes = num_classes
self.first_layer = first_layer
self.dropout_p = args.dropout if dropout is None else dropout
self.bias = True if bias is None else bias
self.pre_act = pre_act
self.post_act = post_act
self.encoder = nn.Conv2d(ch_in, ch_out, kernel_size, stride=stride, padding=padding, bias=self.bias)
if not args.backprop and args.loss_unsup == 'recon':
self.decoder_x = nn.ConvTranspose2d(ch_out, ch_in, kernel_size, stride=stride, padding=padding)
if args.bio or (not args.backprop and (args.loss_sup == 'pred' or args.loss_sup == 'predsim')):
# Resolve average-pooling kernel size in order for flattened dim to match args.dim_in_decoder
ks_h, ks_w = 1, 1
dim_out_h, dim_out_w = dim_out, dim_out
dim_in_decoder = ch_out*dim_out_h*dim_out_w
while dim_in_decoder > args.dim_in_decoder and ks_h < dim_out:
ks_h*=2
dim_out_h = math.ceil(dim_out / ks_h)
dim_in_decoder = ch_out*dim_out_h*dim_out_w
if dim_in_decoder > args.dim_in_decoder:
ks_w*=2
dim_out_w = math.ceil(dim_out / ks_w)
dim_in_decoder = ch_out*dim_out_h*dim_out_w
if ks_h > 1 or ks_w > 1:
pad_h = (ks_h * (dim_out_h - dim_out // ks_h)) // 2
pad_w = (ks_w * (dim_out_w - dim_out // ks_w)) // 2
self.avg_pool = nn.AvgPool2d((ks_h,ks_w), padding=(pad_h, pad_w))
else:
self.avg_pool = None
if not args.backprop and (args.loss_sup == 'pred' or args.loss_sup == 'predsim'):
if args.bio:
self.decoder_y = LinearFA(dim_in_decoder, args.target_proj_size)
else:
self.decoder_y = nn.Linear(dim_in_decoder, num_classes)
self.decoder_y.weight.data.zero_()
if not args.backprop and args.bio:
self.proj_y = nn.Linear(num_classes, args.target_proj_size, bias=False)
if not args.backprop and (args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim'):
self.conv_loss = nn.Conv2d(ch_out, ch_out, 3, stride=1, padding=1, bias=False)
if not args.no_batch_norm:
if pre_act:
self.bn_pre = torch.nn.BatchNorm2d(ch_in)
if not (pre_act and args.backprop):
self.bn = torch.nn.BatchNorm2d(ch_out)
nn.init.constant_(self.bn.weight, 1)
nn.init.constant_(self.bn.bias, 0)
if args.nonlin == 'relu':
self.nonlin = nn.ReLU(inplace=True)
elif args.nonlin == 'leakyrelu':
self.nonlin = nn.LeakyReLU(negative_slope=0.01, inplace=True)
if self.dropout_p > 0:
self.dropout = torch.nn.Dropout2d(p=self.dropout_p, inplace=False)
if args.optim == 'sgd':
self.optimizer = optim.SGD(self.parameters(), lr=0, weight_decay=args.weight_decay, momentum=args.momentum)
elif args.optim == 'adam' or args.optim == 'amsgrad':
self.optimizer = optim.Adam(self.parameters(), lr=0, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')
self.clear_stats()
def clear_stats(self):
if not args.no_print_stats:
self.loss_sim = 0.0
self.loss_pred = 0.0
self.correct = 0
self.examples = 0
def print_stats(self):
if not args.backprop:
stats = '{}, loss_sim={:.4f}, loss_pred={:.4f}, error={:.3f}%, num_examples={}\n'.format(
self.encoder,
self.loss_sim / self.examples,
self.loss_pred / self.examples,
100.0 * float(self.examples - self.correct) / self.examples,
self.examples)
return stats
else:
return ''
def set_learning_rate(self, lr):
self.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
def optim_zero_grad(self):
self.optimizer.zero_grad()
def optim_step(self):
self.optimizer.step()
def forward(self, x, y, y_onehot, x_shortcut=None):
# If pre-activation, apply batchnorm->nonlin->dropout
if self.pre_act:
if not args.no_batch_norm:
x = self.bn_pre(x)
x = self.nonlin(x)
if self.dropout_p > 0:
x = self.dropout(x)
# The convolutional transformation
h = self.encoder(x)
# If post-activation, apply batchnorm
if self.post_act and not args.no_batch_norm:
h = self.bn(h)
# Add shortcut branch (used in residual networks)
if x_shortcut is not None:
h = h + x_shortcut
# If post-activation, add nonlinearity
if self.post_act:
h = self.nonlin(h)
# Save return value and add dropout
h_return = h
if self.post_act and self.dropout_p > 0:
h_return = self.dropout(h_return)
# Calculate local loss and update weights
if (not args.no_print_stats or self.training) and not args.backprop:
# Add batchnorm and nonlinearity if not done already
if not self.post_act:
if not args.no_batch_norm:
h = self.bn(h)
h = self.nonlin(h)
# Calculate hidden feature similarity matrix
if args.loss_unsup == 'sim' or args.loss_sup == 'sim' or args.loss_sup == 'predsim':
if args.bio:
h_loss = h
if self.avg_pool is not None:
h_loss = self.avg_pool(h_loss)
else:
h_loss = self.conv_loss(h)
Rh = similarity_matrix(h_loss)
# Calculate unsupervised loss
if args.loss_unsup == 'sim':
Rx = similarity_matrix(x).detach()
loss_unsup = F.mse_loss(Rh, Rx)
elif args.loss_unsup == 'recon' and not self.first_layer:
x_hat = self.nonlin(self.decoder_x(h))
loss_unsup = F.mse_loss(x_hat, x.detach())
else:
if args.cuda:
loss_unsup = torch.cuda.FloatTensor([0])
else:
loss_unsup = torch.FloatTensor([0])
# Calculate supervised loss
if args.loss_sup == 'sim':
if args.bio:
Ry = similarity_matrix(self.proj_y(y_onehot)).detach()
else:
Ry = similarity_matrix(y_onehot).detach()
loss_sup = F.mse_loss(Rh, Ry)
if not args.no_print_stats:
self.loss_sim += loss_sup.item() * h.size(0)
self.examples += h.size(0)
elif args.loss_sup == 'pred':
if self.avg_pool is not None:
h = self.avg_pool(h)
y_hat_local = self.decoder_y(h.view(h.size(0), -1))
if args.bio:
float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
loss_sup = F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
else:
loss_sup = F.cross_entropy(y_hat_local, y.detach())
if not args.no_print_stats:
self.loss_pred += loss_sup.item() * h.size(0)
self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
self.examples += h.size(0)
elif args.loss_sup == 'predsim':
if self.avg_pool is not None:
h = self.avg_pool(h)
y_hat_local = self.decoder_y(h.view(h.size(0), -1))
if args.bio:
Ry = similarity_matrix(self.proj_y(y_onehot)).detach()
float_type = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
y_onehot_pred = self.proj_y(y_onehot).gt(0).type(float_type).detach()
loss_pred = (1-args.beta) * F.binary_cross_entropy_with_logits(y_hat_local, y_onehot_pred)
else:
Ry = similarity_matrix(y_onehot).detach()
loss_pred = (1-args.beta) * F.cross_entropy(y_hat_local, y.detach())
loss_sim = args.beta * F.mse_loss(Rh, Ry)
loss_sup = loss_pred + loss_sim
if not args.no_print_stats:
self.loss_pred += loss_pred.item() * h.size(0)
self.loss_sim += loss_sim.item() * h.size(0)
self.correct += y_hat_local.max(1)[1].eq(y).cpu().sum()
self.examples += h.size(0)
# Combine unsupervised and supervised loss
loss = args.alpha * loss_unsup + (1 - args.alpha) * loss_sup
# Single-step back-propagation
if self.training:
loss.backward(retain_graph = args.no_detach)
# Update weights in this layer and detatch computational graph
if self.training and not args.no_detach:
self.optimizer.step()
self.optimizer.zero_grad()
h_return.detach_()
loss = loss.item()
else:
loss = 0.0
return h_return, loss
class BasicBlock(nn.Module):
''' Used in ResNet() '''
expansion = 1
def __init__(self, in_planes, planes, stride, num_classes, input_dim):
super(BasicBlock, self).__init__()
self.input_dim = input_dim
self.stride = stride
self.conv1 = LocalLossBlockConv(in_planes, planes, 3, stride, 1, num_classes, input_dim, bias=False, pre_act=args.pre_act, post_act=not args.pre_act)
self.conv2 = LocalLossBlockConv(planes, planes, 3, 1, 1, num_classes, input_dim, bias=False, pre_act=args.pre_act, post_act=not args.pre_act)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False, groups=1),
nn.BatchNorm2d(self.expansion*planes)
)
if args.optim == 'sgd':
self.optimizer = optim.SGD(self.shortcut.parameters(), lr=0, weight_decay=args.weight_decay, momentum=args.momentum)
elif args.optim == 'adam' or args.optim == 'amsgrad':
self.optimizer = optim.Adam(self.shortcut.parameters(), lr=0, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')
def set_learning_rate(self, lr):
self.lr = lr
self.conv1.set_learning_rate(lr)
self.conv2.set_learning_rate(lr)
if len(self.shortcut) > 0:
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def optim_zero_grad(self):
self.conv1.optim_zero_grad()
self.conv2.optim_zero_grad()
if len(self.shortcut) > 0:
self.optimizer.zero_grad()
def optim_step(self):
self.conv1.optim_step()
self.conv2.optim_step()
if len(self.shortcut) > 0:
self.optimizer.step()
def forward(self, input):
x, y, y_onehot, loss_total = input
out,loss = self.conv1(x, y, y_onehot)
loss_total += loss
out,loss = self.conv2(out, y, y_onehot, self.shortcut(x))
loss_total += loss
if not args.no_detach:
if len(self.shortcut) > 0:
self.optimizer.step()
self.optimizer.zero_grad()
return (out, y, y_onehot, loss_total)
class Bottleneck(nn.Module):
''' Used in ResNet() '''
expansion = 4
def __init__(self, in_planes, planes, stride, num_classes, input_dim):
super(Bottleneck, self).__init__()
self.conv1 = LocalLossBlockConv(in_planes, planes, 1, 1, 0, num_classes, input_dim, bias=False)
self.conv2 = LocalLossBlockConv(planes, planes, 3, stride, 1, num_classes, input_dim//stride, bias=False)
self.conv3 = LocalLossBlockConv(planes, self.expansion*planes, 1, 1, 0, num_classes, input_dim//stride, bias=False)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
if args.optim == 'sgd':
self.optimizer = optim.SGD(self.shortcut.parameters(), lr=0, weight_decay=args.weight_decay, momentum=args.momentum)
elif args.optim == 'adam' or args.optim == 'amsgrad':
self.optimizer = optim.Adam(self.shortcut.parameters(), lr=0, weight_decay=args.weight_decay, amsgrad=args.optim == 'amsgrad')
def set_learning_rate(self, lr):
self.lr = lr
self.conv1.set_learning_rate(lr)
self.conv2.set_learning_rate(lr)
self.conv3.set_learning_rate(lr)
if len(self.shortcut) > 0:
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def optim_zero_grad(self):
self.conv1.optim_zero_grad()
self.conv2.optim_zero_grad()
self.conv3.optim_zero_grad()
if len(self.shortcut) > 0:
self.optimizer.zero_grad()
def optim_step(self):
self.conv1.optim_step()
self.conv2.optim_step()
self.conv3.optim_step()
if len(self.shortcut) > 0:
self.optimizer.step()
def forward(self, input):
x, y, y_onehot, loss_total = input
out,loss = self.conv1(x, y, y_onehot)
loss_total += loss
out, loss = self.conv2(out, y, y_onehot)
loss_total += loss
out, loss = self.conv3(out, y, y_onehot, self.shortcut(x))
loss_total += loss
if not args.no_detach:
if len(self.shortcut) > 0:
self.optimizer.step()
self.optimizer.zero_grad()
return (out, y, y_onehot, loss_total)
class ResNet(nn.Module):
'''