From fc29d0eb02e91cf825e7a9ec9fd5d608d866fda9 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Sun, 18 Aug 2024 10:24:55 -0400 Subject: [PATCH] [GraphBolt][CUDA] Overlap original edge ids fetch. (#7714) --- .../src/cuda/extension/gpu_graph_cache.cu | 101 +++++++++++------- .../src/cuda/extension/gpu_graph_cache.h | 11 +- .../impl/fused_csc_sampling_graph.py | 90 +++++++++++----- python/dgl/graphbolt/impl/gpu_graph_cache.py | 11 +- python/dgl/graphbolt/impl/neighbor_sampler.py | 89 ++++++++++++--- .../graphbolt/impl/test_gpu_graph_cache.py | 8 +- .../graphbolt/impl/test_neighbor_sampler.py | 12 ++- 7 files changed, 231 insertions(+), 91 deletions(-) diff --git a/graphbolt/src/cuda/extension/gpu_graph_cache.cu b/graphbolt/src/cuda/extension/gpu_graph_cache.cu index 80a4bcfd7171..c0e70421bf44 100644 --- a/graphbolt/src/cuda/extension/gpu_graph_cache.cu +++ b/graphbolt/src/cuda/extension/gpu_graph_cache.cu @@ -115,14 +115,16 @@ constexpr int kIntBlockSize = 512; c10::intrusive_ptr GpuGraphCache::Create( const int64_t num_edges, const int64_t threshold, - torch::ScalarType indptr_dtype, std::vector dtypes) { + torch::ScalarType indptr_dtype, std::vector dtypes, + bool has_original_edge_ids) { return c10::make_intrusive( - num_edges, threshold, indptr_dtype, dtypes); + num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids); } GpuGraphCache::GpuGraphCache( const int64_t num_edges, const int64_t threshold, - torch::ScalarType indptr_dtype, std::vector dtypes) { + torch::ScalarType indptr_dtype, std::vector dtypes, + bool has_original_edge_ids) { const int64_t initial_node_capacity = 1024; AT_DISPATCH_INDEX_TYPES( dtypes.at(0), "GpuGraphCache::GpuGraphCache", ([&] { @@ -149,7 +151,9 @@ GpuGraphCache::GpuGraphCache( num_edges_ = 0; indptr_ = torch::zeros(initial_node_capacity + 1, options.dtype(indptr_dtype)); - offset_ = torch::empty(indptr_.size(0) - 1, indptr_.options()); + if (!has_original_edge_ids) { + offset_ = torch::empty(indptr_.size(0) - 1, indptr_.options()); + } for (auto dtype : dtypes) { cached_edge_tensors_.push_back( torch::empty(num_edges, options.dtype(dtype))); @@ -249,8 +253,9 @@ std::tuple> GpuGraphCache::Replace( torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions, int64_t num_hit, int64_t num_threshold, torch::Tensor indptr, std::vector edge_tensors) { + const auto with_edge_ids = offset_.has_value(); // The last element of edge_tensors has the edge ids. - const auto num_tensors = edge_tensors.size() - 1; + const auto num_tensors = edge_tensors.size() - with_edge_ids; TORCH_CHECK( num_tensors == cached_edge_tensors_.size(), "Same number of tensors need to be passed!"); @@ -312,8 +317,12 @@ std::tuple> GpuGraphCache::Replace( auto input = allocator.AllocateStorage(num_buffers); auto input_size = allocator.AllocateStorage(num_buffers + 1); - auto edge_id_offsets = torch::empty( - num_nodes, seeds.options().dtype(offset_.scalar_type())); + torch::optional edge_id_offsets; + if (with_edge_ids) { + edge_id_offsets = torch::empty( + num_nodes, + seeds.options().dtype(offset_.value().scalar_type())); + } const auto cache_missing_dtype_dev_ptr = cache_missing_dtype_dev.get(); const auto indices_ptr = indices.data_ptr(); @@ -321,12 +330,15 @@ std::tuple> GpuGraphCache::Replace( const auto input_ptr = input.get(); const auto input_size_ptr = input_size.get(); const auto edge_id_offsets_ptr = - edge_id_offsets.data_ptr(); + edge_id_offsets ? edge_id_offsets->data_ptr() + : nullptr; const auto cache_indptr = indptr_.data_ptr(); const auto missing_indptr = indptr.data_ptr(); - const auto cache_offset = offset_.data_ptr(); + const auto cache_offset = + offset_ ? offset_->data_ptr() : nullptr; const auto missing_edge_ids = - edge_tensors.back().data_ptr(); + edge_id_offsets ? edge_tensors.back().data_ptr() + : nullptr; CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) { const auto tensor_idx = i / num_nodes; const auto idx = i % num_nodes; @@ -340,14 +352,14 @@ std::tuple> GpuGraphCache::Replace( const auto offset_end = is_cached ? cache_indptr[pos + 1] : missing_indptr[idx - num_hit + 1]; - const auto edge_id = - is_cached ? cache_offset[pos] : missing_edge_ids[offset]; const auto out_idx = tensor_idx * num_nodes + original_idx; input_ptr[out_idx] = (is_cached ? cache_ptr : missing_ptr) + offset * size; input_size_ptr[out_idx] = size * (offset_end - offset); - if (i < num_nodes) { + if (edge_id_offsets_ptr && i < num_nodes) { + const auto edge_id = + is_cached ? cache_offset[pos] : missing_edge_ids[offset]; edge_id_offsets_ptr[out_idx] = edge_id; } }); @@ -390,10 +402,12 @@ std::tuple> GpuGraphCache::Replace( indptr_.size(0) * kIntGrowthFactor, indptr_.options()); new_indptr.slice(0, 0, indptr_.size(0)) = indptr_; indptr_ = new_indptr; - auto new_offset = - torch::empty(indptr_.size(0) - 1, offset_.options()); - new_offset.slice(0, 0, offset_.size(0)) = offset_; - offset_ = new_offset; + if (offset_) { + auto new_offset = + torch::empty(indptr_.size(0) - 1, offset_->options()); + new_offset.slice(0, 0, offset_->size(0)) = *offset_; + offset_ = new_offset; + } } torch::Tensor sindptr; bool enough_space; @@ -415,22 +429,32 @@ std::tuple> GpuGraphCache::Replace( } if (enough_space) { auto num_edges = num_edges_; - auto transform_input_it = thrust::make_zip_iterator( - sindptr.data_ptr() + 1, - sliced_indptr.data_ptr()); - auto transform_output_it = thrust::make_zip_iterator( - indptr_.data_ptr() + num_nodes_ + 1, - offset_.data_ptr() + num_nodes_); - THRUST_CALL( - transform, transform_input_it, - transform_input_it + sindptr.size(0) - 1, - transform_output_it, - [=] __host__ __device__( - const thrust::tuple& x) { - return thrust::make_tuple( - thrust::get<0>(x) + num_edges, - missing_edge_ids[thrust::get<1>(x)]); - }); + if (offset_) { + auto transform_input_it = thrust::make_zip_iterator( + sindptr.data_ptr() + 1, + sliced_indptr.data_ptr()); + auto transform_output_it = thrust::make_zip_iterator( + indptr_.data_ptr() + num_nodes_ + 1, + offset_->data_ptr() + num_nodes_); + THRUST_CALL( + transform, transform_input_it, + transform_input_it + sindptr.size(0) - 1, + transform_output_it, + [=] __host__ __device__( + const thrust::tuple& x) { + return thrust::make_tuple( + thrust::get<0>(x) + num_edges, + missing_edge_ids[thrust::get<1>(x)]); + }); + } else { + THRUST_CALL( + transform, sindptr.data_ptr() + 1, + sindptr.data_ptr() + sindptr.size(0), + indptr_.data_ptr() + num_nodes_ + 1, + [=] __host__ __device__(const indptr_t& x) { + return x + num_edges; + }); + } auto map = reinterpret_cast*>(map_); const dim3 block(kIntBlockSize); const dim3 grid( @@ -467,10 +491,13 @@ std::tuple> GpuGraphCache::Replace( .view(edge_tensors[i].scalar_type()) .slice(0, 0, static_cast(output_size))); } - // Append the edge ids as the last element of the output. - output_edge_tensors.push_back(ops::IndptrEdgeIdsImpl( - output_indptr, output_indptr.scalar_type(), edge_id_offsets, - static_cast(static_cast(output_size)))); + if (edge_id_offsets) { + // Append the edge ids as the last element of the output. + output_edge_tensors.push_back(ops::IndptrEdgeIdsImpl( + output_indptr, output_indptr.scalar_type(), + *edge_id_offsets, + static_cast(static_cast(output_size)))); + } { thrust::counting_iterator iota{0}; diff --git a/graphbolt/src/cuda/extension/gpu_graph_cache.h b/graphbolt/src/cuda/extension/gpu_graph_cache.h index 0708f5d00917..42324ef33140 100644 --- a/graphbolt/src/cuda/extension/gpu_graph_cache.h +++ b/graphbolt/src/cuda/extension/gpu_graph_cache.h @@ -47,10 +47,13 @@ class GpuGraphCache : public torch::CustomClassHolder { * @param indptr_dtype The node id datatype. * @param dtypes The dtypes of the edge tensors to be cached. dtypes[0] is * reserved for the indices edge tensor holding node ids. + * @param has_original_edge_ids Whether the graph to be cached has original + * edge ids. */ GpuGraphCache( const int64_t num_edges, const int64_t threshold, - torch::ScalarType indptr_dtype, std::vector dtypes); + torch::ScalarType indptr_dtype, std::vector dtypes, + bool has_original_edge_ids); GpuGraphCache() = default; @@ -109,7 +112,8 @@ class GpuGraphCache : public torch::CustomClassHolder { static c10::intrusive_ptr Create( const int64_t num_edges, const int64_t threshold, - torch::ScalarType indptr_dtype, std::vector dtypes); + torch::ScalarType indptr_dtype, std::vector dtypes, + bool has_original_edge_ids); private: void* map_; // pointer to the hash table. @@ -119,7 +123,8 @@ class GpuGraphCache : public torch::CustomClassHolder { int64_t num_nodes_; // The number of cached nodes in the cache. int64_t num_edges_; // The number of cached edges in the cache. torch::Tensor indptr_; // The cached graph structure indptr tensor. - torch::Tensor offset_; // The original graph's sliced_indptr tensor. + torch::optional + offset_; // The original graph's sliced_indptr tensor. std::vector cached_edge_tensors_; // The cached graph // structure edge tensors. std::mutex mtx_; // Protects the data structure and makes it threadsafe. diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 75f87d19aa2b..cc9137092cf6 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -23,19 +23,32 @@ class _SampleNeighborsWaiter: - def __init__(self, fn, future, seed_offsets): + def __init__( + self, fn, future, seed_offsets, fetching_original_edge_ids_is_optional + ): self.fn = fn self.future = future self.seed_offsets = seed_offsets + self.fetching_original_edge_ids_is_optional = ( + fetching_original_edge_ids_is_optional + ) def wait(self): """Returns the stored value when invoked.""" fn = self.fn C_sampled_subgraph = self.future.wait() seed_offsets = self.seed_offsets + fetching_original_edge_ids_is_optional = ( + self.fetching_original_edge_ids_is_optional + ) # Ensure there is no memory leak. self.fn = self.future = self.seed_offsets = None - return fn(C_sampled_subgraph, seed_offsets) + self.fetching_original_edge_ids_is_optional = None + return fn( + C_sampled_subgraph, + seed_offsets, + fetching_original_edge_ids_is_optional, + ) class FusedCSCSamplingGraph(SamplingGraph): @@ -592,6 +605,7 @@ def _convert_to_sampled_subgraph( self, C_sampled_subgraph: torch.ScriptObject, seed_offsets: Optional[list] = None, + fetching_original_edge_ids_is_optional: bool = False, ) -> SampledSubgraphImpl: """An internal function used to convert a fused homogeneous sampled subgraph to general struct 'SampledSubgraphImpl'.""" @@ -611,9 +625,15 @@ def _convert_to_sampled_subgraph( and ORIGINAL_EDGE_ID in self.edge_attributes ) original_edge_ids = ( - torch.ops.graphbolt.index_select( - self.edge_attributes[ORIGINAL_EDGE_ID], - edge_ids_in_fused_csc_sampling_graph, + ( + torch.ops.graphbolt.index_select( + self.edge_attributes[ORIGINAL_EDGE_ID], + edge_ids_in_fused_csc_sampling_graph, + ) + if not fetching_original_edge_ids_is_optional + or not edge_ids_in_fused_csc_sampling_graph.is_cuda + or not self.edge_attributes[ORIGINAL_EDGE_ID].is_pinned() + else None ) if has_original_eids else edge_ids_in_fused_csc_sampling_graph @@ -621,8 +641,8 @@ def _convert_to_sampled_subgraph( if type_per_edge is None and etype_offsets is None: # The sampled graph is already a homogeneous graph. sampled_csc = CSCFormatBase(indptr=indptr, indices=indices) - if indices is not None: - # Only needed to fetch indices. + if indices is not None and original_edge_ids is not None: + # Only needed to fetch indices or original_edge_ids. edge_ids_in_fused_csc_sampling_graph = None else: offset = self._node_type_offset_list @@ -691,11 +711,17 @@ def _convert_to_sampled_subgraph( ] ] ) - original_hetero_edge_ids[etype] = original_edge_ids[ - etype_offsets[etype_id] : etype_offsets[etype_id + 1] - ] - if indices is None: - # Only needed to fetch indices. + original_hetero_edge_ids[etype] = ( + None + if original_edge_ids is None + else original_edge_ids[ + etype_offsets[etype_id] : etype_offsets[ + etype_id + 1 + ] + ] + ) + if indices is None or original_edge_ids is None: + # Only needed to fetch indices or original edge ids. sampled_hetero_edge_ids_in_fused_csc_sampling_graph[ etype ] = edge_ids_in_fused_csc_sampling_graph[ @@ -727,7 +753,7 @@ def sample_neighbors( fanouts: torch.Tensor, replace: bool = False, probs_name: Optional[str] = None, - returning_indices_is_optional: bool = False, + returning_indices_and_original_edge_ids_are_optional: bool = False, async_op: bool = False, ) -> SampledSubgraphImpl: """Sample neighboring edges of the given nodes and return the induced @@ -768,10 +794,12 @@ def sample_neighbors( corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges. - returning_indices_is_optional: bool + returning_indices_and_original_edge_ids_are_optional: bool Boolean indicating whether it is okay for the call to this function - to leave the indices tensor uninitialized. In this case, it is the - user's responsibility to gather it using the edge ids. + to leave the indices and the original edge ids tensors + uninitialized. In this case, it is the user's responsibility to + gather them using _edge_ids_in_fused_csc_sampling_graph if either is + missing. async_op: bool Boolean indicating whether the call is asynchronous. If so, the result can be obtained by calling wait on the returned future. @@ -818,7 +846,7 @@ def sample_neighbors( fanouts, replace=replace, probs_or_mask=probs_or_mask, - returning_indices_is_optional=returning_indices_is_optional, + returning_indices_is_optional=returning_indices_and_original_edge_ids_are_optional, async_op=async_op, ) if async_op: @@ -826,10 +854,13 @@ def sample_neighbors( self._convert_to_sampled_subgraph, C_sampled_subgraph, seed_offsets, + returning_indices_and_original_edge_ids_are_optional, ) else: return self._convert_to_sampled_subgraph( - C_sampled_subgraph, seed_offsets + C_sampled_subgraph, + seed_offsets, + returning_indices_and_original_edge_ids_are_optional, ) def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask): @@ -956,7 +987,7 @@ def sample_layer_neighbors( fanouts: torch.Tensor, replace: bool = False, probs_name: Optional[str] = None, - returning_indices_is_optional: bool = False, + returning_indices_and_original_edge_ids_are_optional: bool = False, random_seed: torch.Tensor = None, seed2_contribution: float = 0.0, async_op: bool = False, @@ -1001,10 +1032,12 @@ def sample_layer_neighbors( corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges. - returning_indices_is_optional: bool + returning_indices_and_original_edge_ids_are_optional: bool Boolean indicating whether it is okay for the call to this function - to leave the indices tensor uninitialized. In this case, it is the - user's responsibility to gather it using the edge ids. + to leave the indices and the original edge ids tensors + uninitialized. In this case, it is the user's responsibility to + gather them using _edge_ids_in_fused_csc_sampling_graph if either is + missing. random_seed: torch.Tensor, optional An int64 tensor with one or two elements. @@ -1092,7 +1125,7 @@ def sample_layer_neighbors( fanouts.tolist(), replace, True, # is_labor - returning_indices_is_optional, + returning_indices_and_original_edge_ids_are_optional, probs_or_mask, random_seed, seed2_contribution, @@ -1102,10 +1135,13 @@ def sample_layer_neighbors( self._convert_to_sampled_subgraph, C_sampled_subgraph, seed_offsets, + returning_indices_and_original_edge_ids_are_optional, ) else: return self._convert_to_sampled_subgraph( - C_sampled_subgraph, seed_offsets + C_sampled_subgraph, + seed_offsets, + returning_indices_and_original_edge_ids_are_optional, ) def temporal_sample_neighbors( @@ -1512,15 +1548,21 @@ def _initialize_gpu_graph_cache( dtypes = [self.indices.dtype] if self.type_per_edge is not None: dtypes.append(self.type_per_edge.dtype) + has_original_edge_ids = False if self.edge_attributes is not None: probs_or_mask = self.edge_attributes.get(prob_name, None) if probs_or_mask is not None: dtypes.append(probs_or_mask.dtype) + original_edge_ids = self.edge_attributes.get(ORIGINAL_EDGE_ID, None) + if original_edge_ids is not None: + dtypes.append(original_edge_ids.dtype) + has_original_edge_ids = True self._gpu_graph_cache_ = GPUGraphCache( num_gpu_cached_edges, gpu_cache_threshold, self.csc_indptr.dtype, dtypes, + has_original_edge_ids, ) diff --git a/python/dgl/graphbolt/impl/gpu_graph_cache.py b/python/dgl/graphbolt/impl/gpu_graph_cache.py index e4cf78b589af..a6a640dfa6c3 100644 --- a/python/dgl/graphbolt/impl/gpu_graph_cache.py +++ b/python/dgl/graphbolt/impl/gpu_graph_cache.py @@ -17,15 +17,19 @@ class GPUGraphCache(object): The dtype of the indptr tensor of the graph. dtypes : list[torch.dtype] The dtypes of the edge tensors that are going to be cached. + has_original_edge_ids : bool + Whether the graph to be cached has original edge ids. """ - def __init__(self, num_edges, threshold, indptr_dtype, dtypes): + def __init__( + self, num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids + ): major, _ = torch.cuda.get_device_capability() assert ( major >= 7 ), "GPUGraphCache is supported only on CUDA compute capability >= 70 (Volta)." self._cache = torch.ops.graphbolt.gpu_graph_cache( - num_edges, threshold, indptr_dtype, dtypes + num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids ) self.total_miss = 0 self.total_queries = 0 @@ -44,7 +48,8 @@ def query(self, keys): A tuple containing (missing_keys, replace_fn) where replace_fn is a function that should be called with the graph structure corresponding to the missing keys. Its arguments are - (Tensor, list(Tensor)). + (Tensor, list(Tensor)), where the first tensor is the missing indptr + and the second list is the missing edge tensors. """ self.total_queries += keys.shape[0] ( diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index fc834718ef4d..4229edb6be00 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -177,7 +177,8 @@ def _fetch_per_layer_async(self, minibatch): tensors_to_be_sliced.append(self.graph.type_per_edge) has_type_per_edge = True - has_probs_or_mask = None + has_probs_or_mask = False + has_original_edge_ids = False if self.graph.edge_attributes is not None: probs_or_mask = self.graph.edge_attributes.get( self.prob_name, None @@ -185,10 +186,21 @@ def _fetch_per_layer_async(self, minibatch): if probs_or_mask is not None: tensors_to_be_sliced.append(probs_or_mask) has_probs_or_mask = True + original_edge_ids = self.graph.edge_attributes.get( + ORIGINAL_EDGE_ID, None + ) + if original_edge_ids is not None: + tensors_to_be_sliced.append(original_edge_ids) + has_original_edge_ids = True # Slices the batched tensors. future = torch.ops.graphbolt.index_select_csc_batched_async( - self.graph.csc_indptr, tensors_to_be_sliced, seeds, True, None + self.graph.csc_indptr, + tensors_to_be_sliced, + seeds, + # When there are no edge ids, we assume it is arange(num_edges). + not has_original_edge_ids, + None, ) yield @@ -251,19 +263,35 @@ def __init__( asynchronous=False, ): graph = sampler.__self__ - self.returning_indices_is_optional = False + self.returning_indices_and_original_edge_ids_are_optional = False + original_edge_ids = ( + None + if graph.edge_attributes is None + else graph.edge_attributes.get(ORIGINAL_EDGE_ID, None) + ) if ( overlap_fetch and sampler.__name__ == "sample_neighbors" - and graph.indices.is_pinned() + and ( + graph.indices.is_pinned() + or ( + original_edge_ids is not None + and original_edge_ids.is_pinned() + ) + ) and graph._gpu_graph_cache is None ): datapipe = datapipe.transform(self._sample_per_layer) if asynchronous: datapipe = datapipe.buffer() datapipe = datapipe.transform(self._wait_subgraph_future) + fetch_indices_and_original_edge_ids_fn = partial( + self._fetch_indices_and_original_edge_ids, + graph.indices, + original_edge_ids, + ) datapipe = ( - datapipe.transform(partial(self._fetch_indices, graph.indices)) + datapipe.transform(fetch_indices_and_original_edge_ids_fn) .buffer() .wait() ) @@ -276,7 +304,7 @@ def __init__( graph.node_type_to_id, ) ) - self.returning_indices_is_optional = True + self.returning_indices_and_original_edge_ids_are_optional = True elif overlap_fetch: datapipe = datapipe.fetch_insubgraph_data(graph, prob_name) datapipe = datapipe.transform( @@ -309,7 +337,7 @@ def _sample_per_layer(self, minibatch): self.fanout, self.replace, self.prob_name, - self.returning_indices_is_optional, + self.returning_indices_and_original_edge_ids_are_optional, async_op=self.asynchronous, **kwargs, ) @@ -341,7 +369,7 @@ def _wait_subgraph_future(minibatch): return minibatch @staticmethod - def _fetch_indices(indices, minibatch): + def _fetch_indices_and_original_edge_ids(indices, orig_edge_ids, minibatch): stream = torch.cuda.current_stream() host_to_device_stream = get_host_to_device_uva_stream() host_to_device_stream.wait_stream(stream) @@ -366,16 +394,43 @@ def record_stream(tensor): index_select(indices, edge_ids) ) minibatch._indices_needs_offset_subtraction = True - elif subgraph.sampled_csc.indices is None: - subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream( - torch.cuda.current_stream() - ) - subgraph.sampled_csc.indices = record_stream( - index_select( - indices, subgraph._edge_ids_in_fused_csc_sampling_graph + if ( + orig_edge_ids is not None + and subgraph.original_edge_ids[etype] is None + ): + edge_ids = ( + subgraph._edge_ids_in_fused_csc_sampling_graph[ + etype + ] + ) + edge_ids.record_stream(torch.cuda.current_stream()) + subgraph.original_edge_ids[etype] = record_stream( + index_select(orig_edge_ids, edge_ids) + ) + else: + if subgraph.sampled_csc.indices is None: + subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream( + torch.cuda.current_stream() + ) + subgraph.sampled_csc.indices = record_stream( + index_select( + indices, + subgraph._edge_ids_in_fused_csc_sampling_graph, + ) + ) + if ( + orig_edge_ids is not None + and subgraph.original_edge_ids is None + ): + subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream( + torch.cuda.current_stream() + ) + subgraph.original_edge_ids = record_stream( + index_select( + orig_edge_ids, + subgraph._edge_ids_in_fused_csc_sampling_graph, + ) ) - ) - minibatch._indices_needs_offset_subtraction = True subgraph._edge_ids_in_fused_csc_sampling_graph = None minibatch.wait = torch.cuda.current_stream().record_event().wait diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py b/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py index fdd7329cccbf..e6034cf77019 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py @@ -36,7 +36,8 @@ ], ) @pytest.mark.parametrize("cache_size", [4, 9, 11]) -def test_gpu_graph_cache(indptr_dtype, dtype, cache_size): +@pytest.mark.parametrize("with_edge_ids", [True, False]) +def test_gpu_graph_cache(indptr_dtype, dtype, cache_size, with_edge_ids): indices_dtype = torch.int32 indptr = torch.tensor([0, 3, 6, 10], dtype=indptr_dtype, pin_memory=True) indices = torch.arange(0, indptr[-1], dtype=indices_dtype, pin_memory=True) @@ -48,6 +49,7 @@ def test_gpu_graph_cache(indptr_dtype, dtype, cache_size): 2, indptr.dtype, [e.dtype for e in edge_tensors], + not with_edge_ids, ) for i in range(10): @@ -59,7 +61,7 @@ def test_gpu_graph_cache(indptr_dtype, dtype, cache_size): missing_indptr, missing_edge_tensors, ) = torch.ops.graphbolt.index_select_csc_batched( - indptr, edge_tensors, missing_keys, True, None + indptr, edge_tensors, missing_keys, with_edge_ids, None ) output_indptr, output_edge_tensors = replace( missing_indptr, missing_edge_tensors @@ -69,7 +71,7 @@ def test_gpu_graph_cache(indptr_dtype, dtype, cache_size): reference_indptr, reference_edge_tensors, ) = torch.ops.graphbolt.index_select_csc_batched( - indptr, edge_tensors, keys, True, None + indptr, edge_tensors, keys, with_edge_ids, None ) assert torch.equal(output_indptr, reference_indptr) diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index 3f827923e2f5..547b8867fd41 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -3,13 +3,12 @@ import backend as F -import dgl import dgl.graphbolt as gb import pytest import torch -def get_hetero_graph(): +def get_hetero_graph(include_original_edge_ids): # COO graph: # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1] @@ -26,6 +25,10 @@ def get_hetero_graph(): ), "mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]), } + if include_original_edge_ids: + edge_attributes[gb.ORIGINAL_EDGE_ID] = ( + torch.arange(indices.size(0), 0, -1) - 1 + ) node_type_offset = torch.LongTensor([0, 1, 3, 6]) return gb.fused_csc_sampling_graph( indptr, @@ -44,8 +47,9 @@ def get_hetero_graph(): @pytest.mark.parametrize("sorted", [False, True]) @pytest.mark.parametrize("num_cached_edges", [0, 10]) @pytest.mark.parametrize("is_pinned", [False, True]) +@pytest.mark.parametrize("has_orig_edge_ids", [False, True]) def test_NeighborSampler_GraphFetch( - hetero, prob_name, sorted, num_cached_edges, is_pinned + hetero, prob_name, sorted, num_cached_edges, is_pinned, has_orig_edge_ids ): if sorted: items = torch.arange(3) @@ -53,7 +57,7 @@ def test_NeighborSampler_GraphFetch( items = torch.tensor([2, 0, 1]) names = "seeds" itemset = gb.ItemSet(items, names=names) - graph = get_hetero_graph() + graph = get_hetero_graph(has_orig_edge_ids) graph = graph.pin_memory_() if is_pinned else graph.to(F.ctx()) if hetero: itemset = gb.HeteroItemSet({"n3": itemset})