Skip to content

Commit

Permalink
fix(train): Fix step & destroy group (#2168)
Browse files Browse the repository at this point in the history
* fix(train): Fix log

* fix(train): Fix log

* fix(train): fix log

* fix(train): destroy group
  • Loading branch information
xingchensong authored Nov 29, 2023
1 parent 3ab6718 commit 1aadcf7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main():

# Get executor
executor = Executor()
executor.step = configs["init_infos"].get('step', 0)
executor.step = configs["init_infos"].get('step', -1)

# Init scaler, used for pytorch amp mixed precision training
scaler = None
Expand All @@ -129,13 +129,13 @@ def main():
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(epoch, lr, rank))

dist.barrier() # NOTE(xcsong): Ensure all ranks start Train at the same time.
# NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join`
group_join = dist.new_group(backend="gloo",
timeout=datetime.timedelta(seconds=args.timeout))

dist.barrier() # NOTE(xcsong): Ensure all ranks start Train at the same time.
executor.train(model, optimizer, scheduler, train_data_loader,
writer, configs, scaler, group_join)
dist.destroy_process_group(group_join)

dist.barrier() # NOTE(xcsong): Ensure all ranks start CV at the same time.
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, configs)
Expand Down
2 changes: 1 addition & 1 deletion wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def train(self, model, optimizer, scheduler, data_loader, writer,
scaler, info_dict
)
log_per_step(writer, info_dict)
self.step += 1
self.step += 1 if (batch_idx + 1) % info_dict["accum_grad"] == 0 else 0

def cv(self, model, data_loader, configs):
''' Cross validation on
Expand Down
2 changes: 1 addition & 1 deletion wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def scheduler(opt):
args=args, model=model, optimizer=optimizer,
lr_scheduler=scheduler, model_parameters=model.parameters())

step = configs["init_infos"].get("step", 0)
step = configs["init_infos"].get("step", -1)
scheduler.set_step(step)
return model, optimizer, scheduler

Expand Down

0 comments on commit 1aadcf7

Please sign in to comment.