forked from sovit-123/fasterrcnn-pytorch-training-pipeline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
214 lines (193 loc) · 7.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
USAGE
Training on custom ResNet:
python train.py --model fasterrcnn_custom_resnet --epochs 2 --config data_configs/voc.yaml --no-mosaic --batch-size 16
Training on ResNet50 FPN with custom project folder name and visualizing transformed images before training begins:
python train.py --model fasterrcnn_resnet5-_fpn --epochs 2 --config data_configs/voc.yaml -vt --project-name resnet50fpn_voc --no-mosaic --batch-size 16
"""
from tabnanny import check
from torch_utils.engine import (
train_one_epoch, evaluate
)
from datasets import (
create_train_dataset, create_valid_dataset,
create_train_loader, create_valid_loader
)
from models.create_fasterrcnn_model import create_model
from utils.general import (
set_training_dir, Averager,
save_model_state, save_train_loss_plot,
show_tranformed_image
)
from utils.logging import log, set_log, coco_log
import torch
import argparse
import yaml
import numpy as np
# For same annotation colors each time.
np.random.seed(42)
if __name__ == '__main__':
# Construct the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--model', default='fasterrcnn_resnet50_fpn',
help='name of the model'
)
parser.add_argument(
'-c', '--config', default=None,
help='path to the data config file'
)
parser.add_argument(
'-d', '--device',
default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
help='computation/training device, default is GPU if GPU present'
)
parser.add_argument(
'-e', '--epochs', default=5, type=int,
help='number of epochs to train for'
)
parser.add_argument(
'-w', '--workers', default=4, type=int,
help='number of workers for data processing/transforms/augmentations'
)
parser.add_argument(
'-b', '--batch-size', dest='batch_size', default=8, type=int,
help='batch size to load the data'
)
parser.add_argument(
'-ims', '--img-size', dest='img_size', default=512, type=int,
help='image size to feed to the network'
)
parser.add_argument(
'-pn', '--project-name', default=None, type=str, dest='project_name',
help='training result dir name in outputs/training/, (default res_#)'
)
parser.add_argument(
'-vt', '--viz-transformed', dest='vis_transformed', action='store_true',
help='visualize transformed images fed to the network'
)
parser.add_argument(
'-nm', '--no-mosaic', dest='no_mosaic', action='store_false',
help='pass this to not to use mosaic augmentation'
)
parser.add_argument(
'-uta', '--use-train-aug', dest='use_train_aug', action='store_true',
help='whether to use train augmentation, uses some advanced augmentation \
that may make training difficult when used with mosaic'
)
parser.add_argument(
'-ca', '--cosine-annealing', dest='cosine_annealing', action='store_true',
help='use cosine annealing warm restarts'
)
parser.add_argument(
'-mw', '--weights', default=None, type=str,
help='path to model weights if resuming training'
)
args = vars(parser.parse_args())
# Load the data configurations
with open(args['config']) as file:
data_configs = yaml.safe_load(file)
# Settings/parameters/constants.
TRAIN_DIR_IMAGES = data_configs['TRAIN_DIR_IMAGES']
TRAIN_DIR_LABELS = data_configs['TRAIN_DIR_LABELS']
VALID_DIR_IMAGES = data_configs['VALID_DIR_IMAGES']
VALID_DIR_LABELS = data_configs['VALID_DIR_LABELS']
CLASSES = data_configs['CLASSES']
NUM_CLASSES = data_configs['NC']
NUM_WORKERS = args['workers']
DEVICE = args['device']
NUM_EPOCHS = args['epochs']
SAVE_VALID_PREDICTIONS = data_configs['SAVE_VALID_PREDICTION_IMAGES']
BATCH_SIZE = args['batch_size']
VISUALIZE_TRANSFORMED_IMAGES = args['vis_transformed']
OUT_DIR = set_training_dir(args['project_name'])
COLORS = np.random.uniform(0, 1, size=(len(CLASSES), 3))
# Set logging file.
set_log(OUT_DIR)
# Model configurations
IMAGE_WIDTH = args['img_size']
IMAGE_HEIGHT = args['img_size']
device = 'cuda:0'
train_dataset = create_train_dataset(
TRAIN_DIR_IMAGES, TRAIN_DIR_LABELS,
IMAGE_WIDTH, IMAGE_HEIGHT, CLASSES,
use_train_aug=args['use_train_aug'],
mosaic=args['no_mosaic']
)
valid_dataset = create_valid_dataset(
VALID_DIR_IMAGES, VALID_DIR_LABELS,
IMAGE_WIDTH, IMAGE_HEIGHT, CLASSES
)
train_loader = create_train_loader(train_dataset, BATCH_SIZE, NUM_WORKERS)
valid_loader = create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS)
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(valid_dataset)}\n")
if VISUALIZE_TRANSFORMED_IMAGES:
show_tranformed_image(train_loader, DEVICE, CLASSES, COLORS)
# Initialize the Averager class.
train_loss_hist = Averager()
# Train and validation loss lists to store loss values of all
# iterations till ena and plot graphs for all iterations.
train_loss_list = []
create_model = create_model[args['model']]
model = create_model(num_classes=NUM_CLASSES)
# Load pretrained weights if path is provided.
if args['weights'] is not None:
print('Loading trained weights...')
checkpoint = torch.load(args['weights'])
model.load_state_dict(checkpoint['model_state_dict'])
print(model)
model = model.to(DEVICE)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# Get the model parameters.
params = [p for p in model.parameters() if p.requires_grad]
# Define the optimizer.
# optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
optimizer = torch.optim.AdamW(params, lr=0.0001, weight_decay=0.0005)
if args['cosine_annealing']:
# LR will be zero as we approach `steps` number of epochs each time.
# If `steps = 5`, LR will slowly reduce to zero every 5 epochs.
steps = NUM_EPOCHS + 10
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=steps,
T_mult=1,
verbose=False
)
else:
scheduler = None
for epoch in range(NUM_EPOCHS):
train_loss_hist.reset()
_, batch_loss_list = train_one_epoch(
model,
optimizer,
train_loader,
DEVICE,
epoch,
train_loss_hist,
print_freq=100,
scheduler=scheduler
)
coco_evaluator, stats = evaluate(
model,
valid_loader,
device=DEVICE,
save_valid_preds=SAVE_VALID_PREDICTIONS,
out_dir=OUT_DIR,
classes=CLASSES,
colors=COLORS
)
# Add the current epoch's batch-wise lossed to the `train_loss_list`.
train_loss_list.extend(batch_loss_list)
# Save the current epoch model state. This can be used
# to resume training. It saves model state dict, number of
# epochs trained for, optimizer state dict, and loss function.
save_model_state(epoch, model, optimizer, OUT_DIR)
# Save loss plot.
save_train_loss_plot(OUT_DIR, train_loss_list)
coco_log(OUT_DIR, stats)