diff --git a/graphbolt/src/cuda/extension/gpu_graph_cache.cu b/graphbolt/src/cuda/extension/gpu_graph_cache.cu index 35f0a91f52ec..1edab7ea20cb 100644 --- a/graphbolt/src/cuda/extension/gpu_graph_cache.cu +++ b/graphbolt/src/cuda/extension/gpu_graph_cache.cu @@ -162,6 +162,11 @@ GpuGraphCache::~GpuGraphCache() { std::tuple GpuGraphCache::Query( torch::Tensor seeds) { + TORCH_CHECK(seeds.device().is_cuda(), "Seeds should be on a CUDA device."); + TORCH_CHECK( + seeds.device().index() == device_id_, + "Seeds should be on the correct CUDA device."); + TORCH_CHECK(seeds.sizes().size() == 1, "Keys should be a 1D tensor."); auto allocator = cuda::GetAllocator(); auto index_dtype = cached_edge_tensors_.at(0).scalar_type(); const dim3 block(kIntBlockSize); @@ -175,8 +180,7 @@ std::tuple GpuGraphCache::Query( map->capacity() * kIntGrowthFactor, cuco::cuda_stream_ref{cuda::GetCurrentStream()}); } - auto positions = - torch::empty(seeds.size(0), seeds.options().dtype(index_dtype)); + auto positions = torch::empty_like(seeds); CUDA_KERNEL_CALL( _QueryAndIncrement, grid, block, 0, static_cast(seeds.size(0)), seeds.data_ptr(), @@ -211,10 +215,8 @@ std::tuple GpuGraphCache::Query( thrust::counting_iterator iota{0}; auto position_and_index = thrust::make_zip_iterator(positions.data_ptr(), iota); - auto output_positions = - torch::empty(seeds.size(0), seeds.options().dtype(index_dtype)); - auto output_indices = - torch::empty(seeds.size(0), seeds.options().dtype(index_dtype)); + auto output_positions = torch::empty_like(seeds); + auto output_indices = torch::empty_like(seeds); auto output_position_and_index = thrust::make_zip_iterator( output_positions.data_ptr(), output_indices.data_ptr()); @@ -243,6 +245,10 @@ std::tuple> GpuGraphCache::Replace( num_tensors == cached_edge_tensors_.size(), "Same number of tensors need to be passed!"); const auto num_nodes = seeds.size(0); + TORCH_CHECK( + indptr.size(0) == num_nodes - num_hit + 1, + "(indptr.size(0) == seeds.size(0) - num_hit + 1) failed."); + const int64_t num_buffers = num_nodes * num_tensors; auto allocator = cuda::GetAllocator(); auto index_dtype = cached_edge_tensors_.at(0).scalar_type(); return AT_DISPATCH_INDEX_TYPES( @@ -268,6 +274,11 @@ std::tuple> GpuGraphCache::Replace( cached_edge_tensors_[i].scalar_type() == edge_tensors[i].scalar_type(), "The dtypes of edge tensors must match."); + if (i > 0) { + TORCH_CHECK( + edge_tensors[i - 1].size(0) == edge_tensors[i].size(0), + "The missing edge tensors should have identical size."); + } cache_missing_dtype_ptr[i] = { reinterpret_cast( cached_edge_tensors_[i].data_ptr()), @@ -281,11 +292,9 @@ std::tuple> GpuGraphCache::Replace( copy_n, cache_missing_dtype_ptr, num_tensors, cache_missing_dtype_dev.get()); - auto input = allocator.AllocateStorage( - num_tensors * num_nodes); + auto input = allocator.AllocateStorage(num_buffers); auto input_size = - allocator.AllocateStorage(num_tensors * num_nodes); - + allocator.AllocateStorage(num_buffers + 1); const auto cache_missing_dtype_dev_ptr = cache_missing_dtype_dev.get(); const auto indices_ptr = indices.data_ptr(); @@ -294,28 +303,25 @@ std::tuple> GpuGraphCache::Replace( const auto input_size_ptr = input_size.get(); const auto cache_indptr = indptr_.data_ptr(); const auto missing_indptr = indptr.data_ptr(); - CUB_CALL( - DeviceFor::Bulk, num_tensors * num_nodes, - [=] __device__(int64_t i) { - const auto tensor_idx = i / num_nodes; - const auto idx = i % num_nodes; - const auto pos = positions_ptr[idx]; - const auto original_idx = indices_ptr[idx]; - const auto [cache_ptr, missing_ptr, size] = - cache_missing_dtype_dev_ptr[tensor_idx]; - const auto is_cached = pos >= 0; - const auto offset = is_cached - ? cache_indptr[pos] - : missing_indptr[idx - num_hit]; - const auto offset_end = - is_cached ? cache_indptr[pos + 1] - : missing_indptr[idx - num_hit + 1]; - const auto out_idx = tensor_idx * num_nodes + original_idx; - - input_ptr[out_idx] = - (is_cached ? cache_ptr : missing_ptr) + offset * size; - input_size_ptr[out_idx] = size * (offset_end - offset); - }); + CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) { + const auto tensor_idx = i / num_nodes; + const auto idx = i % num_nodes; + const auto pos = positions_ptr[idx]; + const auto original_idx = indices_ptr[idx]; + const auto [cache_ptr, missing_ptr, size] = + cache_missing_dtype_dev_ptr[tensor_idx]; + const auto is_cached = pos >= 0; + const auto offset = is_cached ? cache_indptr[pos] + : missing_indptr[idx - num_hit]; + const auto offset_end = is_cached + ? cache_indptr[pos + 1] + : missing_indptr[idx - num_hit + 1]; + const auto out_idx = tensor_idx * num_nodes + original_idx; + + input_ptr[out_idx] = + (is_cached ? cache_ptr : missing_ptr) + offset * size; + input_size_ptr[out_idx] = size * (offset_end - offset); + }); auto output_indptr = torch::empty( num_nodes + 1, seeds.options().dtype(indptr_.scalar_type())); auto output_indptr_ptr = output_indptr.data_ptr(); @@ -340,8 +346,8 @@ std::tuple> GpuGraphCache::Replace( [=] __host__ __device__(indices_t x) { return x == threshold; }); - auto output_indices = torch::empty( - num_threshold, seeds.options().dtype(index_dtype)); + auto output_indices = + torch::empty(num_threshold, seeds.options()); CUB_CALL( DeviceSelect::Flagged, iota, is_threshold, output_indices.data_ptr(), @@ -364,9 +370,9 @@ std::tuple> GpuGraphCache::Replace( indptr.size(0) - 2, cached_output_size); cached_output_size = sindices.size(0); enough_space = num_edges_ + *cached_output_size <= - cached_edge_tensors_.at(0).size(0); + cached_edge_tensors_[i].size(0); if (enough_space) { - cached_edge_tensors_.at(i).slice( + cached_edge_tensors_[i].slice( 0, num_edges_, num_edges_ + *cached_output_size) = sindices; } else @@ -406,7 +412,7 @@ std::tuple> GpuGraphCache::Replace( for (size_t i = 0; i < num_tensors; i++) { output_edge_tensors.push_back(torch::empty( static_cast(output_size), - seeds.options().dtype(edge_tensors[i].scalar_type()))); + cached_edge_tensors_[i].options())); output_tensor_ptrs_ptr[i] = { reinterpret_cast( output_edge_tensors.back().data_ptr()), @@ -419,6 +425,11 @@ std::tuple> GpuGraphCache::Replace( THRUST_CALL( copy_n, output_tensor_ptrs_ptr, num_tensors, output_tensor_ptrs_dev.get()); + // This event and the later synchronization is needed so that the + // allocated pinned tensor stays alive during the copy operation. + // @todo mfbalin: eliminate this synchronization. + at::cuda::CUDAEvent copy_event; + copy_event.record(cuda::GetCurrentStream()); { thrust::counting_iterator iota{0}; @@ -434,7 +445,6 @@ std::tuple> GpuGraphCache::Replace( }); constexpr int64_t max_copy_at_once = std::numeric_limits::max(); - const int64_t num_buffers = num_nodes * num_tensors; for (int64_t i = 0; i < num_buffers; i += max_copy_at_once) { CUB_CALL( DeviceMemcpy::Batched, input.get() + i, @@ -443,6 +453,7 @@ std::tuple> GpuGraphCache::Replace( } } + copy_event.synchronize(); return std::make_tuple(output_indptr, output_edge_tensors); })); })); diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index c3c79f2914e7..08c89ba95bd2 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -1,5 +1,6 @@ """Graph Bolt DataLoaders""" +from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor import torch @@ -9,6 +10,7 @@ from .base import CopyTo from .feature_fetcher import FeatureFetcher +from .impl.gpu_graph_cache import GPUGraphCache from .impl.neighbor_sampler import SamplePerLayer from .internal import datapipe_graph_to_adjlist @@ -17,9 +19,34 @@ __all__ = [ "DataLoader", + "construct_gpu_graph_cache", ] +def construct_gpu_graph_cache( + sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold +): + "Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters." + graph = sample_per_layer_obj.sampler.__self__ + num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges) + dtypes = OrderedDict() + dtypes["indices"] = graph.indices.dtype + if graph.type_per_edge is not None: + dtypes["type_per_edge"] = graph.type_per_edge.dtype + if graph.edge_attributes is not None: + probs_or_mask = graph.edge_attributes.get( + sample_per_layer_obj.prob_name, None + ) + if probs_or_mask is not None: + dtypes["probs_or_mask"] = probs_or_mask.dtype + return GPUGraphCache( + num_gpu_cached_edges, + gpu_cache_threshold, + graph.csc_indptr.dtype, + list(dtypes.values()), + ) + + def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs): """Find parent of target_datapipe and wrap it with .""" datapipes = dp_utils.find_dps( @@ -197,13 +224,17 @@ def __init__( SamplePerLayer, ) executor = ThreadPoolExecutor(max_workers=1) + gpu_graph_cache = None for sampler in samplers: + if gpu_graph_cache is None: + gpu_graph_cache = construct_gpu_graph_cache( + sampler, num_gpu_cached_edges, gpu_cache_threshold + ) datapipe_graph = dp_utils.replace_dp( datapipe_graph, sampler, sampler.fetch_and_sample( - num_gpu_cached_edges, - gpu_cache_threshold, + gpu_graph_cache, _get_uva_stream(), executor, 1, diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 009e5408905a..5a8edbab31b3 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -1,6 +1,5 @@ """Neighbor subgraph samplers for GraphBolt.""" -from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -13,7 +12,6 @@ from ..subgraph_sampler import SubgraphSampler from .fused_csc_sampling_graph import fused_csc_sampling_graph -from .gpu_graph_cache import GPUGraphCache from .sampled_subgraph_impl import SampledSubgraphImpl @@ -34,33 +32,9 @@ class FetchCachedInsubgraphData(Mapper): function that can be called with the fetched graph structure. """ - def __init__( - self, - datapipe, - sample_per_layer_obj, - num_gpu_cached_edges, - gpu_cache_threshold, - ): + def __init__(self, datapipe, gpu_graph_cache): super().__init__(datapipe, self._fetch_per_layer) - graph = sample_per_layer_obj.sampler.__self__ - num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges) - dtypes = OrderedDict() - dtypes["indices"] = graph.indices.dtype - if graph.type_per_edge is not None: - dtypes["type_per_edge"] = graph.type_per_edge.dtype - if graph.edge_attributes is not None: - probs_or_mask = graph.edge_attributes.get( - sample_per_layer_obj.prob_name, None - ) - if probs_or_mask is not None: - dtypes["probs_or_mask"] = probs_or_mask.dtype - self.cache = GPUGraphCache( - num_gpu_cached_edges, - gpu_cache_threshold, - graph.csc_indptr.dtype, - list(dtypes.values()), - ) - self.dtypes = dtypes + self.cache = gpu_graph_cache def _fetch_per_layer(self, minibatch): minibatch._seeds, minibatch._replace = self.cache.query( @@ -143,17 +117,14 @@ def __init__( self, datapipe, sample_per_layer_obj, - num_gpu_cached_edges=0, - gpu_cache_threshold=1, + gpu_graph_cache, stream=None, executor=None, ): self.graph = sample_per_layer_obj.sampler.__self__ datapipe = datapipe.concat_hetero_seeds(sample_per_layer_obj) - if num_gpu_cached_edges > 0: - datapipe = datapipe.fetch_cached_insubgraph_data( - sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold - ) + if gpu_graph_cache is not None: + datapipe = datapipe.fetch_cached_insubgraph_data(gpu_graph_cache) super().__init__(datapipe, self._fetch_per_layer) self.prob_name = sample_per_layer_obj.prob_name self.stream = stream @@ -335,17 +306,16 @@ class FetcherAndSampler(MiniBatchTransformer): def __init__( self, sampler, - num_gpu_cached_edges, - gpu_cache_threshold, + gpu_graph_cache, stream, executor, buffer_size, ): datapipe = sampler.datapipe.fetch_insubgraph_data( - sampler, num_gpu_cached_edges, gpu_cache_threshold, stream, executor + sampler, gpu_graph_cache, stream, executor ) datapipe = datapipe.buffer(buffer_size).wait_future().wait() - if num_gpu_cached_edges > 0: + if gpu_graph_cache is not None: datapipe = datapipe.combine_cached_and_fetched_insubgraph(sampler) datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler) super().__init__(datapipe) diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index a85b6ed06748..d476fdb92da1 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -7,6 +7,7 @@ import dgl.graphbolt as gb import pytest import torch +from dgl.graphbolt.dataloader import construct_gpu_graph_cache def get_hetero_graph(): @@ -69,8 +70,15 @@ def test_NeighborSampler_GraphFetch( compact_per_layer = sample_per_layer.compact_per_layer(True) gb.seed(123) expected_results = list(compact_per_layer) + gpu_graph_cache = None + if num_cached_edges > 0: + gpu_graph_cache = construct_gpu_graph_cache( + sample_per_layer, num_cached_edges, 1 + ) datapipe = gb.FetchInsubgraphData( - datapipe, sample_per_layer, num_cached_edges, 1 + datapipe, + sample_per_layer, + gpu_graph_cache, ) datapipe = datapipe.wait_future() if num_cached_edges > 0: