Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[utils] log everything to tensorboard #2307

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,18 @@ def main():

dist.barrier(
) # NOTE(xcsong): Ensure all ranks start CV at the same time.
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, configs)
cv_loss = total_loss / num_seen_utts
loss_dict = executor.cv(model, cv_data_loader, configs)

lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} CV info lr {} cv_loss {} rank {}'.format(
epoch, lr, cv_loss, rank))
logging.info('Epoch {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, lr, loss_dict["loss"], rank, loss_dict["acc"]))
info_dict = {
'epoch': epoch,
'lr': lr,
'cv_loss': cv_loss,
'step': executor.step,
'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
'tag': "epoch_{}".format(epoch),
'loss_dict': loss_dict,
**configs
}
log_per_epoch(writer, info_dict=info_dict)
Expand Down
31 changes: 19 additions & 12 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,13 @@ def train(self, model, optimizer, scheduler, train_data_loader,
save_interval = info_dict.get('save_interval', 100000000000000)
if self.step % save_interval == 0 and self.step != 0 \
and (batch_idx + 1) % info_dict["accum_grad"] == 0:
total_loss, num_seen_utts = self.cv(
model, cv_data_loader, configs)
loss_dict = self.cv(model, cv_data_loader, configs)
model.train()
info_dict.update({
"tag":
"step_{}".format(self.step),
"cv_loss":
total_loss / num_seen_utts,
"loss_dict":
loss_dict,
"save_time":
datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
"lr":
Expand All @@ -104,7 +103,7 @@ def cv(self, model, cv_data_loader, configs):
'''
model.eval()
info_dict = copy.deepcopy(configs)
num_seen_utts, total_loss = 1, 0.0 # in order to avoid division by 0
num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0
with torch.no_grad():
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
Expand All @@ -116,12 +115,20 @@ def cv(self, model, cv_data_loader, configs):
continue

info_dict = batch_forward(model, batch_dict, None, info_dict)
loss = info_dict['loss_dict']['loss']
_dict = info_dict["loss_dict"]

num_seen_utts += num_utts
total_acc.append(_dict['th_accuracy']
if _dict['th_accuracy'] is not None else 0.0)
for loss_name, loss_value in _dict.items():
if loss_value is not None and "loss" in loss_name \
and torch.isfinite(loss_value):
loss_value = loss_value.item()
loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \
loss_value * num_utts

if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts

info_dict["history_loss"] = total_loss / num_seen_utts
log_per_step(writer=None, info_dict=info_dict)
return total_loss, num_seen_utts
for loss_name, loss_value in loss_dict.items():
loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts
loss_dict["acc"] = sum(total_acc) / len(total_acc)
return loss_dict
14 changes: 8 additions & 6 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ def log_per_step(writer, info_dict):
accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1
log_interval = info_dict.get('log_interval', 10)
lr = info_dict.get("lr", 0.0)
history_loss = info_dict.get("history_loss", 0.0)
is_gradient_accumulation_boundary = info_dict.get(
"is_gradient_accumulation_boundary", False)

Expand All @@ -577,10 +576,13 @@ def log_per_step(writer, info_dict):
loss_dict['loss'] * accum_grad, step + 1)
writer.add_scalar('train/grad_norm', info_dict['grad_norm'],
step + 1)
for name, value in loss_dict.items():
if name != 'loss' and value is not None:
writer.add_scalar('train/{}'.format(name), value, step + 1)
elif "step_" in tag and rank == 0 and writer is not None:
writer.add_scalar('global_step/cv_loss', info_dict["cv_loss"],
step + 1)
writer.add_scalar('global_step/lr', lr, step + 1)
for name, value in loss_dict.items():
writer.add_scalar('global_step/{}'.format(name), value, step + 1)

if (batch_idx + 1) % log_interval == 0:
log_str = '{} Batch {}/{} loss {:.6f} '.format(
Expand All @@ -591,16 +593,16 @@ def log_per_step(writer, info_dict):
if tag == "TRAIN":
log_str += 'lr {:.8f} grad_norm {:.6f} rank {}'.format(
lr, info_dict['grad_norm'], rank)
elif tag == "CV":
log_str += 'history loss {:.6f} rank {}'.format(history_loss, rank)
logging.debug(log_str)


def log_per_epoch(writer, info_dict):
epoch = info_dict["epoch"]
loss_dict = info_dict["loss_dict"]
if int(os.environ.get('RANK', 0)) == 0:
writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch)
writer.add_scalar('epoch/lr', info_dict["lr"], epoch)
for name, value in loss_dict.items():
writer.add_scalar('epoch/{}'.format(name), value, epoch)


def freeze_modules(model, args):
Expand Down
Loading