Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_memory_alloc_env
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 31, 2024
2 parents c950e27 + eafb530 commit ecef989
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 2 deletions.
86 changes: 86 additions & 0 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,72 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
found_ptr_tensor.slice(0, 0, found_cnt)};
}

template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryAndThenReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto indices = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto pointers = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
auto missing_keys = torch::empty_like(
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
static_assert(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto pointers_ptr =
reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
auto iterators = std::unique_ptr<typename CachePolicy::map_iterator[]>(
new typename CachePolicy::map_iterator[keys.size(0)]);
// QueryImpl here.
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
const auto [it, can_read] = policy.Emplace(key);
if (can_read) {
auto& cache_key = *it->second;
positions_ptr[found_cnt] = cache_key.getPos();
pointers_ptr[found_cnt] = &cache_key;
indices_ptr[found_cnt++] = i;
} else {
indices_ptr[--missing_cnt] = i;
missing_keys_ptr[missing_cnt] = key;
iterators[missing_cnt] = it;
}
}
// ReplaceImpl here.
set_t<int64_t> position_set;
position_set.reserve(keys.size(0));
for (int64_t i = missing_cnt; i < missing_keys.size(0); i++) {
auto it = iterators[i];
if (it->second == policy.getMapSentinelValue()) {
policy.Insert(it);
// After Insert, it->second is not nullptr anymore.
TORCH_CHECK(
// If there are duplicate values and the key was just inserted,
// we do not have to check for the uniqueness of the positions.
std::get<1>(position_set.insert(it->second->getPos())),
"Can't insert all, larger cache capacity is needed.");
}
auto& cache_key = *it->second;
positions_ptr[i] = cache_key.getPos();
pointers_ptr[i] = &cache_key;
}
}));
return {positions, indices, pointers, missing_keys.slice(0, found_cnt)};
}

template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
Expand Down Expand Up @@ -140,6 +206,11 @@ S3FifoCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
S3FifoCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
Expand All @@ -165,6 +236,11 @@ SieveCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
SieveCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
Expand All @@ -189,6 +265,11 @@ LruCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
LruCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
Expand All @@ -213,6 +294,11 @@ ClockCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
ClockCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}

std::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
Expand Down
152 changes: 152 additions & 0 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ class BaseCachePolicy {
virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Query(torch::Tensor keys) = 0;

/**
* @brief The policy query function.
* @param keys The keys to query the cache.
*
* @return (positions, indices, pointers, missing_keys), where positions has
* the locations of the keys which were emplaced into the cache, pointers
* point to the emplaced CacheKey pointers in the cache, missing_keys has the
* keys that were not found and just inserted and indices is defined such that
* keys[indices[:keys.size(0) - missing_keys.size(0)]] gives us the keys for
* the found keys and keys[indices[keys.size(0) - missing_keys.size(0):]] is
* identical to missing_keys.
*/
virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys) = 0;

/**
* @brief The policy replace function.
* @param keys The keys to query the cache.
Expand Down Expand Up @@ -165,6 +180,10 @@ class BaseCachePolicy {
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryImpl(CachePolicy& policy, torch::Tensor keys);

template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplaceImpl(CachePolicy& policy, torch::Tensor keys);

template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor> ReplaceImpl(
CachePolicy& policy, torch::Tensor keys);
Expand All @@ -180,6 +199,7 @@ class BaseCachePolicy {
**/
class S3FifoCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
/**
* @brief Constructor for the S3FifoCachePolicy class.
*
Expand All @@ -199,6 +219,12 @@ class S3FifoCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);

/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::Replace.
*/
Expand Down Expand Up @@ -234,6 +260,25 @@ class S3FifoCachePolicy : public BaseCachePolicy {
return std::nullopt;
}

auto getMapSentinelValue() const { return nullptr; }

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.Increment().StartUse<false>();
return {it, true};
} else {
cache_key.Increment().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}

std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
const auto in_ghost_queue = ghost_set_.erase(key);
Expand All @@ -243,6 +288,14 @@ class S3FifoCachePolicy : public BaseCachePolicy {
return {pos, cache_key_ptr};
}

void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
const auto in_ghost_queue = ghost_set_.erase(key);
auto& queue = in_ghost_queue ? main_queue_ : small_queue_;
it->second = queue.Push(CacheKey(key, pos));
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down Expand Up @@ -306,6 +359,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
**/
class SieveCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
/**
* @brief Constructor for the SieveCachePolicy class.
*
Expand All @@ -323,6 +377,12 @@ class SieveCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);

/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::Replace.
*/
Expand Down Expand Up @@ -350,6 +410,25 @@ class SieveCachePolicy : public BaseCachePolicy {
return std::nullopt;
}

auto getMapSentinelValue() const { return nullptr; }

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}

std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
Expand All @@ -358,6 +437,13 @@ class SieveCachePolicy : public BaseCachePolicy {
return {pos, cache_key_ptr};
}

void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
it->second = &queue_.front();
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down Expand Up @@ -398,6 +484,7 @@ class SieveCachePolicy : public BaseCachePolicy {
**/
class LruCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, std::list<CacheKey>::iterator>::iterator;
/**
* @brief Constructor for the LruCachePolicy class.
*
Expand All @@ -415,6 +502,12 @@ class LruCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);

/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::Replace.
*/
Expand Down Expand Up @@ -455,13 +548,40 @@ class LruCachePolicy : public BaseCachePolicy {
return std::nullopt;
}

auto getMapSentinelValue() { return queue_.end(); }

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != queue_.end()) {
MoveToFront(it->second);
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.StartUse<false>();
return {it, true};
} else {
cache_key.StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}

std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
key_to_cache_key_[key] = queue_.begin();
return {pos, &queue_.front()};
}

void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
it->second = queue_.begin();
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down Expand Up @@ -501,6 +621,7 @@ class LruCachePolicy : public BaseCachePolicy {
**/
class ClockCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
/**
* @brief Constructor for the ClockCachePolicy class.
*
Expand All @@ -520,6 +641,12 @@ class ClockCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);

/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::Replace.
*/
Expand Down Expand Up @@ -547,13 +674,38 @@ class ClockCachePolicy : public BaseCachePolicy {
return std::nullopt;
}

auto getMapSentinelValue() const { return nullptr; }

std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}

std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
auto cache_key_ptr = queue_.Push(CacheKey(key, pos));
key_to_cache_key_[key] = cache_key_ptr;
return {pos, cache_key_ptr};
}

void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
it->second = queue_.Push(CacheKey(key, pos));
}

template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
Expand Down
Loading

0 comments on commit ecef989

Please sign in to comment.