Skip to content

Commit

Permalink
Adds torch.cuda.set_device calls to DDP examples (pytorch#1142)
Browse files Browse the repository at this point in the history
Add set_device calls to DDP examples
  • Loading branch information
subramen authored May 15, 2023
1 parent 6a64939 commit 79ef786
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions distributed/ddp-tutorial-series/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def ddp_setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

class Trainer:
def __init__(
Expand Down
1 change: 1 addition & 0 deletions distributed/ddp-tutorial-series/multigpu_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

class Trainer:
def __init__(
Expand Down
1 change: 1 addition & 0 deletions distributed/ddp-tutorial-series/multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

class Trainer:
def __init__(
Expand Down
1 change: 1 addition & 0 deletions distributed/minGPT-ddp/mingpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def get_train_objs(gpt_cfg: GPTConfig, opt_cfg: OptimizerConfig, data_cfg: DataConfig):
dataset = CharDataset(data_cfg)
Expand Down

0 comments on commit 79ef786

Please sign in to comment.