From c96a8b3e91d0a6cdbb8b103fe84b1374e94053f9 Mon Sep 17 00:00:00 2001 From: shhssdm Date: Sun, 11 Apr 2021 18:22:36 -0500 Subject: [PATCH 1/6] Add pytorch-direct version --- .../train_sampling_pytorch_direct.py | 434 ++++++++++++++++++ examples/pytorch/graphsage/utils.py | 39 +- 2 files changed, 472 insertions(+), 1 deletion(-) create mode 100644 examples/pytorch/graphsage/train_sampling_pytorch_direct.py diff --git a/examples/pytorch/graphsage/train_sampling_pytorch_direct.py b/examples/pytorch/graphsage/train_sampling_pytorch_direct.py new file mode 100644 index 000000000000..7d6e12bf28b8 --- /dev/null +++ b/examples/pytorch/graphsage/train_sampling_pytorch_direct.py @@ -0,0 +1,434 @@ +import dgl +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.multiprocessing as mp +import dgl.nn.pytorch as dglnn +import time +import math +import argparse +from torch.nn.parallel import DistributedDataParallel +import tqdm +import utils + + +from utils import thread_wrapped_func +from load_graph import load_reddit, inductive_split + +class SAGE(nn.Module): + def __init__(self, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super().__init__() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_classes = n_classes + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + for i in range(1, n_layers - 1): + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(dropout) + self.activation = activation + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + return h + + def inference(self, g, x, device): + """ + Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). + g : the entire graph. + x : the input of entire node set. + + The inference code is written in a fashion that it could handle any number of nodes and + layers. + """ + # During inference with sampling, multi-layer blocks are very inefficient because + # lots of computations in the first few layers are repeated. + # Therefore, we compute the representation of all nodes layer by layer. The nodes + # on each layer are of course splitted in batches. + # TODO: can we standardize this? + for l, layer in enumerate(self.layers): + y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) + + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) + dataloader = dgl.dataloading.NodeDataLoader( + g, + th.arange(g.num_nodes()), + sampler, + batch_size=args.batch_size, + shuffle=True, + drop_last=False, + num_workers=args.num_workers) + + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + block = blocks[0] + + block = block.int().to(device) + h = x[input_nodes].to(device) + h = layer(block, h) + if l != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + + y[output_nodes] = h.cpu() + + x = y + return y + +def compute_acc(pred, labels): + """ + Compute the accuracy of prediction given the labels. + """ + return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) + +def evaluate(model, g, nfeat, labels, val_nid, device): + """ + Evaluate the model on the validation set specified by ``val_nid``. + g : The entire graph. + inputs : The features of all the nodes. + labels : The labels of all the nodes. + val_nid : A node ID tensor indicating which nodes do we actually compute the accuracy for. + device : The GPU device to evaluate on. + """ + model.eval() + with th.no_grad(): + pred = model.inference(g, nfeat, device) + model.train() + return compute_acc(pred[val_nid], labels[val_nid]) + +def load_subtensor(nfeat, labels, seeds, input_nodes, device): + """ + Extracts features and labels for a subset of nodes. + """ + batch_inputs = nfeat[input_nodes].to(device) + batch_labels = labels[seeds].to(device) + return batch_inputs, batch_labels + +def producer(q, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2, train_nfeat, train_labels, feat_dimension, label_dimension, device): + th.cuda.set_device(device) + + # Map input tensors into GPU address + train_nfeat = train_nfeat.to(device="unified") + train_labels = train_labels.to(device="unified") + + # Create GPU-side ping pong buffers + in_feat1 = th.zeros(feat_dimension, device=device) + in_feat2 = th.zeros(feat_dimension, device=device) + in_label1 = th.zeros(label_dimension, dtype=th.long, device=device) + in_label2 = th.zeros(label_dimension, dtype=th.long, device=device) + + # Termination signal + finish = th.ones(1, dtype=th.bool) + + # Share with the training process + q.put((in_feat1, in_feat2, in_label1, in_label2, finish)) + print("Allocation done") + + flag = 1 + + with th.no_grad(): + while(1): + event1.wait() + event1.clear() + if not finish: + break + if flag: + th.index_select(train_nfeat, 0, idxf1[0:idxf1_len].to(device=device), out=in_feat1[0:idxf1_len]) + th.index_select(train_labels, 0, idxl1[0:idxl1_len].to(device=device), out=in_label1[0:idxl1_len]) + else: + th.index_select(train_nfeat, 0, idxf2[0:idxf2_len].to(device=device), out=in_feat2[0:idxf2_len]) + th.index_select(train_labels, 0, idxl2[0:idxl2_len].to(device=device), out=in_label2[0:idxl2_len]) + flag = (flag == False) + th.cuda.synchronize() + event2.set() + +#### Entry point + +def run(q, args, device, data, in_feats, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2): + th.cuda.set_device(device) + + # Unpack data + n_classes, train_g, val_g, test_g = data + + train_mask = train_g.ndata['train_mask'] + val_mask = val_g.ndata['val_mask'] + test_mask = ~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']) + train_nid = train_mask.nonzero().squeeze() + val_nid = val_mask.nonzero().squeeze() + test_nid = test_mask.nonzero().squeeze() + + # Create PyTorch DataLoader for constructing blocks + sampler = dgl.dataloading.MultiLayerNeighborSampler( + [int(fanout) for fanout in args.fan_out.split(',')]) + dataloader = dgl.dataloading.NodeDataLoader( + train_g, + train_nid, + sampler, + batch_size=args.batch_size, + shuffle=True, + drop_last=False, + num_workers=args.num_workers) + + # Define model and optimizer + model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) + model = model.to(device) + loss_fcn = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + in_feat1, in_feat2, in_label1, in_label2, finish = q.get() + + # A prologue for the pipelining purpose, just for the first minibatch of the first epoch + # ------------------------------------------------------ + flag = True + input_nodes, seeds, blocks_next = next(iter(dataloader)) + + # Send node indices for the next minibatch to the producer + if flag: + idxf1[0:len(input_nodes)].copy_(input_nodes) + idxl1[0:len(seeds)].copy_(seeds) + idxf1_len.fill_(len(input_nodes)) + idxl1_len.fill_(len(seeds)) + else: + idxf2[0:len(input_nodes)].copy_(input_nodes) + idxl2[0:len(seeds)].copy_(seeds) + idxf2_len.fill_(len(input_nodes)) + idxl2_len.fill_(len(seeds)) + event1.set() + time.sleep(1) + + input_nodes_n = len(input_nodes) + seeds_n = len(seeds) + flag = (flag == False) + blocks_temp = blocks_next + # ------------------------------------------------------ + # Prologue done + + # Training loop + avg = 0 + iter_tput = [] + for epoch in range(args.num_epochs): + tic = time.time() + + # Loop over the dataloader to sample the computation dependency graph as a list of + # blocks. + for step, (input_nodes, seeds, blocks_next) in enumerate(dataloader): + tic_step = time.time() + + # Send node indices for the next minibatch to the producer + if flag: + idxf1[0:len(input_nodes)].copy_(input_nodes) + idxl1[0:len(seeds)].copy_(seeds) + idxf1_len.fill_(len(input_nodes)) + idxl1_len.fill_(len(seeds)) + else: + idxf2[0:len(input_nodes)].copy_(input_nodes) + idxl2[0:len(seeds)].copy_(seeds) + idxf2_len.fill_(len(input_nodes)) + idxl2_len.fill_(len(seeds)) + + event1.set() + + event2.wait() + event2.clear() + + # Load the input features as well as output labels + if not flag: + batch_inputs = in_feat1[0:input_nodes_n] + batch_labels = in_label1[0:seeds_n] + else: + batch_inputs = in_feat2[0:input_nodes_n] + batch_labels = in_label2[0:seeds_n] + + blocks = [block.int().to(device) for block in blocks_temp] + # Compute loss and prediction + batch_pred = model(blocks, batch_inputs) + loss = loss_fcn(batch_pred, batch_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + flag = (flag == False) + input_nodes_n = len(input_nodes) + seeds_n = len(seeds) + blocks_temp = blocks_next + + iter_tput.append(len(seeds) / (time.time() - tic_step)) + if step % args.log_every == 0: + acc = compute_acc(batch_pred, batch_labels) + print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format( + epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), th.cuda.max_memory_allocated() / 1000000)) + + # A prologue for the next epoch + # ------------------------------------------------------ + input_nodes, seeds, blocks_next = next(iter(dataloader)) + + if flag: + idxf1[0:len(input_nodes)].copy_(input_nodes) + idxl1[0:len(seeds)].copy_(seeds) + idxf1_len.fill_(len(input_nodes)) + idxl1_len.fill_(len(seeds)) + else: + idxf2[0:len(input_nodes)].copy_(input_nodes) + idxl2[0:len(seeds)].copy_(seeds) + idxf2_len.fill_(len(input_nodes)) + idxl2_len.fill_(len(seeds)) + event1.set() + + event2.wait() + event2.clear() + + # Load the input features as well as output labels + if not flag: + batch_inputs = in_feat1[0:input_nodes_n] + batch_labels = in_label1[0:seeds_n] + else: + batch_inputs = in_feat2[0:input_nodes_n] + batch_labels = in_label2[0:seeds_n] + + # Compute loss and prediction + blocks = [block.int().to(device) for block in blocks_temp] + batch_pred = model(blocks, batch_inputs) + loss = loss_fcn(batch_pred, batch_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + flag = (flag == False) + input_nodes_n = len(input_nodes) + seeds_n = len(seeds) + blocks_temp = blocks_next + # ------------------------------------------------------ + # Prologue done + + toc = time.time() + print('Epoch Time(s): {:.4f}'.format(toc - tic)) + if epoch >= 5: + avg += toc - tic + if epoch % args.eval_every == 0 and epoch != 0: + eval_acc = evaluate( + model, val_g, val_nfeat, val_labels, val_nid, device) + test_acc = evaluate( + model, test_g, test_nfeat, test_labels, test_nid, device) + print('Eval Acc {:.4f}'.format(eval_acc)) + print('Test Acc: {:.4f}'.format(test_acc)) + + print('Avg epoch time: {}'.format(avg / (epoch - 4))) + + # Send a termination signal to the producer + finish.copy_(th.zeros(1, dtype=th.bool)) + event1.set() + +if __name__ == '__main__': + argparser = argparse.ArgumentParser("multi-gpu training") + argparser.add_argument('--gpu', type=int, default=0, + help="GPU device ID. Use -1 for CPU training") + argparser.add_argument('--num-epochs', type=int, default=20) + argparser.add_argument('--num-hidden', type=int, default=16) + argparser.add_argument('--num-layers', type=int, default=2) + argparser.add_argument('--fan-out', type=str, default='10,25') + argparser.add_argument('--batch-size', type=int, default=1000) + argparser.add_argument('--log-every', type=int, default=20) + argparser.add_argument('--eval-every', type=int, default=5) + argparser.add_argument('--lr', type=float, default=0.003) + argparser.add_argument('--dropout', type=float, default=0.5) + argparser.add_argument('--num-workers', type=int, default=4, + help="Number of sampling processes. Use 0 for no extra process.") + argparser.add_argument('--inductive', action='store_true', + help="Inductive learning setting") + argparser.add_argument('--mps', type=str, default='0') + args = argparser.parse_args() + + device = th.device('cuda:%d' % args.gpu) + mps = list(map(str, args.mps.split(','))) + + # If MPS values are given, then setup MPS + if float(mps[0]) != 0: + user_id = utils.mps_get_user_id() + utils.mps_daemon_start() + utils.mps_server_start(user_id) + server_pid = utils.mps_get_server_pid() + time.sleep(4) + + g, n_classes = load_reddit() + # Construct graph + g = dgl.as_heterograph(g) + + if args.inductive: + train_g, val_g, test_g = inductive_split(g) + else: + train_g = val_g = test_g = g + + # Create csr/coo/csc formats before launching training processes with multi-gpu. + # This avoids creating certain formats in each sub-process, which saves momory and CPU. + train_g.create_formats_() + val_g.create_formats_() + test_g.create_formats_() + # Pack data + data = n_classes, train_g, val_g, test_g + + train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features').share_memory_() + train_labels = val_labels = test_labels = g.ndata.pop('labels').share_memory_() + in_feats = train_nfeat.shape[1] + + fanout_max = 1 + for fanout in args.fan_out.split(','): + fanout_max = fanout_max * int(fanout) + + feat_dimension = [args.batch_size * fanout_max, train_nfeat.shape[1]] + label_dimension = [args.batch_size] + + ctx = mp.get_context('spawn') + + if float(mps[0]) != 0: + utils.mps_set_active_thread_percentage(server_pid, mps[0]) + # Just in case we add a timer to make sure MPS setup is done before we launch producer + time.sleep(4) + + # TODO: shared structure declarations can be futher simplified + q = ctx.SimpleQueue() + + # Synchornization signals + event1 = ctx.Event() + event2 = ctx.Event() + + # Indices and the their lengths shared between the producer and the training processes + idxf1 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() + idxf2 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() + idxl1 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() + idxl2 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() + idxf1_len = th.zeros([1], dtype=th.long).share_memory_() + idxf2_len = th.zeros([1], dtype=th.long).share_memory_() + idxl1_len = th.zeros([1], dtype=th.long).share_memory_() + idxl2_len = th.zeros([1], dtype=th.long).share_memory_() + + print("Producer Start") + producer_inst = ctx.Process(target=producer, + args=(q, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2, train_nfeat, train_labels, feat_dimension, label_dimension, device)) + producer_inst.start() + + if float(mps[0]) != 0: + # Just in case we add timers to make sure MPS setup is done before we launch training + time.sleep(8) + utils.mps_set_active_thread_percentage(server_pid, mps[1]) + time.sleep(4) + + print("Run Start") + p = mp.Process(target=thread_wrapped_func(run), + args=(q, args, device, data, in_feats, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2)) + p.start() + + p.join() + producer_inst.join() diff --git a/examples/pytorch/graphsage/utils.py b/examples/pytorch/graphsage/utils.py index f5f8a27c378b..8c7a4f481624 100644 --- a/examples/pytorch/graphsage/utils.py +++ b/examples/pytorch/graphsage/utils.py @@ -10,6 +10,7 @@ import torch.multiprocessing as mp from _thread import start_new_thread from functools import wraps +import subprocess import traceback def thread_wrapped_func(func): @@ -35,4 +36,40 @@ def _queue_result(): else: assert isinstance(exception, Exception) raise exception.__class__(trace) - return decorated_function \ No newline at end of file + return decorated_function + +# Get user id +def mps_get_user_id(): + result = subprocess.run(['id', '-u'], stdout=subprocess.PIPE) + return result.stdout.decode('utf-8').rstrip() + +# Start MPS daemon +def mps_daemon_start(): + result = subprocess.run(['nvidia-cuda-mps-control', '-d'], stdout=subprocess.PIPE) + print(result.stdout.decode('utf-8').rstrip()) + +# Start MPS server with user id +def mps_server_start(user_id): + ps = subprocess.Popen(('echo', 'start_server -uid ' + user_id), stdout=subprocess.PIPE) + output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) + ps.wait() + +# Get created server pid +def mps_get_server_pid(): + ps = subprocess.Popen(('echo', 'get_server_list'), stdout=subprocess.PIPE) + output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) + ps.wait() + return output.decode('utf-8').rstrip() + +# Set active thread percentage with the pid for producer +def mps_set_active_thread_percentage(server_pid, percentage): + ps = subprocess.Popen(('echo', 'set_active_thread_percentage ' + server_pid + ' ' + str(percentage)), stdout=subprocess.PIPE) + output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) + ps.wait() + print('Setting set_active_thread_percentage to', output.decode('utf-8').rstrip()) + +# Quit MPS +def mps_quit(): + ps = subprocess.Popen(('echo', 'quit'), stdout=subprocess.PIPE) + output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) + ps.wait() \ No newline at end of file From beca41cd9be451195c494ea3faee1acea6d08d48 Mon Sep 17 00:00:00 2001 From: davidmin7 Date: Tue, 6 Jul 2021 21:32:10 -0500 Subject: [PATCH 2/6] remove --- .../train_sampling_pytorch_direct.py | 434 ------------------ python/dgl/multiprocessing/pytorch.py | 39 -- 2 files changed, 473 deletions(-) delete mode 100644 examples/pytorch/graphsage/train_sampling_pytorch_direct.py diff --git a/examples/pytorch/graphsage/train_sampling_pytorch_direct.py b/examples/pytorch/graphsage/train_sampling_pytorch_direct.py deleted file mode 100644 index 7d6e12bf28b8..000000000000 --- a/examples/pytorch/graphsage/train_sampling_pytorch_direct.py +++ /dev/null @@ -1,434 +0,0 @@ -import dgl -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torch.multiprocessing as mp -import dgl.nn.pytorch as dglnn -import time -import math -import argparse -from torch.nn.parallel import DistributedDataParallel -import tqdm -import utils - - -from utils import thread_wrapped_func -from load_graph import load_reddit, inductive_split - -class SAGE(nn.Module): - def __init__(self, - in_feats, - n_hidden, - n_classes, - n_layers, - activation, - dropout): - super().__init__() - self.n_layers = n_layers - self.n_hidden = n_hidden - self.n_classes = n_classes - self.layers = nn.ModuleList() - self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) - for i in range(1, n_layers - 1): - self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) - self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) - self.dropout = nn.Dropout(dropout) - self.activation = activation - - def forward(self, blocks, x): - h = x - for l, (layer, block) in enumerate(zip(self.layers, blocks)): - h = layer(block, h) - if l != len(self.layers) - 1: - h = self.activation(h) - h = self.dropout(h) - return h - - def inference(self, g, x, device): - """ - Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). - g : the entire graph. - x : the input of entire node set. - - The inference code is written in a fashion that it could handle any number of nodes and - layers. - """ - # During inference with sampling, multi-layer blocks are very inefficient because - # lots of computations in the first few layers are repeated. - # Therefore, we compute the representation of all nodes layer by layer. The nodes - # on each layer are of course splitted in batches. - # TODO: can we standardize this? - for l, layer in enumerate(self.layers): - y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) - - sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) - dataloader = dgl.dataloading.NodeDataLoader( - g, - th.arange(g.num_nodes()), - sampler, - batch_size=args.batch_size, - shuffle=True, - drop_last=False, - num_workers=args.num_workers) - - for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): - block = blocks[0] - - block = block.int().to(device) - h = x[input_nodes].to(device) - h = layer(block, h) - if l != len(self.layers) - 1: - h = self.activation(h) - h = self.dropout(h) - - y[output_nodes] = h.cpu() - - x = y - return y - -def compute_acc(pred, labels): - """ - Compute the accuracy of prediction given the labels. - """ - return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) - -def evaluate(model, g, nfeat, labels, val_nid, device): - """ - Evaluate the model on the validation set specified by ``val_nid``. - g : The entire graph. - inputs : The features of all the nodes. - labels : The labels of all the nodes. - val_nid : A node ID tensor indicating which nodes do we actually compute the accuracy for. - device : The GPU device to evaluate on. - """ - model.eval() - with th.no_grad(): - pred = model.inference(g, nfeat, device) - model.train() - return compute_acc(pred[val_nid], labels[val_nid]) - -def load_subtensor(nfeat, labels, seeds, input_nodes, device): - """ - Extracts features and labels for a subset of nodes. - """ - batch_inputs = nfeat[input_nodes].to(device) - batch_labels = labels[seeds].to(device) - return batch_inputs, batch_labels - -def producer(q, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2, train_nfeat, train_labels, feat_dimension, label_dimension, device): - th.cuda.set_device(device) - - # Map input tensors into GPU address - train_nfeat = train_nfeat.to(device="unified") - train_labels = train_labels.to(device="unified") - - # Create GPU-side ping pong buffers - in_feat1 = th.zeros(feat_dimension, device=device) - in_feat2 = th.zeros(feat_dimension, device=device) - in_label1 = th.zeros(label_dimension, dtype=th.long, device=device) - in_label2 = th.zeros(label_dimension, dtype=th.long, device=device) - - # Termination signal - finish = th.ones(1, dtype=th.bool) - - # Share with the training process - q.put((in_feat1, in_feat2, in_label1, in_label2, finish)) - print("Allocation done") - - flag = 1 - - with th.no_grad(): - while(1): - event1.wait() - event1.clear() - if not finish: - break - if flag: - th.index_select(train_nfeat, 0, idxf1[0:idxf1_len].to(device=device), out=in_feat1[0:idxf1_len]) - th.index_select(train_labels, 0, idxl1[0:idxl1_len].to(device=device), out=in_label1[0:idxl1_len]) - else: - th.index_select(train_nfeat, 0, idxf2[0:idxf2_len].to(device=device), out=in_feat2[0:idxf2_len]) - th.index_select(train_labels, 0, idxl2[0:idxl2_len].to(device=device), out=in_label2[0:idxl2_len]) - flag = (flag == False) - th.cuda.synchronize() - event2.set() - -#### Entry point - -def run(q, args, device, data, in_feats, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2): - th.cuda.set_device(device) - - # Unpack data - n_classes, train_g, val_g, test_g = data - - train_mask = train_g.ndata['train_mask'] - val_mask = val_g.ndata['val_mask'] - test_mask = ~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']) - train_nid = train_mask.nonzero().squeeze() - val_nid = val_mask.nonzero().squeeze() - test_nid = test_mask.nonzero().squeeze() - - # Create PyTorch DataLoader for constructing blocks - sampler = dgl.dataloading.MultiLayerNeighborSampler( - [int(fanout) for fanout in args.fan_out.split(',')]) - dataloader = dgl.dataloading.NodeDataLoader( - train_g, - train_nid, - sampler, - batch_size=args.batch_size, - shuffle=True, - drop_last=False, - num_workers=args.num_workers) - - # Define model and optimizer - model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) - model = model.to(device) - loss_fcn = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=args.lr) - - in_feat1, in_feat2, in_label1, in_label2, finish = q.get() - - # A prologue for the pipelining purpose, just for the first minibatch of the first epoch - # ------------------------------------------------------ - flag = True - input_nodes, seeds, blocks_next = next(iter(dataloader)) - - # Send node indices for the next minibatch to the producer - if flag: - idxf1[0:len(input_nodes)].copy_(input_nodes) - idxl1[0:len(seeds)].copy_(seeds) - idxf1_len.fill_(len(input_nodes)) - idxl1_len.fill_(len(seeds)) - else: - idxf2[0:len(input_nodes)].copy_(input_nodes) - idxl2[0:len(seeds)].copy_(seeds) - idxf2_len.fill_(len(input_nodes)) - idxl2_len.fill_(len(seeds)) - event1.set() - time.sleep(1) - - input_nodes_n = len(input_nodes) - seeds_n = len(seeds) - flag = (flag == False) - blocks_temp = blocks_next - # ------------------------------------------------------ - # Prologue done - - # Training loop - avg = 0 - iter_tput = [] - for epoch in range(args.num_epochs): - tic = time.time() - - # Loop over the dataloader to sample the computation dependency graph as a list of - # blocks. - for step, (input_nodes, seeds, blocks_next) in enumerate(dataloader): - tic_step = time.time() - - # Send node indices for the next minibatch to the producer - if flag: - idxf1[0:len(input_nodes)].copy_(input_nodes) - idxl1[0:len(seeds)].copy_(seeds) - idxf1_len.fill_(len(input_nodes)) - idxl1_len.fill_(len(seeds)) - else: - idxf2[0:len(input_nodes)].copy_(input_nodes) - idxl2[0:len(seeds)].copy_(seeds) - idxf2_len.fill_(len(input_nodes)) - idxl2_len.fill_(len(seeds)) - - event1.set() - - event2.wait() - event2.clear() - - # Load the input features as well as output labels - if not flag: - batch_inputs = in_feat1[0:input_nodes_n] - batch_labels = in_label1[0:seeds_n] - else: - batch_inputs = in_feat2[0:input_nodes_n] - batch_labels = in_label2[0:seeds_n] - - blocks = [block.int().to(device) for block in blocks_temp] - # Compute loss and prediction - batch_pred = model(blocks, batch_inputs) - loss = loss_fcn(batch_pred, batch_labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - flag = (flag == False) - input_nodes_n = len(input_nodes) - seeds_n = len(seeds) - blocks_temp = blocks_next - - iter_tput.append(len(seeds) / (time.time() - tic_step)) - if step % args.log_every == 0: - acc = compute_acc(batch_pred, batch_labels) - print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format( - epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), th.cuda.max_memory_allocated() / 1000000)) - - # A prologue for the next epoch - # ------------------------------------------------------ - input_nodes, seeds, blocks_next = next(iter(dataloader)) - - if flag: - idxf1[0:len(input_nodes)].copy_(input_nodes) - idxl1[0:len(seeds)].copy_(seeds) - idxf1_len.fill_(len(input_nodes)) - idxl1_len.fill_(len(seeds)) - else: - idxf2[0:len(input_nodes)].copy_(input_nodes) - idxl2[0:len(seeds)].copy_(seeds) - idxf2_len.fill_(len(input_nodes)) - idxl2_len.fill_(len(seeds)) - event1.set() - - event2.wait() - event2.clear() - - # Load the input features as well as output labels - if not flag: - batch_inputs = in_feat1[0:input_nodes_n] - batch_labels = in_label1[0:seeds_n] - else: - batch_inputs = in_feat2[0:input_nodes_n] - batch_labels = in_label2[0:seeds_n] - - # Compute loss and prediction - blocks = [block.int().to(device) for block in blocks_temp] - batch_pred = model(blocks, batch_inputs) - loss = loss_fcn(batch_pred, batch_labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - flag = (flag == False) - input_nodes_n = len(input_nodes) - seeds_n = len(seeds) - blocks_temp = blocks_next - # ------------------------------------------------------ - # Prologue done - - toc = time.time() - print('Epoch Time(s): {:.4f}'.format(toc - tic)) - if epoch >= 5: - avg += toc - tic - if epoch % args.eval_every == 0 and epoch != 0: - eval_acc = evaluate( - model, val_g, val_nfeat, val_labels, val_nid, device) - test_acc = evaluate( - model, test_g, test_nfeat, test_labels, test_nid, device) - print('Eval Acc {:.4f}'.format(eval_acc)) - print('Test Acc: {:.4f}'.format(test_acc)) - - print('Avg epoch time: {}'.format(avg / (epoch - 4))) - - # Send a termination signal to the producer - finish.copy_(th.zeros(1, dtype=th.bool)) - event1.set() - -if __name__ == '__main__': - argparser = argparse.ArgumentParser("multi-gpu training") - argparser.add_argument('--gpu', type=int, default=0, - help="GPU device ID. Use -1 for CPU training") - argparser.add_argument('--num-epochs', type=int, default=20) - argparser.add_argument('--num-hidden', type=int, default=16) - argparser.add_argument('--num-layers', type=int, default=2) - argparser.add_argument('--fan-out', type=str, default='10,25') - argparser.add_argument('--batch-size', type=int, default=1000) - argparser.add_argument('--log-every', type=int, default=20) - argparser.add_argument('--eval-every', type=int, default=5) - argparser.add_argument('--lr', type=float, default=0.003) - argparser.add_argument('--dropout', type=float, default=0.5) - argparser.add_argument('--num-workers', type=int, default=4, - help="Number of sampling processes. Use 0 for no extra process.") - argparser.add_argument('--inductive', action='store_true', - help="Inductive learning setting") - argparser.add_argument('--mps', type=str, default='0') - args = argparser.parse_args() - - device = th.device('cuda:%d' % args.gpu) - mps = list(map(str, args.mps.split(','))) - - # If MPS values are given, then setup MPS - if float(mps[0]) != 0: - user_id = utils.mps_get_user_id() - utils.mps_daemon_start() - utils.mps_server_start(user_id) - server_pid = utils.mps_get_server_pid() - time.sleep(4) - - g, n_classes = load_reddit() - # Construct graph - g = dgl.as_heterograph(g) - - if args.inductive: - train_g, val_g, test_g = inductive_split(g) - else: - train_g = val_g = test_g = g - - # Create csr/coo/csc formats before launching training processes with multi-gpu. - # This avoids creating certain formats in each sub-process, which saves momory and CPU. - train_g.create_formats_() - val_g.create_formats_() - test_g.create_formats_() - # Pack data - data = n_classes, train_g, val_g, test_g - - train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features').share_memory_() - train_labels = val_labels = test_labels = g.ndata.pop('labels').share_memory_() - in_feats = train_nfeat.shape[1] - - fanout_max = 1 - for fanout in args.fan_out.split(','): - fanout_max = fanout_max * int(fanout) - - feat_dimension = [args.batch_size * fanout_max, train_nfeat.shape[1]] - label_dimension = [args.batch_size] - - ctx = mp.get_context('spawn') - - if float(mps[0]) != 0: - utils.mps_set_active_thread_percentage(server_pid, mps[0]) - # Just in case we add a timer to make sure MPS setup is done before we launch producer - time.sleep(4) - - # TODO: shared structure declarations can be futher simplified - q = ctx.SimpleQueue() - - # Synchornization signals - event1 = ctx.Event() - event2 = ctx.Event() - - # Indices and the their lengths shared between the producer and the training processes - idxf1 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() - idxf2 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() - idxl1 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() - idxl2 = th.zeros([args.batch_size * fanout_max], dtype=th.long).share_memory_() - idxf1_len = th.zeros([1], dtype=th.long).share_memory_() - idxf2_len = th.zeros([1], dtype=th.long).share_memory_() - idxl1_len = th.zeros([1], dtype=th.long).share_memory_() - idxl2_len = th.zeros([1], dtype=th.long).share_memory_() - - print("Producer Start") - producer_inst = ctx.Process(target=producer, - args=(q, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2, train_nfeat, train_labels, feat_dimension, label_dimension, device)) - producer_inst.start() - - if float(mps[0]) != 0: - # Just in case we add timers to make sure MPS setup is done before we launch training - time.sleep(8) - utils.mps_set_active_thread_percentage(server_pid, mps[1]) - time.sleep(4) - - print("Run Start") - p = mp.Process(target=thread_wrapped_func(run), - args=(q, args, device, data, in_feats, idxf1, idxf2, idxl1, idxl2, idxf1_len, idxf2_len, idxl1_len, idxl2_len, event1, event2)) - p.start() - - p.join() - producer_inst.join() diff --git a/python/dgl/multiprocessing/pytorch.py b/python/dgl/multiprocessing/pytorch.py index 94e6311d8969..f4d4602c8d8e 100644 --- a/python/dgl/multiprocessing/pytorch.py +++ b/python/dgl/multiprocessing/pytorch.py @@ -1,6 +1,5 @@ """PyTorch multiprocessing wrapper.""" from functools import wraps -import subprocess import traceback from _thread import start_new_thread import torch.multiprocessing as mp @@ -30,47 +29,9 @@ def _queue_result(): raise exception.__class__(trace) return decorated_function -<<<<<<< HEAD:examples/pytorch/graphsage/utils.py -# Get user id -def mps_get_user_id(): - result = subprocess.run(['id', '-u'], stdout=subprocess.PIPE) - return result.stdout.decode('utf-8').rstrip() - -# Start MPS daemon -def mps_daemon_start(): - result = subprocess.run(['nvidia-cuda-mps-control', '-d'], stdout=subprocess.PIPE) - print(result.stdout.decode('utf-8').rstrip()) - -# Start MPS server with user id -def mps_server_start(user_id): - ps = subprocess.Popen(('echo', 'start_server -uid ' + user_id), stdout=subprocess.PIPE) - output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) - ps.wait() - -# Get created server pid -def mps_get_server_pid(): - ps = subprocess.Popen(('echo', 'get_server_list'), stdout=subprocess.PIPE) - output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) - ps.wait() - return output.decode('utf-8').rstrip() - -# Set active thread percentage with the pid for producer -def mps_set_active_thread_percentage(server_pid, percentage): - ps = subprocess.Popen(('echo', 'set_active_thread_percentage ' + server_pid + ' ' + str(percentage)), stdout=subprocess.PIPE) - output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) - ps.wait() - print('Setting set_active_thread_percentage to', output.decode('utf-8').rstrip()) - -# Quit MPS -def mps_quit(): - ps = subprocess.Popen(('echo', 'quit'), stdout=subprocess.PIPE) - output = subprocess.check_output(('nvidia-cuda-mps-control'), stdin=ps.stdout) - ps.wait() -======= # pylint: disable=missing-docstring class Process(mp.Process): # pylint: disable=dangerous-default-value def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): target = thread_wrapped_func(target) super().__init__(group, target, name, args, kwargs, daemon=daemon) ->>>>>>> upstream/master:python/dgl/multiprocessing/pytorch.py From 9b21f998b75c9b4fc40bfbc378ef02fee3abea0b Mon Sep 17 00:00:00 2001 From: davidmin7 Date: Sun, 25 Jul 2021 04:30:17 -0500 Subject: [PATCH 3/6] Add multi-gpu unified tensor test for pytorch --- tests/pytorch/test_unified_tensor.py | 53 ++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/pytorch/test_unified_tensor.py b/tests/pytorch/test_unified_tensor.py index 4770361fdfcd..44ccf8566c45 100644 --- a/tests/pytorch/test_unified_tensor.py +++ b/tests/pytorch/test_unified_tensor.py @@ -1,9 +1,18 @@ +import dgl.multiprocessing as mp import unittest, os +import pytest import torch as th import dgl import backend as F +def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq, output_rand): + device = th.device('cuda:'+str(dev_id)) + th.cuda.set_device(device) + input_unified = dgl.contrib.UnifiedTensor(input, device=device) + output_seq.copy_(input_unified[seq_idx.to(device)].cpu()) + output_rand.copy_(input_unified[rand_idx.to(device)].cpu()) + @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') def test_unified_tensor(): @@ -27,5 +36,49 @@ def test_unified_tensor(): rand_idx = rand_idx.to(th.device('cuda')) assert th.all(th.eq(input[rand_idx].to(th.device('cuda')), input_unified[rand_idx])) +@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') +@unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') +def test_multi_gpu_unified_tensor(): + if F.ctx().type == 'cuda' and th.cuda.device_count() > 1: + pytest.skip("Only one GPU detected, skip multi-gpu test.") + + num_workers = th.cuda.device_count() + + test_row_size = 65536 + test_col_size = 128 + + rand_test_size = 8192 + + input = th.rand((test_row_size, test_col_size)).share_memory_() + seq_idx = th.arange(0, test_row_size).share_memory_() + rand_idx = th.randint(0, test_row_size, (rand_test_size,)).share_memory_() + + output_seq = [] + output_rand = [] + + output_seq_cpu = input[seq_idx] + output_rand_cpu = input[rand_idx] + + worker_list = [] + + ctx = mp.get_context('spawn') + for i in range(num_workers): + output_seq.append(th.zeros((test_row_size, test_col_size)).share_memory_()) + output_rand.append(th.zeros((rand_test_size, test_col_size)).share_memory_()) + p = ctx.Process(target=start_unified_tensor_worker, + args=(i, input, seq_idx, rand_idx, output_seq[i], output_rand[i],)) + p.start() + worker_list.append(p) + + for p in worker_list: + p.join() + for p in worker_list: + assert p.exitcode == 0 + for i in range(num_workers): + assert th.all(th.eq(output_seq_cpu, output_seq[i])) + assert th.all(th.eq(output_rand_cpu, output_rand[i])) + + if __name__ == '__main__': test_unified_tensor() + test_multi_gpu_unified_tensor() From dae7f91e5559ad8899301bb594650ad286429096 Mon Sep 17 00:00:00 2001 From: davidmin7 Date: Sun, 25 Jul 2021 08:04:10 -0500 Subject: [PATCH 4/6] relocate verification step to each process --- tests/pytorch/test_unified_tensor.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_unified_tensor.py b/tests/pytorch/test_unified_tensor.py index 44ccf8566c45..4a7f9ffaa831 100644 --- a/tests/pytorch/test_unified_tensor.py +++ b/tests/pytorch/test_unified_tensor.py @@ -6,12 +6,17 @@ import dgl import backend as F -def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq, output_rand): +def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq_cpu, output_rand_cpu): device = th.device('cuda:'+str(dev_id)) th.cuda.set_device(device) input_unified = dgl.contrib.UnifiedTensor(input, device=device) - output_seq.copy_(input_unified[seq_idx.to(device)].cpu()) - output_rand.copy_(input_unified[rand_idx.to(device)].cpu()) + + seq_idx = seq_idx.to(device) + assert th.all(th.eq(output_seq_cpu, input_unified[seq_idx])) + + rand_idx = rand_idx.to(device) + assert th.all(th.eq(output_rand_cpu, input_unified[rand_idx])) + @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') @@ -50,8 +55,8 @@ def test_multi_gpu_unified_tensor(): rand_test_size = 8192 input = th.rand((test_row_size, test_col_size)).share_memory_() - seq_idx = th.arange(0, test_row_size).share_memory_() - rand_idx = th.randint(0, test_row_size, (rand_test_size,)).share_memory_() + seq_idx = th.arange(0, test_row_size) + rand_idx = th.randint(0, test_row_size, (rand_test_size,)) output_seq = [] output_rand = [] @@ -63,10 +68,8 @@ def test_multi_gpu_unified_tensor(): ctx = mp.get_context('spawn') for i in range(num_workers): - output_seq.append(th.zeros((test_row_size, test_col_size)).share_memory_()) - output_rand.append(th.zeros((rand_test_size, test_col_size)).share_memory_()) p = ctx.Process(target=start_unified_tensor_worker, - args=(i, input, seq_idx, rand_idx, output_seq[i], output_rand[i],)) + args=(i, input, seq_idx, rand_idx, output_seq_cpu, output_rand_cpu,)) p.start() worker_list.append(p) @@ -74,9 +77,6 @@ def test_multi_gpu_unified_tensor(): p.join() for p in worker_list: assert p.exitcode == 0 - for i in range(num_workers): - assert th.all(th.eq(output_seq_cpu, output_seq[i])) - assert th.all(th.eq(output_rand_cpu, output_rand[i])) if __name__ == '__main__': From 1080da5c52bc005fd70cf61c858798290a15e8a4 Mon Sep 17 00:00:00 2001 From: davidmin7 Date: Sun, 25 Jul 2021 09:51:02 -0500 Subject: [PATCH 5/6] reduce number of workers --- tests/pytorch/test_unified_tensor.py | 32 +++++++++++++--------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_unified_tensor.py b/tests/pytorch/test_unified_tensor.py index 4a7f9ffaa831..f7658756a468 100644 --- a/tests/pytorch/test_unified_tensor.py +++ b/tests/pytorch/test_unified_tensor.py @@ -6,17 +6,12 @@ import dgl import backend as F -def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq_cpu, output_rand_cpu): +def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq, output_rand): device = th.device('cuda:'+str(dev_id)) th.cuda.set_device(device) input_unified = dgl.contrib.UnifiedTensor(input, device=device) - - seq_idx = seq_idx.to(device) - assert th.all(th.eq(output_seq_cpu, input_unified[seq_idx])) - - rand_idx = rand_idx.to(device) - assert th.all(th.eq(output_rand_cpu, input_unified[rand_idx])) - + output_seq.copy_(input_unified[seq_idx.to(device)].cpu()) + output_rand.copy_(input_unified[rand_idx.to(device)].cpu()) @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') @@ -43,11 +38,9 @@ def test_unified_tensor(): @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') -def test_multi_gpu_unified_tensor(): - if F.ctx().type == 'cuda' and th.cuda.device_count() > 1: - pytest.skip("Only one GPU detected, skip multi-gpu test.") - - num_workers = th.cuda.device_count() +def test_multi_gpu_unified_tensor(num_workers): + if F.ctx().type == 'cuda' and th.cuda.device_count() < num_workers: + pytest.skip("Not enough number of GPUs to do this test, skip multi-gpu test.") test_row_size = 65536 test_col_size = 128 @@ -55,8 +48,8 @@ def test_multi_gpu_unified_tensor(): rand_test_size = 8192 input = th.rand((test_row_size, test_col_size)).share_memory_() - seq_idx = th.arange(0, test_row_size) - rand_idx = th.randint(0, test_row_size, (rand_test_size,)) + seq_idx = th.arange(0, test_row_size).share_memory_() + rand_idx = th.randint(0, test_row_size, (rand_test_size,)).share_memory_() output_seq = [] output_rand = [] @@ -68,8 +61,10 @@ def test_multi_gpu_unified_tensor(): ctx = mp.get_context('spawn') for i in range(num_workers): + output_seq.append(th.zeros((test_row_size, test_col_size)).share_memory_()) + output_rand.append(th.zeros((rand_test_size, test_col_size)).share_memory_()) p = ctx.Process(target=start_unified_tensor_worker, - args=(i, input, seq_idx, rand_idx, output_seq_cpu, output_rand_cpu,)) + args=(i, input, seq_idx, rand_idx, output_seq[i], output_rand[i],)) p.start() worker_list.append(p) @@ -77,8 +72,11 @@ def test_multi_gpu_unified_tensor(): p.join() for p in worker_list: assert p.exitcode == 0 + for i in range(num_workers): + assert th.all(th.eq(output_seq_cpu, output_seq[i])) + assert th.all(th.eq(output_rand_cpu, output_rand[i])) if __name__ == '__main__': test_unified_tensor() - test_multi_gpu_unified_tensor() + test_multi_gpu_unified_tensor(2) From 0f3fe70b100c972c83e5e055910024c7678e448d Mon Sep 17 00:00:00 2001 From: davidmin7 Date: Sun, 25 Jul 2021 10:08:59 -0500 Subject: [PATCH 6/6] add parameter --- tests/pytorch/test_unified_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/test_unified_tensor.py b/tests/pytorch/test_unified_tensor.py index f7658756a468..2990715a87cb 100644 --- a/tests/pytorch/test_unified_tensor.py +++ b/tests/pytorch/test_unified_tensor.py @@ -38,6 +38,7 @@ def test_unified_tensor(): @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') +@pytest.mark.parametrize("num_workers", [1, 2]) def test_multi_gpu_unified_tensor(num_workers): if F.ctx().type == 'cuda' and th.cuda.device_count() < num_workers: pytest.skip("Not enough number of GPUs to do this test, skip multi-gpu test.") @@ -79,4 +80,5 @@ def test_multi_gpu_unified_tensor(num_workers): if __name__ == '__main__': test_unified_tensor() + test_multi_gpu_unified_tensor(1) test_multi_gpu_unified_tensor(2)