Skip to content

Commit

Permalink
simplify code and fix indentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 20, 2024
1 parent f4f5a0a commit df89b5d
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ constexpr int kIntGrainSize = 256;

torch::Tensor AddOffset(torch::Tensor keys, int64_t offset) {
if (offset == 0) return keys;
const auto numel = keys.size(0);
auto output =
torch::empty(numel, keys.options().pinned_memory(utils::is_pinned(keys)));
auto output = torch::empty_like(
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "AddOffset", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto output_ptr = output.data_ptr<index_t>();
graphbolt::parallel_for_each(0, numel, kIntGrainSize, [&](int64_t i) {
const auto result = keys_ptr[i] + offset;
if constexpr (!std::is_same_v<index_t, int64_t>) {
TORCH_CHECK(
std::numeric_limits<index_t>::min() <= result &&
result <= std::numeric_limits<index_t>::max());
}
output_ptr[i] = static_cast<index_t>(result);
});
graphbolt::parallel_for_each(
0, keys.numel(), kIntGrainSize, [&](int64_t i) {
const auto result = keys_ptr[i] + offset;
if constexpr (!std::is_same_v<index_t, int64_t>) {
TORCH_CHECK(
std::numeric_limits<index_t>::min() <= result &&
result <= std::numeric_limits<index_t>::max());
}
output_ptr[i] = static_cast<index_t>(result);
});
}));
return output;
}
Expand Down

0 comments on commit df89b5d

Please sign in to comment.