Skip to content

Commit

Permalink
don't take insertion code path when there is nothing to insert.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jun 26, 2024
1 parent 0fcf5fe commit 9cb17bd
Showing 1 changed file with 66 additions and 62 deletions.
128 changes: 66 additions & 62 deletions graphbolt/src/cuda/extension/gpu_graph_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,69 +340,73 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
num_nodes + 1);
CopyScalar output_size{output_indptr_ptr + num_nodes};

auto missing_positions = positions.slice(0, num_hit);
auto missing_indices = indices.slice(0, num_hit);

thrust::counting_iterator<indices_t> iota{0};
auto threshold = -threshold_;
auto is_threshold = thrust::make_transform_iterator(
missing_positions.data_ptr<indices_t>(),
[=] __host__ __device__(indices_t x) {
return x == threshold;
});
auto output_indices =
torch::empty(num_threshold, seeds.options());
CUB_CALL(
DeviceSelect::Flagged, iota, is_threshold,
output_indices.data_ptr<indices_t>(),
cub::DiscardOutputIterator{}, missing_positions.size(0));
auto [in_degree, sliced_indptr] =
ops::SliceCSCIndptr(indptr, output_indices);
while (num_nodes_ + num_threshold >= indptr_.size(0)) {
auto new_indptr = torch::empty(
indptr_.size(0) * kIntGrowthFactor, indptr_.options());
new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;
indptr_ = new_indptr;
}
torch::Tensor sindptr;
bool enough_space;
torch::optional<int64_t> cached_output_size;
for (size_t i = 0; i < edge_tensors.size(); i++) {
torch::Tensor sindices;
std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(
in_degree, sliced_indptr, edge_tensors[i], output_indices,
indptr.size(0) - 2, cached_output_size);
cached_output_size = sindices.size(0);
enough_space = num_edges_ + *cached_output_size <=
cached_edge_tensors_[i].size(0);
if (enough_space) {
cached_edge_tensors_[i].slice(
0, num_edges_, num_edges_ + *cached_output_size) =
sindices;
} else
break;
}
if (enough_space) {
auto num_edges = num_edges_;
THRUST_CALL(
transform, sindptr.data_ptr<indptr_t>() + 1,
sindptr.data_ptr<indptr_t>() + sindptr.size(0),
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
[=] __host__ __device__(indptr_t x) {
return x + num_edges;
if (num_threshold > 0) {
// Insert the vertices whose access count equal threshold.
auto missing_positions = positions.slice(0, num_hit);
auto missing_indices = indices.slice(0, num_hit);

thrust::counting_iterator<indices_t> iota{0};
auto threshold = -threshold_;
auto is_threshold = thrust::make_transform_iterator(
missing_positions.data_ptr<indices_t>(),
[=] __host__ __device__(indices_t x) {
return x == threshold;
});
auto map = reinterpret_cast<map_t<indices_t>*>(map_);
const dim3 block(kIntBlockSize);
const dim3 grid(
(num_threshold + kIntBlockSize - 1) / kIntBlockSize);
CUDA_KERNEL_CALL(
_Insert, grid, block, 0, output_indices.size(0),
static_cast<indices_t>(num_nodes_),
seeds.data_ptr<indices_t>(),
missing_indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>(), map->ref(cuco::find));
num_edges_ += *cached_output_size;
num_nodes_ += num_threshold;
auto output_indices =
torch::empty(num_threshold, seeds.options());
CUB_CALL(
DeviceSelect::Flagged, iota, is_threshold,
output_indices.data_ptr<indices_t>(),
cub::DiscardOutputIterator{}, missing_positions.size(0));
auto [in_degree, sliced_indptr] =
ops::SliceCSCIndptr(indptr, output_indices);
while (num_nodes_ + num_threshold >= indptr_.size(0)) {
auto new_indptr = torch::empty(
indptr_.size(0) * kIntGrowthFactor, indptr_.options());
new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;
indptr_ = new_indptr;
}
torch::Tensor sindptr;
bool enough_space;
torch::optional<int64_t> cached_output_size;
for (size_t i = 0; i < edge_tensors.size(); i++) {
torch::Tensor sindices;
std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(
in_degree, sliced_indptr, edge_tensors[i], output_indices,
indptr.size(0) - 2, cached_output_size);
cached_output_size = sindices.size(0);
enough_space = num_edges_ + *cached_output_size <=
cached_edge_tensors_[i].size(0);
if (enough_space) {
cached_edge_tensors_[i].slice(
0, num_edges_, num_edges_ + *cached_output_size) =
sindices;
} else
break;
}
if (enough_space) {
auto num_edges = num_edges_;
THRUST_CALL(
transform, sindptr.data_ptr<indptr_t>() + 1,
sindptr.data_ptr<indptr_t>() + sindptr.size(0),
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
[=] __host__ __device__(indptr_t x) {
return x + num_edges;
});
auto map = reinterpret_cast<map_t<indices_t>*>(map_);
const dim3 block(kIntBlockSize);
const dim3 grid(
(num_threshold + kIntBlockSize - 1) / kIntBlockSize);
CUDA_KERNEL_CALL(
_Insert, grid, block, 0, output_indices.size(0),
static_cast<indices_t>(num_nodes_),
seeds.data_ptr<indices_t>(),
missing_indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>(),
map->ref(cuco::find));
num_edges_ += *cached_output_size;
num_nodes_ += num_threshold;
}
}

constexpr int alignment = 128;
Expand Down

0 comments on commit 9cb17bd

Please sign in to comment.