Skip to content

Commit

Permalink
refactor(dataset): return dict instead of tuple (#2106)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Nov 2, 2023
1 parent 0d8344c commit b8a3340
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 32 deletions.
7 changes: 2 additions & 5 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,16 @@ def main():
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(epoch, lr, rank))

device = model.local_rank if args.deepspeed else device

# NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join`
group_join = dist.new_group(backend="gloo",
timeout=datetime.timedelta(seconds=30))

dist.barrier() # NOTE(xcsong): Ensure all ranks start Train at the same time.
executor.train(model, optimizer, scheduler, train_data_loader, device,
executor.train(model, optimizer, scheduler, train_data_loader,
writer, configs, scaler, group_join)

dist.barrier() # NOTE(xcsong): Ensure all ranks start CV at the same time.
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device,
configs)
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, configs)
cv_loss = total_loss / num_seen_utts

logging.info('Epoch {} CV info cv_loss {} rank {}'.format(epoch, cv_loss, rank))
Expand Down
4 changes: 2 additions & 2 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,5 +641,5 @@ def padding(data):
batch_first=True,
padding_value=-1)

yield (sorted_keys, padded_feats, padding_labels, feats_lengths,
label_lengths)
yield {"keys": sorted_keys, "feats": padded_feats, "target": padding_labels,
"feats_lengths": feats_lengths, "target_lengths": label_lengths}
27 changes: 6 additions & 21 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Executor:
def __init__(self):
self.step = 0

def train(self, model, optimizer, scheduler, data_loader, device, writer,
def train(self, model, optimizer, scheduler, data_loader, writer,
configs, scaler, group_join):
''' Train one epoch
'''
Expand All @@ -48,21 +48,13 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer,
model_context = nullcontext

with model_context():
for batch_idx, batch in enumerate(data_loader):
for batch_idx, batch_dict in enumerate(data_loader):
info_dict["step"] = self.step
info_dict["batch_idx"] = batch_idx
if wenet_join(group_join, info_dict):
break

key, feats, target, feats_lengths, target_lengths = batch

batch_dict = {}
batch_dict["feats"] = feats.to(device)
batch_dict["target"] = target.to(device)
batch_dict["feats_lengths"] = feats_lengths.to(device)
batch_dict["target_lengths"] = target_lengths.to(device)

if target_lengths.size(0) == 0:
if batch_dict["target_lengths"].size(0) == 0:
continue

context = None
Expand All @@ -88,26 +80,19 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer,
log_per_step(writer, info_dict)
self.step += 1

def cv(self, model, data_loader, device, configs):
def cv(self, model, data_loader, configs):
''' Cross validation on
'''
model.eval()
info_dict = copy.deepcopy(configs)
info_dict["tag"] = "CV"
num_seen_utts, total_loss = 1, 0.0 # in order to avoid division by 0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
for batch_idx, batch_dict in enumerate(data_loader):
info_dict["step"] = self.step
info_dict["batch_idx"] = batch_idx
key, feats, target, feats_lengths, target_lengths = batch

batch_dict = {}
batch_dict["feats"] = feats.to(device)
batch_dict["target"] = target.to(device)
batch_dict["feats_lengths"] = feats_lengths.to(device)
batch_dict["target_lengths"] = target_lengths.to(device)

num_utts = target_lengths.size(0)
num_utts = batch_dict["target_lengths"].size(0)
if num_utts == 0:
continue

Expand Down
13 changes: 9 additions & 4 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def wenet_join(group_join, info_dict):

def batch_forward(model, batch, scaler, info_dict):
train_engine = info_dict.get('train_engine', "torch_ddp")
device = int(os.environ.get('LOCAL_RANK', 0))
accum_grad = info_dict.get('accum_grad', 1)

dtype = info_dict.get("dtype", "fp32")
Expand All @@ -431,16 +432,20 @@ def batch_forward(model, batch, scaler, info_dict):
with torch.cuda.amp.autocast(
enabled=dtype is not None, dtype=dtype, cache_enabled=False
):
loss_dict = model(batch["feats"], batch["feats_lengths"],
batch["target"], batch["target_lengths"])
loss_dict = model(batch["feats"].to(device),
batch["feats_lengths"].to(device),
batch["target"].to(device),
batch["target_lengths"].to(device))
else:
# torch_ddp
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
with torch.cuda.amp.autocast(scaler is not None):
loss_dict = model(batch["feats"], batch["feats_lengths"],
batch["target"], batch["target_lengths"])
loss_dict = model(batch["feats"].to(device),
batch["feats_lengths"].to(device),
batch["target"].to(device),
batch["target_lengths"].to(device))
info_dict['loss_dict'] = loss_dict

return info_dict
Expand Down

0 comments on commit b8a3340

Please sign in to comment.