Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] gb.indptr_edge_ids. #7592

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,25 @@ torch::Tensor ExpandIndptrImpl(
torch::optional<torch::Tensor> node_ids = torch::nullopt,
torch::optional<int64_t> output_size = torch::nullopt);

/**
* @brief IndptrEdgeIdsImpl implements conversion from a given indptr offset
* tensor to a COO edge ids tensor. For a given indptr [0, 2, 5, 7] and offset
* tensor [0, 100, 200], the output will be [0, 1, 100, 101, 102, 201, 202]. If
* offset was not provided, the output would be [0, 1, 0, 1, 2, 0, 1].
*
* @param indptr The indptr offset tensor.
* @param dtype The dtype of the returned output tensor.
* @param offset The offset tensor.
* @param output_size Optional value of indptr[-1]. Passing it eliminates CPU
* GPU synchronization.
*
* @return The resulting tensor.
*/
torch::Tensor IndptrEdgeIdsImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> offset,
torch::optional<int64_t> output_size);

/**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
* 'src_ids' tensor and applies the uniqueness information to compact both
Expand Down
52 changes: 43 additions & 9 deletions graphbolt/src/cuda/expand_indptr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ struct RepeatIndex {
}
};

template <typename indices_t, typename nodes_t>
struct IotaIndex {
const nodes_t* nodes;
__host__ __device__ auto operator()(indices_t i) {
return thrust::make_counting_iterator(nodes ? nodes[i] : 0);
}
};

template <typename indptr_t, typename indices_t>
struct OutputBufferIndexer {
const indptr_t* indptr;
Expand All @@ -54,8 +62,8 @@ struct AdjacentDifference {

torch::Tensor ExpandIndptrImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> nodes,
torch::optional<int64_t> output_size) {
torch::optional<torch::Tensor> nodes, torch::optional<int64_t> output_size,
const bool edge_ids) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include edge ids? has edge ids?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let me change to a better name.

if (!output_size.has_value()) {
output_size = AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t {
Expand Down Expand Up @@ -84,8 +92,6 @@ torch::Tensor ExpandIndptrImpl(
nodes ? nodes.value().data_ptr<nodes_t>() : nullptr;

thrust::counting_iterator<int64_t> iota(0);
auto input_buffer = thrust::make_transform_iterator(
iota, RepeatIndex<indices_t, nodes_t>{nodes_ptr});
auto output_buffer = thrust::make_transform_iterator(
iota, OutputBufferIndexer<indptr_t, indices_t>{
indptr_ptr, csc_rows_ptr});
Expand All @@ -95,17 +101,45 @@ torch::Tensor ExpandIndptrImpl(
const auto num_rows = indptr.size(0) - 1;
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
CUB_CALL(
DeviceCopy::Batched, input_buffer + i,
output_buffer + i, buffer_sizes + i,
std::min(num_rows - i, max_copy_at_once));

if (edge_ids) {
auto input_buffer = thrust::make_transform_iterator(
iota, IotaIndex<indices_t, nodes_t>{nodes_ptr});
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
CUB_CALL(
DeviceCopy::Batched, input_buffer + i,
output_buffer + i, buffer_sizes + i,
std::min(num_rows - i, max_copy_at_once));
}
} else {
auto input_buffer = thrust::make_transform_iterator(
iota, RepeatIndex<indices_t, nodes_t>{nodes_ptr});
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
CUB_CALL(
DeviceCopy::Batched, input_buffer + i,
output_buffer + i, buffer_sizes + i,
std::min(num_rows - i, max_copy_at_once));
}
}
}));
}));
}));
return csc_rows;
}

torch::Tensor ExpandIndptrImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> nodes,
torch::optional<int64_t> output_size) {
return ExpandIndptrImpl(indptr, dtype, nodes, output_size, false);
}

torch::Tensor IndptrEdgeIdsImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> offset,
torch::optional<int64_t> output_size) {
return ExpandIndptrImpl(indptr, dtype, offset, output_size, true);
}

} // namespace ops
} // namespace graphbolt
7 changes: 2 additions & 5 deletions graphbolt/src/cuda/insubgraph.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
in_degree, sliced_indptr, type_per_edge.value(), nodes,
indptr.size(0) - 2, num_edges));
}
auto rows = ExpandIndptrImpl(
output_indptr, indices.scalar_type(), torch::nullopt, num_edges);
auto i = torch::arange(output_indices.size(0), output_indptr.options());
auto edge_ids =
i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows);
auto edge_ids = IndptrEdgeIdsImpl(
output_indptr, sliced_indptr.scalar_type(), sliced_indptr, num_edges);

return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt, edge_ids,
Expand Down
27 changes: 27 additions & 0 deletions graphbolt/src/expand_indptr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ torch::Tensor ExpandIndptr(
indptr.diff(), 0, output_size);
}

torch::Tensor IndptrEdgeIds(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> offset,
torch::optional<int64_t> output_size) {
if (utils::is_on_gpu(indptr) &&
(!offset.has_value() || utils::is_on_gpu(offset.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndptrEdgeIds",
{ return IndptrEdgeIdsImpl(indptr, dtype, offset, output_size); });
}
TORCH_CHECK(false, "CPU implementation of IndptrEdgeIds is not available.");
}

TORCH_LIBRARY_IMPL(graphbolt, CPU, m) {
m.impl("expand_indptr", &ExpandIndptr);
}
Expand All @@ -44,5 +57,19 @@ TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback());
}

TORCH_LIBRARY_IMPL(graphbolt, CPU, m) {
m.impl("indptr_edge_ids", &IndptrEdgeIds);
}

#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {
m.impl("indptr_edge_ids", &IndptrEdgeIdsImpl);
}
#endif

TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
m.impl("indptr_edge_ids", torch::autograd::autogradNotImplementedFallback());
}

} // namespace ops
} // namespace graphbolt
19 changes: 19 additions & 0 deletions graphbolt/src/expand_indptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ torch::Tensor ExpandIndptr(
torch::optional<torch::Tensor> node_ids = torch::nullopt,
torch::optional<int64_t> output_size = torch::nullopt);

/**
* @brief IndptrEdgeIdsImpl implements conversion from a given indptr offset
* tensor to a COO edge ids tensor. For a given indptr [0, 2, 5, 7] and offset
* tensor [0, 100, 200], the output will be [0, 1, 100, 101, 102, 201, 202]. If
* offset was not provided, the output would be [0, 1, 0, 1, 2, 0, 1].
*
* @param indptr The indptr offset tensor.
* @param dtype The dtype of the returned output tensor.
* @param offset The offset tensor.
* @param output_size Optional value of indptr[-1]. Passing it eliminates CPU
* GPU synchronization.
*
* @return The resulting tensor.
*/
torch::Tensor IndptrEdgeIds(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> offset,
torch::optional<int64_t> output_size);

} // namespace ops
} // namespace graphbolt

Expand Down
9 changes: 9 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ TORCH_LIBRARY(graphbolt, m) {
#ifdef HAS_PT2_COMPLIANT_TAG
,
{at::Tag::pt2_compliant_tag}
#endif
);
m.def(
"indptr_edge_ids(Tensor indptr, ScalarType dtype, Tensor? offset, "
"SymInt? output_size) -> "
"Tensor"
#ifdef HAS_PT2_COMPLIANT_TAG
,
{at::Tag::pt2_compliant_tag}
#endif
);
}
Expand Down
Loading