Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(deepspeed): Refine traning code #2055

Merged
merged 19 commits into from
Nov 2, 2023
Merged

Conversation

xingchensong
Copy link
Member

@xingchensong xingchensong commented Oct 16, 2023

Brief

  • split train.py & executor.py into multiple unified API calls and move those API into train_utils.py
  • launch torchddp & deepspeed via torchrun

TODO (in current PR)

  • check training is correct (both ddp & ds)
  • impl model.join for deepspeed (to seamlessly support batch_type == dynamic and avoid to use filter_data before training)
  • log grad_norm in tensorboard for debug purpose

TODO (in next PR)

Warning

A known issue of deepspeed: microsoft/DeepSpeed#4298

@xingchensong
Copy link
Member Author

xingchensong commented Oct 16, 2023

The Training Pipeline is split into multiple function calls:

    # 1. Read config
    with open(args.config, 'r') as fin:
        configs = yaml.load(fin, Loader=yaml.FullLoader)
    if len(args.override_config) > 0:
        configs = override_config(configs, args.override_config)

    # 2. Init env for ddp OR deepspeed
    world_size, local_rank, rank = init_distributed(args)

    # 3. Do some sanity checks and save config to arsg.model_dir
    configs = check_modify_and_save_config(args, configs)

    # 4. Get dataset & dataloader
    train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
        init_dataset_and_dataloader(args, configs)

    # 5. Init asr model from configs
    infos, model = init_model(args, configs)

    # 6. Check model is jitable & print model archtectures
    trace_and_print_model(model, enable_trace=True)

    # 7. Tensorboard summary
    writer = init_summarywriter(args)

    # 8. Dispatch model from cpu to gpu
    model, device = wrap_cuda_model(args, model)

    # 9. Get optimizer & scheduler
    model, optimizer, scheduler = init_optimizer_and_scheduler(
        args, infos, configs, model)

    # 10. Save checkpoints
    save_model(args, model, tag="init", infos=None)

    # 11. Get executor
    executor = init_executor(infos)

    # 12. Init scaler, used for pytorch amp mixed precision training
    scaler = None
    if args.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    # 13. Start training loop
    for ...

@xingchensong
Copy link
Member Author

To better organize arguments, this PR also split and classify different args into different category.

def get_args():
    parser = argparse.ArgumentParser(description='training your network')
    parser.add_argument...
    parser.add_argument...
    parser.add_argument...

===>

def get_args():
    parser = argparse.ArgumentParser(description='training your network')
    parser = add_model_args(parser)
    parser = add_dataset_args(parser)
    parser = add_ddp_args(parser)
    parser = add_deepspeed_args(parser)

@xingchensong
Copy link
Member Author

xingchensong commented Oct 17, 2023

Test Script (single-node multi-gpu):

# torchddp
bash run.sh \
  --stage 4 --stop_stage 4 \
  --data_type raw --train_set dev \
  --train_engine torch_ddp \
  --dir exp/conformer \
  --tensorboard_dir tensorboard/compare
# deepspeed
bash run.sh \
  --stage 4 --stop_stage 4 \
  --data_type raw --train_set dev \
  --train_engine deepspeed \
  --dir exp/conformer_deepspeed \
  --tensorboard_dir tensorboard/compare

Test Script (multi-node multi-gpu):

# torch_ddp
# without NCCL_IB_DISABLE=1
#   RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL version 2.14.3
# without NCCL_SOCKET_IFNAME=ens10f0
#   RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:29455
# ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=enp NCCL_IB_DISABLE=1 bash run.sh \
  --stage 4 --stop_stage 4 \
  --data_type shard --train_set dev \
  --train_engine torch_ddp \
  --dir exp/conformer_ddp_2nodes \
  --tensorboard_dir tensorboard/compare \
  --HOST_NODE_ADDR gpu-dev052.hogpu.cc:29455 \
  --num_nodes 2
# deepspeed
# without NCCL_IB_DISABLE=1
#   RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL version 2.14.3
# without NCCL_SOCKET_IFNAME=ens10f0
#   RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:29455
# ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=enp NCCL_IB_DISABLE=1 bash run.sh \
  --stage 4 --stop_stage 4 \
  --data_type shard --train_set dev \
  --train_engine deepspeed \
  --dir exp/conformer_ds_2nodes \
  --tensorboard_dir tensorboard/compare \
  --HOST_NODE_ADDR gpu-dev052.hogpu.cc:29455 \
  --num_nodes 2

multi-node training ref: jax-ml/jax#13559 (comment)

Test Result:

torch_ddp (in blue & red) is almost identical to deepspeed (in orange and sky-blue)

image

compare_tensorboard.zip

whiteshirt0429
whiteshirt0429 previously approved these changes Oct 19, 2023
@xingchensong
Copy link
Member Author

xingchensong commented Oct 20, 2023

Impl join for deepspeed:

for batch_idx, batch in enumerate(data_loader):
     if wenet_join(args, device, group_join):
         break
    ...  # training step

where wenet_join is defined as:

def wenet_join(configs, device, group_join):
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    rank = int(os.environ.get('RANK', 0))
    train_engine = configs.get('train_engine', "torch_ddp")

    if train_engine != "deepspeed":
        return False

    try:
        # NOTE(xcsong): Why we need a new group?
        #   Because Deepspeed has its own group where all the relevant communication
        #   operations are executed. If we add a communication operation that is not
        #   managed by Deepspeed in this group, it's highly likely to cause
        #   communication chaos, resulting in hard-to-troubleshoot hangs.
        dist.monitored_barrier(group=group_join,
                               timeout=datetime.timedelta(seconds=30))
    except RuntimeError as e:
        logging.info("Detected uneven workload distribution: {}\n".format(e) +
                     "Break current worker to manually join all workers, " +
                     "world_size {}, current rank {}, current local_rank {}".format(
                         world_size, rank, local_rank))
        return True

    return False

Now, without filtering data, we can continue deepspeed training even if there has uneven data:

e8980335-848d-4fef-8e61-4f7a34c07347

wenet/bin/train.py Outdated Show resolved Hide resolved
wenet/bin/train.py Outdated Show resolved Hide resolved
@xingchensong
Copy link
Member Author

The lines of the core files:

  • wenet/bin/train.py: 425 -> 157
  • wenet/utils/excutor.py: 204 -> 122

Mddct
Mddct previously approved these changes Oct 24, 2023
@wenet-e2e wenet-e2e deleted a comment from Mddct Oct 27, 2023
@kobenaxie
Copy link
Contributor

kobenaxie commented Oct 30, 2023

重构的版本有出现显存占用变多,训练变慢的情况吗(torch_ddp,torch==1.12.0)?

@xingchensong
Copy link
Member Author

xingchensong commented Oct 31, 2023

重构的版本有出现显存占用变多,训练变慢的情况吗(torch_ddp)?

有对比数据吗?是不是从某个checkpoint开始恢复训练的(如果是,那么现象估计是0号卡显存明显高于其他卡,要修改checkpoint.py的load checkpoint函数,torch.load要用cpu,#2091 )训练变慢是因为torchddp也用了wenetjoin,额外增加了30s超时,小数据上可能会比较明显地增加端到端时间 (最新commit已经为torch_dpp skip 了 wenet_join,ce8850f)

@kobenaxie
Copy link
Contributor

重构的版本有出现显存占用变多,训练变慢的情况吗(torch_ddp)?

version 显存占用 训练时长(100 batch)
Base 15~20G 35 s
Refactor 35~40G 70 s

我这边观察到的现象是这样的,没有加载模型,也跳过了wenet_join(),你们没有出现这个问题是吗?

@Mddct Mddct self-requested a review October 31, 2023 03:20
@xingchensong
Copy link
Member Author

xingchensong commented Oct 31, 2023

重构的版本有出现显存占用变多,训练变慢的情况吗(torch_ddp)?

version 显存占用 训练时长(100 batch)
Base 1520G 35 s
Refactor 35
40G 70 s
我这边观察到的现象是这样的,没有加载模型,也跳过了wenet_join(),你们没有出现这个问题是吗?

没有,我这一直用2080ti做的实验,原来的recipe,batch最高开到16,refactor之后同样的配置是可以train的,方便加个微信吗,沟通更快点,微信号:currycode

@kobenaxie
Copy link
Contributor

重构的版本有出现显存占用变多,训练变慢的情况吗(torch_ddp)?

version 显存占用 训练时长(100 batch)
Base 1520G 35 s
Refactor 35
40G 70 s
我这边观察到的现象是这样的,没有加载模型,也跳过了wenet_join(),你们没有出现这个问题是吗?

没有,我这一直用2080ti做的实验,原来的recipe,batch最高开到16,refactor之后同样的配置是可以train的,方便加个微信吗,沟通更快点,微信号:currycode

好的,加你了,麻烦通过一下。

@xingchensong
Copy link
Member Author

update: 8*2080ti, torch 1.13.0,测试结果显示新旧代码loss曲线和训练耗时是一致的

image

@kobenaxie
Copy link
Contributor

torch从1.12.0更新到1.13.0之后显存占用跟速度正常了

@xingchensong
Copy link
Member Author

@whiteshirt0429 @robin1001 @Mddct, I think this PR is ready for a final review

@xingchensong xingchensong added the enhancement New feature or request label Nov 2, 2023
@robin1001
Copy link
Collaborator

Great job! 后续我们精简注释,把如何做多机多卡支持写到文档引用就行。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants