Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nvidia Apex for FP16 calculations #36

Merged
merged 2 commits into from
Jul 24, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from model import Model
from test import validation

try:
from apex import amp
from apex import fp16_utils
APEX_AVAILABLE = True
amp_handle = amp.init(enabled=True)
except ModuleNotFoundError:
APEX_AVAILABLE = False

def train(opt):
""" dataset preparation """
Expand All @@ -42,7 +49,7 @@ def train(opt):

if opt.rgb:
opt.input_channel = 3
model = Model(opt)
model = Model(opt).cuda()
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
Expand All @@ -62,9 +69,7 @@ def train(opt):
param.data.fill_(1)
continue

# data parallel for multi-GPU
model = torch.nn.DataParallel(model).cuda()
model.train()

if opt.continue_model != '':
print(f'loading pretrained model from {opt.continue_model}')
model.load_state_dict(torch.load(opt.continue_model))
Expand Down Expand Up @@ -118,6 +123,13 @@ def train(opt):
best_norm_ED = 1e+6
i = start_iter

if APEX_AVAILABLE:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

# data parallel for multi-GPU
model = torch.nn.DataParallel(model).cuda()
model.train()

while(True):
# train part
for p in model.parameters():
Expand All @@ -140,8 +152,13 @@ def train(opt):
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

model.zero_grad()
cost.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
if APEX_AVAILABLE:
with amp.scale_loss(cost, optimizer) as scaled_loss:
scaled_loss.backward()
fp16_utils.clip_grad_norm(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
else:
cost.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
optimizer.step()

loss_avg.add(cost)
Expand Down