Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Add distributed training examples of PyTorch #4821

Merged
merged 56 commits into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
c612608
imagenet-nccl for test
vvfreesoul Aug 18, 2020
9b8ce66
imagenet-nccl for test
vvfreesoul Aug 21, 2020
e6772d3
imagenet-nccl for test
vvfreesoul Aug 21, 2020
b1f5b8c
imagenet-nccl for test
vvfreesoul Aug 21, 2020
31a46c8
imagenet-nccl for test
vvfreesoul Aug 23, 2020
9057564
imagenet-nccl for test
vvfreesoul Aug 23, 2020
610a420
imagenet-nccl for test
vvfreesoul Aug 23, 2020
da4b007
imagenet-nccl for test
vvfreesoul Aug 23, 2020
6a5fc8c
imagenet-nccl for test
vvfreesoul Aug 23, 2020
f51c5aa
imagenet-nccl for test
vvfreesoul Aug 23, 2020
b4f03fe
imagenet-nccl for test
vvfreesoul Aug 23, 2020
e18c9f8
imagenet-nccl for test
vvfreesoul Aug 23, 2020
cf7c284
imagenet-nccl for test
vvfreesoul Aug 23, 2020
3a84055
Add distributed training examples of PyTorch
vvfreesoul Aug 24, 2020
4ad2f85
Add distributed training examples of PyTorch
vvfreesoul Aug 24, 2020
43a11d2
Add distributed training examples of PyTorch
vvfreesoul Aug 24, 2020
2e59d33
Add distributed training examples of PyTorch
vvfreesoul Aug 24, 2020
ed0a7c6
Add distributed training examples of PyTorch
vvfreesoul Aug 24, 2020
6ac0633
Add distributed training examples of PyTorch
vvfreesoul Aug 25, 2020
562c448
Add distributed training examples of PyTorch
vvfreesoul Aug 25, 2020
e4b5dd1
Add distributed training examples of PyTorch
vvfreesoul Aug 26, 2020
ce8b3ce
Add distributed training examples of PyTorch
vvfreesoul Aug 26, 2020
0fd1f19
Add distributed training examples of PyTorch
vvfreesoul Aug 26, 2020
7db6cbd
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
f46a663
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
4519685
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
326b051
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
d9f2d8d
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
4bdb7c5
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
4cbb352
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
2c488f5
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
9a93e9f
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
1bb98ac
Add distributed training examples of PyTorch
vvfreesoul Aug 31, 2020
353bfdf
Add distributed training examples of PyTorch
vvfreesoul Sep 2, 2020
078d645
Add distributed training examples of PyTorch
vvfreesoul Sep 2, 2020
4efc9ac
Add distributed training examples of PyTorch
vvfreesoul Sep 2, 2020
6373f3a
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
429a6e9
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
f8fa108
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
659c48b
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
0037ab4
Merge remote-tracking branch 'origin/master'
vvfreesoul Sep 4, 2020
f957c60
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
863eda6
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
640c193
Add distributed training examples of PyTorch
vvfreesoul Sep 4, 2020
42cda8e
Add distributed training examples of PyTorch
vvfreesoul Sep 7, 2020
8c2c599
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
eed7c7f
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
adeb4c6
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
1f675a1
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
f0242c7
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
a54c606
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
f494dcf
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
c46c462
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
b18d0df
Add distributed training examples of PyTorch
vvfreesoul Sep 9, 2020
f585648
Add distributed training examples of PyTorch
vvfreesoul Sep 10, 2020
853d112
Add distributed training examples of PyTorch
vvfreesoul Sep 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/manual/cluster-user/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This manual is for cluster users to learn how to submit job, debug job, manage d
3. [How to Manage Data](./how-to-manage-data.md)
4. [How to Debug Jobs](./how-to-debug-jobs.md)
5. [How to Use Advanced Job Settings](./how-to-use-advanced-job-settings.md)
6. [Use Marketplace](./use-marketplace.md)
7. [Use VSCode Extension](./use-vscode-extension.md)
8. [Use Jupyter Notebook Extension](./use-jupyter-notebook-extension.md)
6. [How to Run Distributed Job](./how-to-run-distributed-job.md)
7. [Use Marketplace](./use-marketplace.md)
8. [Use VSCode Extension](./use-vscode-extension.md)
9. [Use Jupyter Notebook Extension](./use-jupyter-notebook-extension.md)
33 changes: 33 additions & 0 deletions docs/manual/cluster-user/how-to-run-distributed-job.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## How OpenPAI Deploy Distributed Jobs
### Taskrole and Instance
When we execute distributed programs on PAI, we can add different task roles for our job. For single server jobs, there is only one task role. For distributed jobs, there may be multiple task roles. For example, when TensorFlow is used to running distributed jobs, it has two roles, including the parameter server and the worker. In distributed jobs, each role may have one or more instances. For example, if it's 8 instances in a worker role of TensorFlow. It means there should be 8 Docker containers for the worker role. Please visit [here](./how-to-use-advanced-job-settings.md#multiple-task-roles) for specific operations.

### Environmental variables
In a distributed job, one task might communicate with others (When we say task, we mean a single instance of a task role). So a task need to be aware of other tasks' runtime information such as IP, port, etc. The system exposes such runtime information as environment variables to each task's Docker container. For mutual communication, users can write code in the container to access those runtime environment variables. Please visit [here](./how-to-use-advanced-job-settings.md#environmental-variables-and-port-reservation) for specific operations.

### Retry policy and Completion policy
If unknown error happens, PAI will retry the job according to user settings. To set a retry policy and completion policy for user's job,PAI asks user to switch to Advanced mode. Please visit [here](./how-to-use-advanced-job-settings.md#job-exit-spec-retry-policy-and-completion-policy) for specific operations.
### Run PyTorch Distributed Jobs in OpenPAI
Example Name | Multi-GPU | Multi-Node | Backend |Apex| Job protocol |
---|---|---|---|---|---|
Single-Node DataParallel CIFAR-10 | ✓| x | -|-| [cifar10-single-node-gpus-cpu-DP.yaml](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-node-gpus-cpu-DP.yaml)|
cifar10-single-mul-DDP-gloo.yaml | ✓| ✓ | gloo|-| [cifar10-single-mul-DDP-gloo.yaml](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-mul-DDP-gloo.yaml)|
cifar10-single-mul-DDP-nccl | ✓| ✓ |nccl|-| [cifar10-single-mul-DDP-nccl.yaml](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-mul-DDP-nccl.yaml)|
cifar10-single-mul-DDP-gloo-Apex-mixed | ✓| ✓ | gloo|✓ | [cifar10-single-mul-DDP-gloo-Apex-mixed.yaml](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-mul-DDP-gloo-Apex-mixed.yaml)|
cifar10-single-mul-DDP-nccl-Apex-mixed | ✓| ✓ | nccl| ✓ | [cifar10-single-mul-DDP-gloo-Apex-mixed.yaml](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-mul-DDP-gloo-Apex-mixed.yaml)|
imagenet-single-mul-DDP-gloo | ✓| ✓| gloo|-| [imagenet-single-mul-DDP-gloo.yaml](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/Lite-imagenet-single-mul-DDP-gloo.yaml)|
## DataParallel
The single node program is simple. The program executed in PAI is exactly the same as the program in our machine. It should be noted that an Worker can be applied in PAI and a Instance can be applied in Worker. In a worker, we can apply for GPUs that we need. We provide an [example](../../../examples/Distributed-example/cifar10-single-node-gpus-cpu-DP.py) of DP.

## DistributedDataParallel
DDP requires users set a master node ip and port for synchronization in PyTorch. For the port, you can simply set one certain port, such as `5000` as your master port. However, this port may conflict with others. To prevent port conflict, you can reserve a port in OpenPAI, as we mentioned [here](./how-to-use-advanced-job-settings.md#environmental-variables-and-port-reservation). The port you reserved is available in environmental variables like `PAI_PORT_LIST_$taskRole_$taskIndex_$portLabel`, where `$taskIndex` means the instance index of that task role. For example, if your task role name is `work` and port label is `SyncPort`, you can add the following code in your PyTorch DDP program:

```
os.environ['MASTER_ADDR'] = os.environ['PAI_HOST_IP_worker_0']
os.environ['MASTER_PORT'] = os.environ['PAI_worker_0_SynPort_PORT']
```
If you are using `gloo` as your DDP communication backend, please set correct network interface such as `export GLOO_SOCKET_IFNAME=eth0`.


We provide examples with [gloo](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-mul-DDP-gloo.yaml) and [nccl](https://github.com/microsoft/pai/tree/master/examples/Distributed-example/cifar10-single-mul-DDP-nccl.yaml) as backend.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
protocolVersion: 2
name: imagenet-gloo_8ba8ed42_7606233c
type: job
jobRetryCount: 0
prerequisites:
- type: dockerimage
uri: 'openpai/standard:python_3.6-pytorch_1.2.0-gpu'
name: docker_image_0
taskRoles:
worker:
instances: 2
completion:
minFailedInstances: 1
taskRetryCount: 0
dockerImage: docker_image_0
resourcePerInstance:
gpu: 4
cpu: 16
memoryMB: 32768
ports:
SynPort: 1
commands:
- export GLOO_SOCKET_IFNAME=eth0
- 'git clone https://github.com/NVIDIA/apex'
- cd apex
- python setup.py install
- cd ..
- apt update
- apt install -y nfs-common
- mkdir -p /mnt/data
- 'mount 10.151.40.32:/mnt/ImagenetData /mnt/data'
- >-
wget
https://raw.githubusercontent.com/microsoft/pai/master/examples/Distributed-example/Lite-imagenet-single-mul-DDP-nccl-gloo.py
- >-
python Lite-imagenet-single-mul-DDP-nccl-gloo.py -n 2 -g 4
--dist-backend gloo --epochs 2 /mnt/data/imagenet/unzipped
defaults:
virtualCluster: default
extras:
com.microsoft.pai.runtimeplugin:
- plugin: ssh
parameters:
jobssh: true
userssh: {}
hivedScheduler:
taskRoles:
worker:
skuNum: 1
skuType: null
117 changes: 117 additions & 0 deletions examples/Distributed-example/Lite-imagenet-single-mul-DDP-nccl-gloo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp

import torchvision.datasets as datasets
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
def main():
print('run main')
parser = argparse.ArgumentParser()
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
args = parser.parse_args()
args.world_size = args.gpus * args.nodes
print('world_size:',args.world_size)
os.environ['MASTER_ADDR'] = os.environ['PAI_HOST_IP_worker_0']
os.environ['MASTER_PORT'] = os.environ['PAI_worker_0_SynPort_PORT']
print('master:', os.environ['MASTER_ADDR'], 'port:', os.environ['MASTER_PORT'])
mp.spawn(train, nprocs=args.gpus, args=(args,))

def train(gpu, args):
print("start train")
rank = int(os.environ['PAI_TASK_INDEX']) * args.gpus + gpu
dist.init_process_group(backend=args.dist_backend, init_method='env://', world_size=args.world_size, rank=rank)
torch.manual_seed(0)
model=model = models.__dict__[args.arch]()
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = args.batch_size
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
# Wrap the model
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.nodes, pin_memory=True, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.nodes, pin_memory=True)
start = datetime.now()
total_step = len(train_loader)
for epoch in range(args.epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)

# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
#if (i + 1) % 100 == 0 and gpu == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
loss.item()))
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp

import torchvision.datasets as datasets
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
def main():
print('run main')
parser = argparse.ArgumentParser()
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
args = parser.parse_args()
args.world_size = args.gpus * args.nodes
print('world_size:',args.world_size)
os.environ['MASTER_ADDR'] = os.environ['PAI_HOST_IP_worker_0']
os.environ['MASTER_PORT'] = os.environ['PAI_worker_0_SynPort_PORT']
print('master:', os.environ['MASTER_ADDR'], 'port:', os.environ['MASTER_PORT'])
mp.spawn(train, nprocs=args.gpus, args=(args,))

def train(gpu, args):
print("start train")
rank = int(os.environ['PAI_TASK_INDEX']) * args.gpus + gpu
dist.init_process_group(backend=args.dist_backend, init_method='env://', world_size=args.world_size, rank=rank)
torch.manual_seed(0)
model=model = models.__dict__[args.arch]()
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = args.batch_size
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
# Wrap the model
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# Wrap the model
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
model = DDP(model)
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.nodes, pin_memory=True, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.nodes, pin_memory=True)
start = datetime.now()
total_step = len(train_loader)
for epoch in range(args.epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)

# Backward and optimize
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
loss.backward()
optimizer.step()
#if (i + 1) % 100 == 0 and gpu == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
loss.item()))
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))

if __name__ == '__main__':
main()
Loading