Skip to content

Commit

Permalink
[GraphBolt][CUDA] Gpu graph cache cpp (#7483)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jun 27, 2024
1 parent 95dc96a commit 08bdd9f
Showing 1 changed file with 137 additions and 126 deletions.
263 changes: 137 additions & 126 deletions graphbolt/src/cuda/extension/gpu_graph_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ GpuGraphCache::~GpuGraphCache() {

std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
torch::Tensor seeds) {
TORCH_CHECK(seeds.device().is_cuda(), "Seeds should be on a CUDA device.");
TORCH_CHECK(
seeds.device().index() == device_id_,
"Seeds should be on the correct CUDA device.");
TORCH_CHECK(seeds.sizes().size() == 1, "Keys should be a 1D tensor.");
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
const dim3 block(kIntBlockSize);
Expand All @@ -175,8 +180,7 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
map->capacity() * kIntGrowthFactor,
cuco::cuda_stream_ref{cuda::GetCurrentStream()});
}
auto positions =
torch::empty(seeds.size(0), seeds.options().dtype(index_dtype));
auto positions = torch::empty_like(seeds);
CUDA_KERNEL_CALL(
_QueryAndIncrement, grid, block, 0,
static_cast<int64_t>(seeds.size(0)), seeds.data_ptr<index_t>(),
Expand Down Expand Up @@ -211,10 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
thrust::counting_iterator<index_t> iota{0};
auto position_and_index =
thrust::make_zip_iterator(positions.data_ptr<index_t>(), iota);
auto output_positions =
torch::empty(seeds.size(0), seeds.options().dtype(index_dtype));
auto output_indices =
torch::empty(seeds.size(0), seeds.options().dtype(index_dtype));
auto output_positions = torch::empty_like(seeds);
auto output_indices = torch::empty_like(seeds);
auto output_position_and_index = thrust::make_zip_iterator(
output_positions.data_ptr<index_t>(),
output_indices.data_ptr<index_t>());
Expand Down Expand Up @@ -243,6 +245,10 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
num_tensors == cached_edge_tensors_.size(),
"Same number of tensors need to be passed!");
const auto num_nodes = seeds.size(0);
TORCH_CHECK(
indptr.size(0) == num_nodes - num_hit + 1,
"(indptr.size(0) == seeds.size(0) - num_hit + 1) failed.");
const int64_t num_buffers = num_nodes * num_tensors;
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
return AT_DISPATCH_INDEX_TYPES(
Expand All @@ -257,35 +263,44 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
static_assert(
sizeof(std::byte) == 1, "Byte needs to have a size of 1.");
auto cache_missing_dtype = torch::empty(
3 * num_tensors, c10::TensorOptions()
// Below, we use this storage to store a tuple of 4 elements,
// since each element is 64-bit, we need 4x int64 storage.
4 * num_tensors, c10::TensorOptions()
.dtype(torch::kInt64)
.pinned_memory(true));
auto cache_missing_dtype_ptr = reinterpret_cast<
::cuda::std::tuple<std::byte*, std::byte*, int64_t>*>(
cache_missing_dtype.data_ptr());
auto cache_missing_dtype_ptr =
reinterpret_cast<::cuda::std::tuple<
std::byte*, std::byte*, int64_t, int64_t>*>(
cache_missing_dtype.data_ptr());
int64_t total_size = 0;
for (size_t i = 0; i < num_tensors; i++) {
TORCH_CHECK(
cached_edge_tensors_[i].scalar_type() ==
edge_tensors[i].scalar_type(),
"The dtypes of edge tensors must match.");
if (i > 0) {
TORCH_CHECK(
edge_tensors[i - 1].size(0) == edge_tensors[i].size(0),
"The missing edge tensors should have identical size.");
}
const int64_t element_size = edge_tensors[i].element_size();
cache_missing_dtype_ptr[i] = {
reinterpret_cast<std::byte*>(
cached_edge_tensors_[i].data_ptr()),
reinterpret_cast<std::byte*>(edge_tensors[i].data_ptr()),
edge_tensors[i].element_size()};
element_size, total_size};
total_size += element_size;
}
auto cache_missing_dtype_dev = allocator.AllocateStorage<
::cuda::std::tuple<std::byte*, std::byte*, int64_t>>(
::cuda::std::tuple<std::byte*, std::byte*, int64_t, int64_t>>(
num_tensors);
THRUST_CALL(
copy_n, cache_missing_dtype_ptr, num_tensors,
cache_missing_dtype_dev.get());

auto input = allocator.AllocateStorage<std::byte*>(
num_tensors * num_nodes);
auto input = allocator.AllocateStorage<std::byte*>(num_buffers);
auto input_size =
allocator.AllocateStorage<size_t>(num_tensors * num_nodes);

allocator.AllocateStorage<size_t>(num_buffers + 1);
const auto cache_missing_dtype_dev_ptr =
cache_missing_dtype_dev.get();
const auto indices_ptr = indices.data_ptr<indices_t>();
Expand All @@ -294,28 +309,25 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
const auto input_size_ptr = input_size.get();
const auto cache_indptr = indptr_.data_ptr<indptr_t>();
const auto missing_indptr = indptr.data_ptr<indptr_t>();
CUB_CALL(
DeviceFor::Bulk, num_tensors * num_nodes,
[=] __device__(int64_t i) {
const auto tensor_idx = i / num_nodes;
const auto idx = i % num_nodes;
const auto pos = positions_ptr[idx];
const auto original_idx = indices_ptr[idx];
const auto [cache_ptr, missing_ptr, size] =
cache_missing_dtype_dev_ptr[tensor_idx];
const auto is_cached = pos >= 0;
const auto offset = is_cached
? cache_indptr[pos]
: missing_indptr[idx - num_hit];
const auto offset_end =
is_cached ? cache_indptr[pos + 1]
: missing_indptr[idx - num_hit + 1];
const auto out_idx = tensor_idx * num_nodes + original_idx;

input_ptr[out_idx] =
(is_cached ? cache_ptr : missing_ptr) + offset * size;
input_size_ptr[out_idx] = size * (offset_end - offset);
});
CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) {
const auto tensor_idx = i / num_nodes;
const auto idx = i % num_nodes;
const auto pos = positions_ptr[idx];
const auto original_idx = indices_ptr[idx];
const auto [cache_ptr, missing_ptr, size, cum_size] =
cache_missing_dtype_dev_ptr[tensor_idx];
const auto is_cached = pos >= 0;
const auto offset = is_cached ? cache_indptr[pos]
: missing_indptr[idx - num_hit];
const auto offset_end = is_cached
? cache_indptr[pos + 1]
: missing_indptr[idx - num_hit + 1];
const auto out_idx = tensor_idx * num_nodes + original_idx;

input_ptr[out_idx] =
(is_cached ? cache_ptr : missing_ptr) + offset * size;
input_size_ptr[out_idx] = size * (offset_end - offset);
});
auto output_indptr = torch::empty(
num_nodes + 1, seeds.options().dtype(indptr_.scalar_type()));
auto output_indptr_ptr = output_indptr.data_ptr<indptr_t>();
Expand All @@ -330,111 +342,110 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
num_nodes + 1);
CopyScalar output_size{output_indptr_ptr + num_nodes};

auto missing_positions = positions.slice(0, num_hit);
auto missing_indices = indices.slice(0, num_hit);

thrust::counting_iterator<indices_t> iota{0};
auto threshold = -threshold_;
auto is_threshold = thrust::make_transform_iterator(
missing_positions.data_ptr<indices_t>(),
[=] __host__ __device__(indices_t x) {
return x == threshold;
});
auto output_indices = torch::empty(
num_threshold, seeds.options().dtype(index_dtype));
CUB_CALL(
DeviceSelect::Flagged, iota, is_threshold,
output_indices.data_ptr<indices_t>(),
cub::DiscardOutputIterator{}, missing_positions.size(0));
auto [in_degree, sliced_indptr] =
ops::SliceCSCIndptr(indptr, output_indices);
while (num_nodes_ + num_threshold >= indptr_.size(0)) {
auto new_indptr = torch::empty(
indptr_.size(0) * kIntGrowthFactor, indptr_.options());
new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;
indptr_ = new_indptr;
}
torch::Tensor sindptr;
bool enough_space;
torch::optional<int64_t> cached_output_size;
for (size_t i = 0; i < edge_tensors.size(); i++) {
torch::Tensor sindices;
std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(
in_degree, sliced_indptr, edge_tensors[i], output_indices,
indptr.size(0) - 2, cached_output_size);
cached_output_size = sindices.size(0);
enough_space = num_edges_ + *cached_output_size <=
cached_edge_tensors_.at(0).size(0);
if (enough_space) {
cached_edge_tensors_.at(i).slice(
0, num_edges_, num_edges_ + *cached_output_size) =
sindices;
} else
break;
}
if (enough_space) {
auto num_edges = num_edges_;
THRUST_CALL(
transform, sindptr.data_ptr<indptr_t>() + 1,
sindptr.data_ptr<indptr_t>() + sindptr.size(0),
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
[=] __host__ __device__(indptr_t x) {
return x + num_edges;
if (num_threshold > 0) {
// Insert the vertices whose access count equal threshold.
auto missing_positions = positions.slice(0, num_hit);
auto missing_indices = indices.slice(0, num_hit);

thrust::counting_iterator<indices_t> iota{0};
auto threshold = -threshold_;
auto is_threshold = thrust::make_transform_iterator(
missing_positions.data_ptr<indices_t>(),
[=] __host__ __device__(indices_t x) {
return x == threshold;
});
auto map = reinterpret_cast<map_t<indices_t>*>(map_);
const dim3 block(kIntBlockSize);
const dim3 grid(
(num_threshold + kIntBlockSize - 1) / kIntBlockSize);
CUDA_KERNEL_CALL(
_Insert, grid, block, 0, output_indices.size(0),
static_cast<indices_t>(num_nodes_),
seeds.data_ptr<indices_t>(),
missing_indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>(), map->ref(cuco::find));
num_edges_ += *cached_output_size;
num_nodes_ += num_threshold;
auto output_indices =
torch::empty(num_threshold, seeds.options());
CUB_CALL(
DeviceSelect::Flagged, iota, is_threshold,
output_indices.data_ptr<indices_t>(),
cub::DiscardOutputIterator{}, missing_positions.size(0));
auto [in_degree, sliced_indptr] =
ops::SliceCSCIndptr(indptr, output_indices);
while (num_nodes_ + num_threshold >= indptr_.size(0)) {
auto new_indptr = torch::empty(
indptr_.size(0) * kIntGrowthFactor, indptr_.options());
new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;
indptr_ = new_indptr;
}
torch::Tensor sindptr;
bool enough_space;
torch::optional<int64_t> cached_output_size;
for (size_t i = 0; i < edge_tensors.size(); i++) {
torch::Tensor sindices;
std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(
in_degree, sliced_indptr, edge_tensors[i], output_indices,
indptr.size(0) - 2, cached_output_size);
cached_output_size = sindices.size(0);
enough_space = num_edges_ + *cached_output_size <=
cached_edge_tensors_[i].size(0);
if (enough_space) {
cached_edge_tensors_[i].slice(
0, num_edges_, num_edges_ + *cached_output_size) =
sindices;
} else
break;
}
if (enough_space) {
auto num_edges = num_edges_;
THRUST_CALL(
transform, sindptr.data_ptr<indptr_t>() + 1,
sindptr.data_ptr<indptr_t>() + sindptr.size(0),
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
[=] __host__ __device__(indptr_t x) {
return x + num_edges;
});
auto map = reinterpret_cast<map_t<indices_t>*>(map_);
const dim3 block(kIntBlockSize);
const dim3 grid(
(num_threshold + kIntBlockSize - 1) / kIntBlockSize);
CUDA_KERNEL_CALL(
_Insert, grid, block, 0, output_indices.size(0),
static_cast<indices_t>(num_nodes_),
seeds.data_ptr<indices_t>(),
missing_indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>(),
map->ref(cuco::find));
num_edges_ += *cached_output_size;
num_nodes_ += num_threshold;
}
}

constexpr int alignment = 128;
const auto output_allocation_count =
(static_cast<indptr_t>(output_size) + alignment - 1) /
alignment * alignment;
auto output_allocation = torch::empty(
output_allocation_count * total_size,
seeds.options().dtype(torch::kInt8));
const auto output_allocation_ptr =
output_allocation.data_ptr<int8_t>();

std::vector<torch::Tensor> output_edge_tensors;
auto output_tensor_ptrs = torch::empty(
2 * num_tensors, c10::TensorOptions()
.dtype(torch::kInt64)
.pinned_memory(true));
const auto output_tensor_ptrs_ptr =
reinterpret_cast<::cuda::std::tuple<std::byte*, int64_t>*>(
output_tensor_ptrs.data_ptr());
for (size_t i = 0; i < num_tensors; i++) {
output_edge_tensors.push_back(torch::empty(
static_cast<indptr_t>(output_size),
seeds.options().dtype(edge_tensors[i].scalar_type())));
output_tensor_ptrs_ptr[i] = {
reinterpret_cast<std::byte*>(
output_edge_tensors.back().data_ptr()),
::cuda::std::get<2>(cache_missing_dtype_ptr[i])};
const auto cum_size =
::cuda::std::get<3>(cache_missing_dtype_ptr[i]);
output_edge_tensors.push_back(
output_allocation
.slice(0, cum_size * output_allocation_count)
.view(edge_tensors[i].scalar_type())
.slice(0, 0, static_cast<indptr_t>(output_size)));
}
auto output_tensor_ptrs_dev =
allocator
.AllocateStorage<::cuda::std::tuple<std::byte*, int64_t>>(
num_tensors);
THRUST_CALL(
copy_n, output_tensor_ptrs_ptr, num_tensors,
output_tensor_ptrs_dev.get());

{
thrust::counting_iterator<int64_t> iota{0};
auto output_tensor_ptrs_dev_ptr = output_tensor_ptrs_dev.get();
auto output_buffer_it = thrust::make_transform_iterator(
iota, [=] __host__ __device__(int64_t i) {
const auto tensor_idx = i / num_nodes;
const auto idx = i % num_nodes;
const auto offset = output_indptr_ptr[idx];
const auto [output_ptr, size] =
output_tensor_ptrs_dev_ptr[tensor_idx];
return output_ptr + offset * size;
const auto [_0, _1, size, cum_size] =
cache_missing_dtype_dev_ptr[tensor_idx];
return output_allocation_ptr +
cum_size * output_allocation_count + offset * size;
});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
const int64_t num_buffers = num_nodes * num_tensors;
for (int64_t i = 0; i < num_buffers; i += max_copy_at_once) {
CUB_CALL(
DeviceMemcpy::Batched, input.get() + i,
Expand Down

0 comments on commit 08bdd9f

Please sign in to comment.