From 6070b55443d14ae480a0f359f3aff45308e7341d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 23 Jul 2019 17:46:01 +0200 Subject: [PATCH] fix #868 --- examples/run_glue.py | 13 +++++++------ examples/run_squad.py | 13 +++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index b383bbcb808856..5d9abd06fc6e56 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -92,6 +92,12 @@ def train(args, train_dataset, model, tokenizer): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) @@ -411,13 +417,8 @@ def main(): if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab - # Distributed and parallel training model.to(args.device) - if args.local_rank != -1: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True) - elif args.n_gpu > 1: + if args.n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("Training/evaluation parameters %s", args) diff --git a/examples/run_squad.py b/examples/run_squad.py index 53ea0bfd648c3a..36e03fb012bf7c 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -101,6 +101,12 @@ def train(args, train_dataset, model, tokenizer): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) @@ -450,13 +456,8 @@ def main(): if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab - # Distributed and parrallel training model.to(args.device) - if args.local_rank != -1: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True) - elif args.n_gpu > 1: + if args.n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("Training/evaluation parameters %s", args)