Skip to content

Commit

Permalink
[GraphBolt] Make CachePolicy hetero capable.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 20, 2024
1 parent 2ce0ea0 commit f4f5a0a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 30 deletions.
54 changes: 44 additions & 10 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,37 @@
#include "./partitioned_cache_policy.h"

#include <algorithm>
#include <limits>
#include <numeric>

#include "./utils.h"

namespace graphbolt {
namespace storage {

constexpr int kIntGrainSize = 64;
constexpr int kIntGrainSize = 256;

torch::Tensor AddOffset(torch::Tensor keys, int64_t offset) {
if (offset == 0) return keys;
const auto numel = keys.size(0);
auto output =
torch::empty(numel, keys.options().pinned_memory(utils::is_pinned(keys)));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "AddOffset", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto output_ptr = output.data_ptr<index_t>();
graphbolt::parallel_for_each(0, numel, kIntGrainSize, [&](int64_t i) {
const auto result = keys_ptr[i] + offset;
if constexpr (!std::is_same_v<index_t, int64_t>) {
TORCH_CHECK(
std::numeric_limits<index_t>::min() <= result &&
result <= std::numeric_limits<index_t>::max());
}
output_ptr[i] = static_cast<index_t>(result);
});
}));
return output;
}

template <typename CachePolicy>
PartitionedCachePolicy::PartitionedCachePolicy(
Expand Down Expand Up @@ -117,7 +140,8 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) {
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
PartitionedCachePolicy::Query(torch::Tensor keys) {
PartitionedCachePolicy::Query(torch::Tensor keys, const int64_t offset) {
keys = AddOffset(keys, offset);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
auto [positions, output_indices, missing_keys, found_pointers] =
Expand All @@ -133,6 +157,7 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
found_and_missing_offsets_ptr[3] = missing_keys.size(0);
auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
missing_keys = AddOffset(missing_keys, -offset);
return {positions, output_indices, missing_keys,
found_pointers, found_offsets, missing_offsets};
};
Expand Down Expand Up @@ -211,17 +236,18 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
num_missing * missing_keys.element_size());
});
auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);
missing_keys = AddOffset(missing_keys, -offset);
return std::make_tuple(
positions, output_indices, missing_keys, found_pointers, found_offsets,
missing_offsets);
}

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
PartitionedCachePolicy::QueryAsync(torch::Tensor keys, const int64_t offset) {
return async([=] {
auto
[positions, output_indices, missing_keys, found_pointers, found_offsets,
missing_offsets] = Query(keys);
missing_offsets] = Query(keys, offset);
return std::vector{positions, output_indices, missing_keys,
found_pointers, found_offsets, missing_offsets};
});
Expand All @@ -230,7 +256,9 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
PartitionedCachePolicy::QueryAndReplace(torch::Tensor keys) {
PartitionedCachePolicy::QueryAndReplace(
torch::Tensor keys, const int64_t offset) {
keys = AddOffset(keys, offset);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
auto [positions, output_indices, pointers, missing_keys] =
Expand All @@ -246,6 +274,7 @@ PartitionedCachePolicy::QueryAndReplace(torch::Tensor keys) {
found_and_missing_offsets_ptr[3] = missing_keys.size(0);
auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
missing_keys = AddOffset(missing_keys, -offset);
return {positions, output_indices, pointers,
missing_keys, found_offsets, missing_offsets};
}
Expand Down Expand Up @@ -336,25 +365,29 @@ PartitionedCachePolicy::QueryAndReplace(torch::Tensor keys) {
num_missing * missing_keys.element_size());
});
auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);
missing_keys = AddOffset(missing_keys, -offset);
return std::make_tuple(
positions, output_indices, pointers, missing_keys, found_offsets,
missing_offsets);
}

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::QueryAndReplaceAsync(torch::Tensor keys) {
PartitionedCachePolicy::QueryAndReplaceAsync(
torch::Tensor keys, const int64_t offset) {
return async([=] {
auto
[positions, output_indices, pointers, missing_keys, found_offsets,
missing_offsets] = QueryAndReplace(keys);
missing_offsets] = QueryAndReplace(keys, offset);
return std::vector{positions, output_indices, pointers,
missing_keys, found_offsets, missing_offsets};
});
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
PartitionedCachePolicy::Replace(
torch::Tensor keys, torch::optional<torch::Tensor> offsets) {
torch::Tensor keys, torch::optional<torch::Tensor> offsets,
const int64_t offset) {
keys = AddOffset(keys, offset);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
auto [positions, pointers] = policies_[0]->Replace(keys);
Expand Down Expand Up @@ -419,9 +452,10 @@ PartitionedCachePolicy::Replace(

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::ReplaceAsync(
torch::Tensor keys, torch::optional<torch::Tensor> offsets) {
torch::Tensor keys, torch::optional<torch::Tensor> offsets,
const int64_t offset) {
return async([=] {
auto [positions, pointers, offsets_out] = Replace(keys, offsets);
auto [positions, pointers, offsets_out] = Replace(keys, offsets, offset);
return std::vector{positions, pointers, offsets_out};
});
}
Expand Down
17 changes: 11 additions & 6 deletions graphbolt/src/partitioned_cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
/**
* @brief The policy query function.
* @param keys The keys to query the cache.
* @param offset The offset to be added to the keys.
*
* @return (positions, indices, missing_keys, found_ptrs, found_offsets,
* missing_offsets), where positions has the locations of the keys which were
Expand All @@ -69,14 +70,15 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
Query(torch::Tensor keys);
Query(torch::Tensor keys, int64_t offset);

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(
torch::Tensor keys);
torch::Tensor keys, int64_t offset);

/**
* @brief The policy query and then replace function.
* @param keys The keys to query the cache.
* @param offset The offset to be added to the keys.
*
* @return (positions, indices, pointers, missing_keys, found_offsets,
* missing_offsets), where positions has the locations of the keys which were
Expand All @@ -92,25 +94,28 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
QueryAndReplace(torch::Tensor keys);
QueryAndReplace(torch::Tensor keys, int64_t offset);

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAndReplaceAsync(
torch::Tensor keys);
torch::Tensor keys, int64_t offset);

/**
* @brief The policy replace function.
* @param keys The keys to query the cache.
* @param offsets The partition offsets for the keys.
* @param offset The offset to be added to the keys.
*
* @return (positions, pointers, offsets), where positions holds the locations
* of the replaced entries in the cache, pointers holds the CacheKey pointers
* for the inserted keys and offsets holds the partition offsets for pointers.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Replace(
torch::Tensor keys, torch::optional<torch::Tensor> offsets);
torch::Tensor keys, torch::optional<torch::Tensor> offsets,
int64_t offset);

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> ReplaceAsync(
torch::Tensor keys, torch::optional<torch::Tensor> offsets);
torch::Tensor keys, torch::optional<torch::Tensor> offsets,
int64_t offset);

template <bool write>
void ReadingWritingCompletedImpl(
Expand Down
11 changes: 6 additions & 5 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
pin_memory=pin_memory,
)
self._is_pinned = pin_memory
self._offset = 0

def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
Expand All @@ -79,7 +80,7 @@ def read(self, ids: torch.Tensor = None):
if ids is None:
return self._fallback_feature.read()
return self._feature.query_and_replace(
ids.cpu(), self._fallback_feature.read
ids.cpu(), self._fallback_feature.read, self._offset
).to(ids.device)

def read_async(self, ids: torch.Tensor):
Expand Down Expand Up @@ -124,7 +125,7 @@ def read_async(self, ids: torch.Tensor):
yield # first stage is done.

ids_copy_event.synchronize()
policy_future = policy.query_and_replace_async(ids)
policy_future = policy.query_and_replace_async(ids, self._offset)

yield

Expand Down Expand Up @@ -241,7 +242,7 @@ def wait(self):
yield # first stage is done.

ids_copy_event.synchronize()
policy_future = policy.query_and_replace_async(ids)
policy_future = policy.query_and_replace_async(ids, self._offset)

yield

Expand Down Expand Up @@ -319,7 +320,7 @@ def wait(self):

yield _Waiter([values_copy_event, writing_completed], values)
else:
policy_future = policy.query_and_replace_async(ids)
policy_future = policy.query_and_replace_async(ids, self._offset)

yield

Expand Down Expand Up @@ -448,4 +449,4 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
)
else:
self._fallback_feature.update(value, ids)
self._feature.replace(ids, value)
self._feature.replace(ids, value, None, self._offset)
18 changes: 12 additions & 6 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ def __init__(
self.total_miss = 0
self.total_queries = 0

def query(self, keys):
def query(self, keys, offset=0):
"""Queries the cache.
Parameters
----------
keys : Tensor
The keys to query the cache with.
offset : int
The offset to be added to the keys. Default is 0.
Returns
-------
Expand All @@ -85,14 +87,14 @@ def query(self, keys):
found_pointers,
found_offsets,
missing_offsets,
) = self._policy.query(keys)
) = self._policy.query(keys, offset)
values = self._cache.query(positions, index, keys.shape[0])
self._policy.reading_completed(found_pointers, found_offsets)
self.total_miss += missing_keys.shape[0]
missing_index = index[positions.size(0) :]
return values, missing_index, missing_keys, missing_offsets

def query_and_replace(self, keys, reader_fn):
def query_and_replace(self, keys, reader_fn, offset=0):
"""Queries the cache. Then inserts the keys that are not found by
reading them by calling `reader_fn(missing_keys)`, which are then
inserted into the cache using the selected caching policy algorithm
Expand All @@ -105,6 +107,8 @@ def query_and_replace(self, keys, reader_fn):
reader_fn : reader_fn(keys: torch.Tensor) -> torch.Tensor
A function that will take a missing keys tensor and will return
their values.
offset : int
The offset to be added to the keys. Default is 0.
Returns
-------
Expand All @@ -120,7 +124,7 @@ def query_and_replace(self, keys, reader_fn):
missing_keys,
found_offsets,
missing_offsets,
) = self._policy.query_and_replace(keys)
) = self._policy.query_and_replace(keys, offset)
found_cnt = keys.size(0) - missing_keys.size(0)
found_positions = positions[:found_cnt]
values = self._cache.query(found_positions, index, keys.shape[0])
Expand All @@ -136,7 +140,7 @@ def query_and_replace(self, keys, reader_fn):
self._policy.writing_completed(missing_pointers, missing_offsets)
return values

def replace(self, keys, values, offsets=None):
def replace(self, keys, values, offsets=None, offset=0):
"""Inserts key-value pairs into the cache using the selected caching
policy algorithm to remove old key-value pairs if it is full.
Expand All @@ -148,8 +152,10 @@ def replace(self, keys, values, offsets=None):
The values to insert to the cache.
offsets : Tensor, optional
The partition offsets of the keys.
offset : int
The offset to be added to the keys. Default is 0.
"""

Check warning on line 157 in python/dgl/graphbolt/impl/feature_cache.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
positions, pointers, offsets = self._policy.replace(keys, offsets)
positions, pointers, offsets = self._policy.replace(keys, offsets, offset)
self._cache.replace(positions, values)
self._policy.writing_completed(pointers, offsets)

Expand Down
7 changes: 4 additions & 3 deletions tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


def _test_query_and_replace(policy1, policy2, keys):
offset = 111111
# Testing query_and_replace equivalence to query and then replace.
(
_,
Expand All @@ -15,7 +16,7 @@ def _test_query_and_replace(policy1, policy2, keys):
missing_keys,
found_offsets,
missing_offsets,
) = policy1.query_and_replace(keys)
) = policy1.query_and_replace(keys, offset)
found_cnt = keys.size(0) - missing_keys.size(0)
found_pointers = pointers[:found_cnt]
policy1.reading_completed(found_pointers, found_offsets)
Expand All @@ -29,10 +30,10 @@ def _test_query_and_replace(policy1, policy2, keys):
found_pointers2,
found_offsets2,
missing_offsets2,
) = policy2.query(keys)
) = policy2.query(keys, offset)
policy2.reading_completed(found_pointers2, found_offsets2)
(_, missing_pointers2, missing_offsets2) = policy2.replace(
missing_keys2, missing_offsets2
missing_keys2, missing_offsets2, offset
)
policy2.writing_completed(missing_pointers2, missing_offsets2)

Expand Down

0 comments on commit f4f5a0a

Please sign in to comment.