-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_linprobe.py
294 lines (240 loc) · 12.3 KB
/
main_linprobe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import random
import time
from pathlib import Path
import paddle
import paddle.nn as nn
import paddle.vision.transforms as transforms
import paddle.vision.datasets as datasets
from paddle.fluid.optimizer import LarsMomentumOptimizer
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
from paddle.nn.initializer import TruncatedNormal
import util.misc as misc
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.misc import WandbLogger
from util.crop import RandomResizedCrop
import models_vit
from engine_finetune import train_one_epoch, evaluate
def get_args_parser():
parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)
parser.add_argument('--batch_size', default=512, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=90, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay (default: 0 for linear probe following MoCo v1)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=False)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=1000, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
# logger training parameters
parser.add_argument('--log_wandb', action='store_true',
help='log training and validation metrics to wandb')
parser.add_argument('--wandb_entity', default=None, type=str,
help='user or team name of wandb')
parser.add_argument('--wandb_project', default=None, type=str,
help='log training and validation metrics to wandb')
parser.add_argument('--debug', action='store_true')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
paddle.seed(args.seed)
np.random.seed(seed)
random.seed(seed)
if args.debug:
paddle.version.cudnn.FLAGS_cudnn_deterministic = True
# linear probe: weak augmentation
transform_train = transforms.Compose([
RandomResizedCrop(224, interpolation='bicubic'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_val = transforms.Compose([
transforms.Resize(256, interpolation='bicubic'),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = datasets.DatasetFolder(os.path.join(args.data_path, 'train' if not args.debug else 'val'), transform=transform_train)
dataset_val = datasets.DatasetFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
print(dataset_train)
print(dataset_val)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = DistributedBatchSampler(
dataset_train, args.batch_size, shuffle=True, drop_last=True)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = DistributedBatchSampler(
dataset_val, args.batch_size, shuffle=True, drop_last=False) # shuffle=True to reduce monitor bias
else:
sampler_val = BatchSampler(dataset=dataset_val, batch_size=args.batch_size)
if global_rank == 0 and args.log_wandb and not args.eval:
log_writer = WandbLogger(args, entity=args.wandb_entity, project=args.wandb_project)
else:
log_writer = None
data_loader_train = DataLoader(dataset_train, batch_sampler=sampler_train, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, batch_sampler=sampler_val, num_workers=args.num_workers)
model = models_vit.__dict__[args.model](
num_classes=args.nb_classes,
global_pool=args.global_pool,
)
if args.finetune and not args.eval:
checkpoint = paddle.load(args.finetune)
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
model.set_state_dict(checkpoint_model)
# manually initialize fc layer: following MoCo v3
TruncatedNormal(std=0.01)(model.head.weight)
# for linear prob only
# hack: revise model's head with BN
model.head = nn.Sequential(
nn.BatchNorm1D(model.head.weight.shape[0], epsilon=1e-6, weight_attr=False, bias_attr=False),
model.head)
# freeze all but the head
for _, p in model.named_parameters():
p.stop_gradient = True
for _, p in model.head[1].named_parameters():
p.stop_gradient = False
model_without_ddp = model
n_parameters = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient)
print("Model = %s" % str(model_without_ddp))
print('number of params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
model = paddle.DataParallel(model)
model_without_ddp = model._layers
optimizer = LarsMomentumOptimizer(
learning_rate=args.lr,
momentum=0.9,
lars_weight_decay=args.weight_decay,
parameter_list=model_without_ddp.head[1].parameters()
)
print(optimizer)
loss_scaler = NativeScaler()
criterion = nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
if args.eval:
test_stats = evaluate(data_loader_val, model)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
data_loader_train.batch_sampler.set_epoch(epoch)
if log_writer is not None:
num_training_steps_per_epoch = len(dataset_train) // eff_batch_size
log_writer.set_step(epoch * num_training_steps_per_epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
test_stats = evaluate(data_loader_val, model)
if args.output_dir:
misc.save_model(
args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, tag='latest')
if test_stats["acc1"] > max_accuracy:
misc.save_model(
args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, tag='best')
if (epoch + 1) % 20 == 0 or epoch + 1 == args.epochs:
misc.save_model(
args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch, 'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.update(log_stats)
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)