Skip to content

Commit

Permalink
Fix the segmentation fault and add extra runtime checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jun 26, 2024
1 parent 7bdd685 commit 44c0aa9
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 79 deletions.
87 changes: 49 additions & 38 deletions graphbolt/src/cuda/extension/gpu_graph_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ GpuGraphCache::~GpuGraphCache() {

std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
torch::Tensor seeds) {
TORCH_CHECK(seeds.device().is_cuda(), "Seeds should be on a CUDA device.");
TORCH_CHECK(
seeds.device().index() == device_id_,
"Seeds should be on the correct CUDA device.");
TORCH_CHECK(seeds.sizes().size() == 1, "Keys should be a 1D tensor.");
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
const dim3 block(kIntBlockSize);
Expand All @@ -175,8 +180,7 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
map->capacity() * kIntGrowthFactor,
cuco::cuda_stream_ref{cuda::GetCurrentStream()});
}
auto positions =
torch::empty(seeds.size(0), seeds.options().dtype(index_dtype));
auto positions = torch::empty_like(seeds);
CUDA_KERNEL_CALL(
_QueryAndIncrement, grid, block, 0,
static_cast<int64_t>(seeds.size(0)), seeds.data_ptr<index_t>(),
Expand Down Expand Up @@ -211,10 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
thrust::counting_iterator<index_t> iota{0};
auto position_and_index =
thrust::make_zip_iterator(positions.data_ptr<index_t>(), iota);
auto output_positions =
torch::empty(seeds.size(0), seeds.options().dtype(index_dtype));
auto output_indices =
torch::empty(seeds.size(0), seeds.options().dtype(index_dtype));
auto output_positions = torch::empty_like(seeds);
auto output_indices = torch::empty_like(seeds);
auto output_position_and_index = thrust::make_zip_iterator(
output_positions.data_ptr<index_t>(),
output_indices.data_ptr<index_t>());
Expand Down Expand Up @@ -243,6 +245,10 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
num_tensors == cached_edge_tensors_.size(),
"Same number of tensors need to be passed!");
const auto num_nodes = seeds.size(0);
TORCH_CHECK(
indptr.size(0) == num_nodes - num_hit + 1,
"(indptr.size(0) == seeds.size(0) - num_hit + 1) failed.");
const int64_t num_buffers = num_nodes * num_tensors;
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
return AT_DISPATCH_INDEX_TYPES(
Expand All @@ -268,6 +274,11 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
cached_edge_tensors_[i].scalar_type() ==
edge_tensors[i].scalar_type(),
"The dtypes of edge tensors must match.");
if (i > 0) {
TORCH_CHECK(
edge_tensors[i - 1].size(0) == edge_tensors[i].size(0),
"The missing edge tensors should have identical size.");
}
cache_missing_dtype_ptr[i] = {
reinterpret_cast<std::byte*>(
cached_edge_tensors_[i].data_ptr()),
Expand All @@ -281,11 +292,9 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
copy_n, cache_missing_dtype_ptr, num_tensors,
cache_missing_dtype_dev.get());

auto input = allocator.AllocateStorage<std::byte*>(
num_tensors * num_nodes);
auto input = allocator.AllocateStorage<std::byte*>(num_buffers);
auto input_size =
allocator.AllocateStorage<size_t>(num_tensors * num_nodes);

allocator.AllocateStorage<size_t>(num_buffers + 1);
const auto cache_missing_dtype_dev_ptr =
cache_missing_dtype_dev.get();
const auto indices_ptr = indices.data_ptr<indices_t>();
Expand All @@ -294,28 +303,25 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
const auto input_size_ptr = input_size.get();
const auto cache_indptr = indptr_.data_ptr<indptr_t>();
const auto missing_indptr = indptr.data_ptr<indptr_t>();
CUB_CALL(
DeviceFor::Bulk, num_tensors * num_nodes,
[=] __device__(int64_t i) {
const auto tensor_idx = i / num_nodes;
const auto idx = i % num_nodes;
const auto pos = positions_ptr[idx];
const auto original_idx = indices_ptr[idx];
const auto [cache_ptr, missing_ptr, size] =
cache_missing_dtype_dev_ptr[tensor_idx];
const auto is_cached = pos >= 0;
const auto offset = is_cached
? cache_indptr[pos]
: missing_indptr[idx - num_hit];
const auto offset_end =
is_cached ? cache_indptr[pos + 1]
: missing_indptr[idx - num_hit + 1];
const auto out_idx = tensor_idx * num_nodes + original_idx;

input_ptr[out_idx] =
(is_cached ? cache_ptr : missing_ptr) + offset * size;
input_size_ptr[out_idx] = size * (offset_end - offset);
});
CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) {
const auto tensor_idx = i / num_nodes;
const auto idx = i % num_nodes;
const auto pos = positions_ptr[idx];
const auto original_idx = indices_ptr[idx];
const auto [cache_ptr, missing_ptr, size] =
cache_missing_dtype_dev_ptr[tensor_idx];
const auto is_cached = pos >= 0;
const auto offset = is_cached ? cache_indptr[pos]
: missing_indptr[idx - num_hit];
const auto offset_end = is_cached
? cache_indptr[pos + 1]
: missing_indptr[idx - num_hit + 1];
const auto out_idx = tensor_idx * num_nodes + original_idx;

input_ptr[out_idx] =
(is_cached ? cache_ptr : missing_ptr) + offset * size;
input_size_ptr[out_idx] = size * (offset_end - offset);
});
auto output_indptr = torch::empty(
num_nodes + 1, seeds.options().dtype(indptr_.scalar_type()));
auto output_indptr_ptr = output_indptr.data_ptr<indptr_t>();
Expand All @@ -340,8 +346,8 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
[=] __host__ __device__(indices_t x) {
return x == threshold;
});
auto output_indices = torch::empty(
num_threshold, seeds.options().dtype(index_dtype));
auto output_indices =
torch::empty(num_threshold, seeds.options());
CUB_CALL(
DeviceSelect::Flagged, iota, is_threshold,
output_indices.data_ptr<indices_t>(),
Expand All @@ -364,9 +370,9 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
indptr.size(0) - 2, cached_output_size);
cached_output_size = sindices.size(0);
enough_space = num_edges_ + *cached_output_size <=
cached_edge_tensors_.at(0).size(0);
cached_edge_tensors_[i].size(0);
if (enough_space) {
cached_edge_tensors_.at(i).slice(
cached_edge_tensors_[i].slice(
0, num_edges_, num_edges_ + *cached_output_size) =
sindices;
} else
Expand Down Expand Up @@ -406,7 +412,7 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
for (size_t i = 0; i < num_tensors; i++) {
output_edge_tensors.push_back(torch::empty(
static_cast<indptr_t>(output_size),
seeds.options().dtype(edge_tensors[i].scalar_type())));
cached_edge_tensors_[i].options()));
output_tensor_ptrs_ptr[i] = {
reinterpret_cast<std::byte*>(
output_edge_tensors.back().data_ptr()),
Expand All @@ -419,6 +425,11 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
THRUST_CALL(
copy_n, output_tensor_ptrs_ptr, num_tensors,
output_tensor_ptrs_dev.get());
// This event and the later synchronization is needed so that the
// allocated pinned tensor stays alive during the copy operation.
// @todo mfbalin: eliminate this synchronization.
at::cuda::CUDAEvent copy_event;
copy_event.record(cuda::GetCurrentStream());

{
thrust::counting_iterator<int64_t> iota{0};
Expand All @@ -434,7 +445,6 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
const int64_t num_buffers = num_nodes * num_tensors;
for (int64_t i = 0; i < num_buffers; i += max_copy_at_once) {
CUB_CALL(
DeviceMemcpy::Batched, input.get() + i,
Expand All @@ -443,6 +453,7 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
}
}

copy_event.synchronize();
return std::make_tuple(output_indptr, output_edge_tensors);
}));
}));
Expand Down
35 changes: 33 additions & 2 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Graph Bolt DataLoaders"""

from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor

import torch
Expand All @@ -9,6 +10,7 @@

from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer

from .internal import datapipe_graph_to_adjlist
Expand All @@ -17,9 +19,34 @@

__all__ = [
"DataLoader",
"construct_gpu_graph_cache",
]


def construct_gpu_graph_cache(
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
):
"Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters."
graph = sample_per_layer_obj.sampler.__self__
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
dtypes = OrderedDict()
dtypes["indices"] = graph.indices.dtype
if graph.type_per_edge is not None:
dtypes["type_per_edge"] = graph.type_per_edge.dtype
if graph.edge_attributes is not None:
probs_or_mask = graph.edge_attributes.get(
sample_per_layer_obj.prob_name, None
)
if probs_or_mask is not None:
dtypes["probs_or_mask"] = probs_or_mask.dtype
return GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
graph.csc_indptr.dtype,
list(dtypes.values()),
)


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
Expand Down Expand Up @@ -197,13 +224,17 @@ def __init__(
SamplePerLayer,
)
executor = ThreadPoolExecutor(max_workers=1)
gpu_graph_cache = None
for sampler in samplers:
if gpu_graph_cache is None:
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
num_gpu_cached_edges,
gpu_cache_threshold,
gpu_graph_cache,
_get_uva_stream(),
executor,
1,
Expand Down
46 changes: 8 additions & 38 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Neighbor subgraph samplers for GraphBolt."""

from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from functools import partial

Expand All @@ -13,7 +12,6 @@

from ..subgraph_sampler import SubgraphSampler
from .fused_csc_sampling_graph import fused_csc_sampling_graph
from .gpu_graph_cache import GPUGraphCache
from .sampled_subgraph_impl import SampledSubgraphImpl


Expand All @@ -34,33 +32,9 @@ class FetchCachedInsubgraphData(Mapper):
function that can be called with the fetched graph structure.
"""

def __init__(
self,
datapipe,
sample_per_layer_obj,
num_gpu_cached_edges,
gpu_cache_threshold,
):
def __init__(self, datapipe, gpu_graph_cache):
super().__init__(datapipe, self._fetch_per_layer)
graph = sample_per_layer_obj.sampler.__self__
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
dtypes = OrderedDict()
dtypes["indices"] = graph.indices.dtype
if graph.type_per_edge is not None:
dtypes["type_per_edge"] = graph.type_per_edge.dtype
if graph.edge_attributes is not None:
probs_or_mask = graph.edge_attributes.get(
sample_per_layer_obj.prob_name, None
)
if probs_or_mask is not None:
dtypes["probs_or_mask"] = probs_or_mask.dtype
self.cache = GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
graph.csc_indptr.dtype,
list(dtypes.values()),
)
self.dtypes = dtypes
self.cache = gpu_graph_cache

def _fetch_per_layer(self, minibatch):
minibatch._seeds, minibatch._replace = self.cache.query(
Expand Down Expand Up @@ -143,17 +117,14 @@ def __init__(
self,
datapipe,
sample_per_layer_obj,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
gpu_graph_cache,
stream=None,
executor=None,
):
self.graph = sample_per_layer_obj.sampler.__self__
datapipe = datapipe.concat_hetero_seeds(sample_per_layer_obj)
if num_gpu_cached_edges > 0:
datapipe = datapipe.fetch_cached_insubgraph_data(
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
)
if gpu_graph_cache is not None:
datapipe = datapipe.fetch_cached_insubgraph_data(gpu_graph_cache)
super().__init__(datapipe, self._fetch_per_layer)
self.prob_name = sample_per_layer_obj.prob_name
self.stream = stream
Expand Down Expand Up @@ -335,17 +306,16 @@ class FetcherAndSampler(MiniBatchTransformer):
def __init__(
self,
sampler,
num_gpu_cached_edges,
gpu_cache_threshold,
gpu_graph_cache,
stream,
executor,
buffer_size,
):
datapipe = sampler.datapipe.fetch_insubgraph_data(
sampler, num_gpu_cached_edges, gpu_cache_threshold, stream, executor
sampler, gpu_graph_cache, stream, executor
)
datapipe = datapipe.buffer(buffer_size).wait_future().wait()
if num_gpu_cached_edges > 0:
if gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(sampler)
datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler)
super().__init__(datapipe)
Expand Down
10 changes: 9 additions & 1 deletion tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dgl.graphbolt as gb
import pytest
import torch
from dgl.graphbolt.dataloader import construct_gpu_graph_cache


def get_hetero_graph():
Expand Down Expand Up @@ -69,8 +70,15 @@ def test_NeighborSampler_GraphFetch(
compact_per_layer = sample_per_layer.compact_per_layer(True)
gb.seed(123)
expected_results = list(compact_per_layer)
gpu_graph_cache = None
if num_cached_edges > 0:
gpu_graph_cache = construct_gpu_graph_cache(
sample_per_layer, num_cached_edges, 1
)
datapipe = gb.FetchInsubgraphData(
datapipe, sample_per_layer, num_cached_edges, 1
datapipe,
sample_per_layer,
gpu_graph_cache,
)
datapipe = datapipe.wait_future()
if num_cached_edges > 0:
Expand Down

0 comments on commit 44c0aa9

Please sign in to comment.