-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(deepspeed): Refine traning code #2055
Conversation
The Training Pipeline is split into multiple function calls: # 1. Read config
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
# 2. Init env for ddp OR deepspeed
world_size, local_rank, rank = init_distributed(args)
# 3. Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# 4. Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
# 5. Init asr model from configs
infos, model = init_model(args, configs)
# 6. Check model is jitable & print model archtectures
trace_and_print_model(model, enable_trace=True)
# 7. Tensorboard summary
writer = init_summarywriter(args)
# 8. Dispatch model from cpu to gpu
model, device = wrap_cuda_model(args, model)
# 9. Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(
args, infos, configs, model)
# 10. Save checkpoints
save_model(args, model, tag="init", infos=None)
# 11. Get executor
executor = init_executor(infos)
# 12. Init scaler, used for pytorch amp mixed precision training
scaler = None
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
# 13. Start training loop
for ... |
To better organize arguments, this PR also split and classify different args into different category. def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument...
parser.add_argument...
parser.add_argument... ===> def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser = add_model_args(parser)
parser = add_dataset_args(parser)
parser = add_ddp_args(parser)
parser = add_deepspeed_args(parser) |
Test Script (single-node multi-gpu):# torchddp
bash run.sh \
--stage 4 --stop_stage 4 \
--data_type raw --train_set dev \
--train_engine torch_ddp \
--dir exp/conformer \
--tensorboard_dir tensorboard/compare # deepspeed
bash run.sh \
--stage 4 --stop_stage 4 \
--data_type raw --train_set dev \
--train_engine deepspeed \
--dir exp/conformer_deepspeed \
--tensorboard_dir tensorboard/compare Test Script (multi-node multi-gpu):# torch_ddp
# without NCCL_IB_DISABLE=1
# RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL version 2.14.3
# without NCCL_SOCKET_IFNAME=ens10f0
# RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:29455
# ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=enp NCCL_IB_DISABLE=1 bash run.sh \
--stage 4 --stop_stage 4 \
--data_type shard --train_set dev \
--train_engine torch_ddp \
--dir exp/conformer_ddp_2nodes \
--tensorboard_dir tensorboard/compare \
--HOST_NODE_ADDR gpu-dev052.hogpu.cc:29455 \
--num_nodes 2 # deepspeed
# without NCCL_IB_DISABLE=1
# RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL version 2.14.3
# without NCCL_SOCKET_IFNAME=ens10f0
# RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:29455
# ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=enp NCCL_IB_DISABLE=1 bash run.sh \
--stage 4 --stop_stage 4 \
--data_type shard --train_set dev \
--train_engine deepspeed \
--dir exp/conformer_ds_2nodes \
--tensorboard_dir tensorboard/compare \
--HOST_NODE_ADDR gpu-dev052.hogpu.cc:29455 \
--num_nodes 2 multi-node training ref: jax-ml/jax#13559 (comment) Test Result:
|
Impl for batch_idx, batch in enumerate(data_loader):
if wenet_join(args, device, group_join):
break
... # training step where def wenet_join(configs, device, group_join):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
train_engine = configs.get('train_engine', "torch_ddp")
if train_engine != "deepspeed":
return False
try:
# NOTE(xcsong): Why we need a new group?
# Because Deepspeed has its own group where all the relevant communication
# operations are executed. If we add a communication operation that is not
# managed by Deepspeed in this group, it's highly likely to cause
# communication chaos, resulting in hard-to-troubleshoot hangs.
dist.monitored_barrier(group=group_join,
timeout=datetime.timedelta(seconds=30))
except RuntimeError as e:
logging.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
"world_size {}, current rank {}, current local_rank {}".format(
world_size, rank, local_rank))
return True
return False Now, without |
The lines of the core files:
|
重构的版本有出现显存占用变多,训练变慢的情况吗(torch_ddp,torch==1.12.0)? |
我这边观察到的现象是这样的,没有加载模型,也跳过了wenet_join(),你们没有出现这个问题是吗? |
没有,我这一直用2080ti做的实验,原来的recipe,batch最高开到16,refactor之后同样的配置是可以train的,方便加个微信吗,沟通更快点,微信号:currycode |
好的,加你了,麻烦通过一下。 |
torch从1.12.0更新到1.13.0之后显存占用跟速度正常了 |
@whiteshirt0429 @robin1001 @Mddct, I think this PR is ready for a final review |
Great job! 后续我们精简注释,把如何做多机多卡支持写到文档引用就行。 |
Brief
train.py
&executor.py
into multiple unified API calls and move those API intotrain_utils.py
torchddp
&deepspeed
via torchrunTODO (in current PR)
batch_type == dynamic
and avoid to usefilter_data
before training)TODO (in next PR)
processor.py::padding
to avoid make new dict inexcutor.py
refactor(dataset): return dict instead of tuple #2106Warning
A known issue of deepspeed: microsoft/DeepSpeed#4298