Skip to content

Commit

Permalink
[GraphBolt][CUDA] Eliminate synchronization for overlap_graph_fetch. (
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Aug 16, 2024
1 parent 396f5f1 commit 2521081
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 79 deletions.
14 changes: 14 additions & 0 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,19 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
return std::make_tuple(output_indptr, results);
}

c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
IndexSelectCSCBatchedAsync(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size) {
return async(
[=] {
return IndexSelectCSCBatched(
indptr, indices_list, nodes, with_edge_ids, output_size);
},
utils::is_on_gpu(nodes));
}

} // namespace ops
} // namespace graphbolt
7 changes: 7 additions & 0 deletions graphbolt/src/index_select.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size);

c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
IndexSelectCSCBatchedAsync(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size);

} // namespace ops
} // namespace graphbolt

Expand Down
1 change: 1 addition & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("scatter_async", &ops::ScatterAsync);
m.def("index_select_csc", &ops::IndexSelectCSC);
m.def("index_select_csc_batched", &ops::IndexSelectCSCBatched);
m.def("index_select_csc_batched_async", &ops::IndexSelectCSCBatchedAsync);
m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create);
m.def("detect_io_uring", &io_uring::IsAvailable);
m.def("set_num_io_uring_threads", &io_uring::SetNumThreads);
Expand Down
135 changes: 67 additions & 68 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"LayerNeighborSampler",
"SamplePerLayer",
"FetchInsubgraphData",
"ConcatHeteroSeeds",
"CombineCachedAndFetchedInSubgraph",
]

Expand Down Expand Up @@ -105,29 +104,6 @@ def _wait_replace_future(self, minibatch):
return minibatch


@functional_datapipe("concat_hetero_seeds")
class ConcatHeteroSeeds(Mapper):
"""Concatenates the seeds into a single tensor in the hetero case."""

def __init__(self, datapipe, graph):
super().__init__(datapipe, self._concat)
self.graph = graph

def _concat(self, minibatch):
seeds = minibatch._seed_nodes
if isinstance(seeds, dict):
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seed_offsets = None
minibatch._seeds = seeds
minibatch._seed_offsets = seed_offsets

return minibatch


@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(MiniBatchTransformer):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
Expand All @@ -142,20 +118,46 @@ def __init__(
graph,
prob_name,
):
datapipe = datapipe.concat_hetero_seeds(graph)
datapipe = datapipe.transform(self._concat_hetero_seeds)
if graph._gpu_graph_cache is not None:
datapipe = datapipe.fetch_cached_insubgraph_data(
graph._gpu_graph_cache
)
datapipe = datapipe.transform(self._fetch_per_layer)
datapipe = datapipe.buffer().wait()
datapipe = datapipe.transform(self._fetch_per_layer_stage_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._fetch_per_layer_stage_2)
if graph._gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)
super().__init__(datapipe)
self.graph = graph
self.prob_name = prob_name

def _fetch_per_layer(self, minibatch):
def _concat_hetero_seeds(self, minibatch):
"""Concatenates the seeds into a single tensor in the hetero case."""
seeds = minibatch._seed_nodes
if isinstance(seeds, dict):
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seed_offsets = None
minibatch._seeds = seeds
minibatch._seed_offsets = seed_offsets

return minibatch

def _fetch_per_layer_stage_1(self, minibatch):
minibatch._async_handle_fetch = self._fetch_per_layer_async(minibatch)
next(minibatch._async_handle_fetch)
return minibatch

def _fetch_per_layer_stage_2(self, minibatch):
minibatch = next(minibatch._async_handle_fetch)
delattr(minibatch, "_async_handle_fetch")
return minibatch

def _fetch_per_layer_async(self, minibatch):
stream = torch.cuda.current_stream()
uva_stream = get_host_to_device_uva_stream()
uva_stream.wait_stream(stream)
Expand All @@ -167,11 +169,6 @@ def _fetch_per_layer(self, minibatch):

seeds.record_stream(torch.cuda.current_stream())

def record_stream(tensor):
if tensor.is_cuda:
tensor.record_stream(stream)
return tensor

# Packs tensors for batch slicing.
tensors_to_be_sliced = [self.graph.indices]

Expand All @@ -190,51 +187,53 @@ def record_stream(tensor):
has_probs_or_mask = True

# Slices the batched tensors.
(
indptr,
sliced_tensors,
) = torch.ops.graphbolt.index_select_csc_batched(
future = torch.ops.graphbolt.index_select_csc_batched_async(
self.graph.csc_indptr, tensors_to_be_sliced, seeds, True, None
)
for tensor in [indptr] + sliced_tensors:
record_stream(tensor)

# Unpacks the sliced tensors.
indices = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
yield

# graphbolt::async has already recorded a CUDAEvent for us and
# called CUDAStreamWaitEvent for us on the current stream.
indptr, sliced_tensors = future.wait()

type_per_edge = None
if has_type_per_edge:
type_per_edge = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
for tensor in [indptr] + sliced_tensors:
tensor.record_stream(stream)

probs_or_mask = None
if has_probs_or_mask:
probs_or_mask = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
# Unpacks the sliced tensors.
indices = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]

edge_ids = sliced_tensors[0]
type_per_edge = None
if has_type_per_edge:
type_per_edge = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
assert len(sliced_tensors) == 0

subgraph = fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=self.graph.node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)

subgraph._indptr_node_type_offset_list = seed_offsets
minibatch._sliced_sampling_graph = subgraph
probs_or_mask = None
if has_probs_or_mask:
probs_or_mask = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]

minibatch.wait = torch.cuda.current_stream().record_event().wait
edge_ids = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
assert len(sliced_tensors) == 0

subgraph = fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=self.graph.node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)

subgraph._indptr_node_type_offset_list = seed_offsets
minibatch._sliced_sampling_graph = subgraph

return minibatch
yield minibatch


@functional_datapipe("sample_per_layer")
Expand Down
13 changes: 2 additions & 11 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,14 @@ def test_gpu_sampling_DataLoader(
dataloader, dataloader2 = dataloaders

bufferer_cnt = int(enable_feature_fetch and overlap_feature_fetch)
awaiter_cnt = 0
if overlap_graph_fetch:
bufferer_cnt += num_layers
awaiter_cnt += num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
if overlap_graph_fetch:
bufferer_cnt += 0 * num_layers
if num_gpu_cached_edges > 0:
bufferer_cnt += 2 * num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
datapipe = dataloader.dataset
datapipe_graph = traverse_dps(datapipe)
awaiters = find_dps(
datapipe_graph,
dgl.graphbolt.Waiter,
)
assert len(awaiters) == awaiter_cnt
bufferers = find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
Expand Down

0 comments on commit 2521081

Please sign in to comment.