Skip to content

Commit

Permalink
[GraphBolt] CachePolicy Writer lock for read_async correctness. (#7581)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 24, 2024
1 parent b80c8f5 commit b8604a5
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 40 deletions.
34 changes: 25 additions & 9 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto pos = policy.Read(key);
auto pos = policy.template Read<false>(key);
if (pos.has_value()) {
positions_ptr[found_cnt] = *pos;
filtered_keys_ptr[found_cnt] = key;
Expand Down Expand Up @@ -78,7 +78,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
position_set.reserve(keys.size(0));
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
const auto pos_optional = policy.Read(key);
const auto pos_optional = policy.template Read<true>(key);
const auto pos = pos_optional ? *pos_optional : policy.Insert(key);
positions_ptr[i] = pos;
TORCH_CHECK(
Expand All @@ -91,14 +91,14 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
return positions;
}

template <typename CachePolicy>
void BaseCachePolicy::ReadingCompletedImpl(
template <bool write, typename CachePolicy>
void BaseCachePolicy::ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor keys) {
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
policy.Unmark(keys_ptr[i]);
policy.template Unmark<write>(keys_ptr[i]);
}
}));
}
Expand All @@ -125,7 +125,11 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
}

void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void S3FifoCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

SieveCachePolicy::SieveCachePolicy(int64_t capacity)
Expand All @@ -145,7 +149,11 @@ torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
}

void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void SieveCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

LruCachePolicy::LruCachePolicy(int64_t capacity)
Expand All @@ -164,7 +172,11 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
}

void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void LruCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

ClockCachePolicy::ClockCachePolicy(int64_t capacity)
Expand All @@ -183,7 +195,11 @@ torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
}

void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void ClockCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

} // namespace storage
Expand Down
109 changes: 89 additions & 20 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -24,14 +24,20 @@
#include <torch/custom_class.h>
#include <torch/torch.h>

#include <limits>

#include "./circular_queue.h"

namespace graphbolt {
namespace storage {

struct CacheKey {
CacheKey(int64_t key, int64_t position)
: freq_(0), key_(key), position_in_cache_(position), reference_count_(1) {
: freq_(0),
key_(key),
position_in_cache_(position),
read_reference_count_(0),
write_reference_count_(1) {
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
}

Expand Down Expand Up @@ -63,17 +69,30 @@ struct CacheKey {
return *this;
}

template <bool write>
CacheKey& StartUse() {
++reference_count_;
if constexpr (write) {
TORCH_CHECK(
write_reference_count_++ < std::numeric_limits<int16_t>::max());
} else {
TORCH_CHECK(read_reference_count_++ < std::numeric_limits<int8_t>::max());
}
return *this;
}

template <bool write>
CacheKey& EndUse() {
--reference_count_;
if constexpr (write) {
--write_reference_count_;
} else {
--read_reference_count_;
}
return *this;
}

bool InUse() { return reference_count_ > 0; }
bool InUse() const { return read_reference_count_ || write_reference_count_; }

bool BeingWritten() const { return write_reference_count_; }

friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
Expand All @@ -83,8 +102,10 @@ struct CacheKey {
private:
int64_t freq_ : 3;
int64_t key_ : 61;
int64_t position_in_cache_ : 48;
int64_t reference_count_ : 16;
int64_t position_in_cache_ : 40;
int64_t read_reference_count_ : 8;
// There could be a chain of writes so it is better to have larger bit count.
int64_t write_reference_count_ : 16;
};

class BaseCachePolicy {
Expand Down Expand Up @@ -123,6 +144,12 @@ class BaseCachePolicy {
*/
virtual void ReadingCompleted(torch::Tensor keys) = 0;

/**
* @brief A writer has finished writing these keys, so they can be evicted.
* @param keys The keys to unmark.
*/
virtual void WritingCompleted(torch::Tensor keys) = 0;

protected:
template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -131,8 +158,9 @@ class BaseCachePolicy {
template <typename CachePolicy>
static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys);

template <typename CachePolicy>
static void ReadingCompletedImpl(CachePolicy& policy, torch::Tensor keys);
template <bool write, typename CachePolicy>
static void ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor keys);
};

/**
Expand Down Expand Up @@ -170,6 +198,11 @@ class S3FifoCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

friend std::ostream& operator<<(
std::ostream& os, const S3FifoCachePolicy& policy) {
return os << "small_queue_: " << policy.small_queue_ << "\n"
Expand All @@ -178,11 +211,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {
<< "capacity_: " << policy.capacity_ << "\n";
}

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
return cache_key.Increment().StartUse().getPos();
if (write || !cache_key.BeingWritten())
return cache_key.Increment().StartUse<write>().getPos();
}
return std::nullopt;
}
Expand All @@ -195,7 +230,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t EvictMainQueue() {
Expand Down Expand Up @@ -282,11 +320,18 @@ class SieveCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
return cache_key.SetFreq().StartUse().getPos();
if (write || !cache_key.BeingWritten())
return cache_key.SetFreq().StartUse<write>().getPos();
}
return std::nullopt;
}
Expand All @@ -298,7 +343,10 @@ class SieveCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t Evict() {
Expand Down Expand Up @@ -362,14 +410,22 @@ class LruCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto cache_key = *it->second;
queue_.erase(it->second);
queue_.push_front(cache_key.StartUse());
it->second = queue_.begin();
return cache_key.getPos();
if (write || !cache_key.BeingWritten()) {
queue_.erase(it->second);
queue_.push_front(cache_key.StartUse<write>());
it->second = queue_.begin();
return cache_key.getPos();
}
}
return std::nullopt;
}
Expand All @@ -381,7 +437,10 @@ class LruCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t Evict() {
Expand Down Expand Up @@ -443,11 +502,18 @@ class ClockCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
return cache_key.SetFreq().StartUse().getPos();
if (write || !cache_key.BeingWritten())
return cache_key.SetFreq().StartUse<write>().getPos();
}
return std::nullopt;
}
Expand All @@ -458,7 +524,10 @@ class ClockCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t Evict() {
Expand Down
1 change: 1 addition & 0 deletions graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());
const auto positions_ptr = positions.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
torch::parallel_for(
0, positions.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/feature_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct FeatureCache : public torch::CustomClassHolder {

private:
torch::Tensor tensor_;
// Protects writes only as reads are guaranteed to be safe.
std::mutex mtx_;
};

} // namespace storage
Expand Down
26 changes: 23 additions & 3 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,14 @@ c10::intrusive_ptr<Future<torch::Tensor>> PartitionedCachePolicy::ReplaceAsync(
return async([=] { return Replace(keys); });
}

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
template <bool write>
void PartitionedCachePolicy::ReadingWritingCompletedImpl(torch::Tensor keys) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
policies_[0]->ReadingCompleted(keys);
if constexpr (write)
policies_[0]->WritingCompleted(keys);
else
policies_[0]->ReadingCompleted(keys);
return;
}
torch::Tensor offsets, indices, permuted_keys;
Expand All @@ -257,15 +261,31 @@ void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end));
if constexpr (write)
policies_.at(tid)->WritingCompleted(permuted_keys.slice(0, begin, end));
else
policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end));
});
}

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<false>(keys);
}

void PartitionedCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(keys);
}

c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::ReadingCompletedAsync(
torch::Tensor keys) {
return async([=] { return ReadingCompleted(keys); });
}

c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::WritingCompletedAsync(
torch::Tensor keys) {
return async([=] { return WritingCompleted(keys); });
}

template <typename CachePolicy>
c10::intrusive_ptr<PartitionedCachePolicy> PartitionedCachePolicy::Create(
int64_t capacity, int64_t num_partitions) {
Expand Down
Loading

0 comments on commit b8604a5

Please sign in to comment.