Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] CachePolicy Writer lock for read_async correctness. #7581

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto pos = policy.Read(key);
auto pos = policy.template Read<false>(key);
if (pos.has_value()) {
positions_ptr[found_cnt] = *pos;
filtered_keys_ptr[found_cnt] = key;
Expand Down Expand Up @@ -78,7 +78,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
position_set.reserve(keys.size(0));
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
const auto pos_optional = policy.Read(key);
const auto pos_optional = policy.template Read<true>(key);
const auto pos = pos_optional ? *pos_optional : policy.Insert(key);
positions_ptr[i] = pos;
TORCH_CHECK(
Expand All @@ -91,14 +91,14 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
return positions;
}

template <typename CachePolicy>
void BaseCachePolicy::ReadingCompletedImpl(
template <bool write, typename CachePolicy>
void BaseCachePolicy::ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor keys) {
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
policy.Unmark(keys_ptr[i]);
policy.template Unmark<write>(keys_ptr[i]);
}
}));
}
Expand All @@ -125,7 +125,11 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
}

void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void S3FifoCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

SieveCachePolicy::SieveCachePolicy(int64_t capacity)
Expand All @@ -145,7 +149,11 @@ torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
}

void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void SieveCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

LruCachePolicy::LruCachePolicy(int64_t capacity)
Expand All @@ -164,7 +172,11 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
}

void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void LruCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

ClockCachePolicy::ClockCachePolicy(int64_t capacity)
Expand All @@ -183,7 +195,11 @@ torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
}

void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingCompletedImpl(*this, keys);
ReadingWritingCompletedImpl<false>(*this, keys);
}

void ClockCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}

} // namespace storage
Expand Down
109 changes: 89 additions & 20 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -24,14 +24,20 @@
#include <torch/custom_class.h>
#include <torch/torch.h>

#include <limits>

#include "./circular_queue.h"

namespace graphbolt {
namespace storage {

struct CacheKey {
CacheKey(int64_t key, int64_t position)
: freq_(0), key_(key), position_in_cache_(position), reference_count_(1) {
: freq_(0),
key_(key),
position_in_cache_(position),
read_reference_count_(0),
write_reference_count_(1) {
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
}

Expand Down Expand Up @@ -63,17 +69,30 @@ struct CacheKey {
return *this;
}

template <bool write>
CacheKey& StartUse() {
++reference_count_;
if constexpr (write) {
TORCH_CHECK(
write_reference_count_++ < std::numeric_limits<int16_t>::max());
} else {
TORCH_CHECK(read_reference_count_++ < std::numeric_limits<int8_t>::max());
}
return *this;
}

template <bool write>
CacheKey& EndUse() {
--reference_count_;
if constexpr (write) {
--write_reference_count_;
} else {
--read_reference_count_;
}
return *this;
}

bool InUse() { return reference_count_ > 0; }
bool InUse() const { return read_reference_count_ || write_reference_count_; }

bool BeingWritten() const { return write_reference_count_; }

friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
Expand All @@ -83,8 +102,10 @@ struct CacheKey {
private:
int64_t freq_ : 3;
int64_t key_ : 61;
int64_t position_in_cache_ : 48;
int64_t reference_count_ : 16;
int64_t position_in_cache_ : 40;
int64_t read_reference_count_ : 8;
// There could be a chain of writes so it is better to have larger bit count.
int64_t write_reference_count_ : 16;
};

class BaseCachePolicy {
Expand Down Expand Up @@ -123,6 +144,12 @@ class BaseCachePolicy {
*/
virtual void ReadingCompleted(torch::Tensor keys) = 0;

/**
* @brief A writer has finished writing these keys, so they can be evicted.
* @param keys The keys to unmark.
*/
virtual void WritingCompleted(torch::Tensor keys) = 0;

protected:
template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -131,8 +158,9 @@ class BaseCachePolicy {
template <typename CachePolicy>
static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys);

template <typename CachePolicy>
static void ReadingCompletedImpl(CachePolicy& policy, torch::Tensor keys);
template <bool write, typename CachePolicy>
static void ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor keys);
};

/**
Expand Down Expand Up @@ -170,6 +198,11 @@ class S3FifoCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

friend std::ostream& operator<<(
std::ostream& os, const S3FifoCachePolicy& policy) {
return os << "small_queue_: " << policy.small_queue_ << "\n"
Expand All @@ -178,11 +211,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {
<< "capacity_: " << policy.capacity_ << "\n";
}

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
return cache_key.Increment().StartUse().getPos();
if (write || !cache_key.BeingWritten())
return cache_key.Increment().StartUse<write>().getPos();
}
return std::nullopt;
}
Expand All @@ -195,7 +230,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t EvictMainQueue() {
Expand Down Expand Up @@ -282,11 +320,18 @@ class SieveCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
return cache_key.SetFreq().StartUse().getPos();
if (write || !cache_key.BeingWritten())
return cache_key.SetFreq().StartUse<write>().getPos();
}
return std::nullopt;
}
Expand All @@ -298,7 +343,10 @@ class SieveCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t Evict() {
Expand Down Expand Up @@ -362,14 +410,22 @@ class LruCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto cache_key = *it->second;
queue_.erase(it->second);
queue_.push_front(cache_key.StartUse());
it->second = queue_.begin();
return cache_key.getPos();
if (write || !cache_key.BeingWritten()) {
queue_.erase(it->second);
queue_.push_front(cache_key.StartUse<write>());
it->second = queue_.begin();
return cache_key.getPos();
}
}
return std::nullopt;
}
Expand All @@ -381,7 +437,10 @@ class LruCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t Evict() {
Expand Down Expand Up @@ -443,11 +502,18 @@ class ClockCachePolicy : public BaseCachePolicy {
*/
void ReadingCompleted(torch::Tensor keys);

/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor keys);

template <bool write>
std::optional<int64_t> Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
if (it != key_to_cache_key_.end()) {
auto& cache_key = *it->second;
return cache_key.SetFreq().StartUse().getPos();
if (write || !cache_key.BeingWritten())
return cache_key.SetFreq().StartUse<write>().getPos();
}
return std::nullopt;
}
Expand All @@ -458,7 +524,10 @@ class ClockCachePolicy : public BaseCachePolicy {
return pos;
}

void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
template <bool write>
void Unmark(int64_t key) {
key_to_cache_key_[key]->EndUse<write>();
}

private:
int64_t Evict() {
Expand Down
1 change: 1 addition & 0 deletions graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());
const auto positions_ptr = positions.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
torch::parallel_for(
0, positions.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/feature_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct FeatureCache : public torch::CustomClassHolder {

private:
torch::Tensor tensor_;
// Protects writes only as reads are guaranteed to be safe.
std::mutex mtx_;
};

} // namespace storage
Expand Down
26 changes: 23 additions & 3 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,14 @@ c10::intrusive_ptr<Future<torch::Tensor>> PartitionedCachePolicy::ReplaceAsync(
return async([=] { return Replace(keys); });
}

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
template <bool write>
void PartitionedCachePolicy::ReadingWritingCompletedImpl(torch::Tensor keys) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
policies_[0]->ReadingCompleted(keys);
if constexpr (write)
policies_[0]->WritingCompleted(keys);
else
policies_[0]->ReadingCompleted(keys);
return;
}
torch::Tensor offsets, indices, permuted_keys;
Expand All @@ -257,15 +261,31 @@ void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end));
if constexpr (write)
policies_.at(tid)->WritingCompleted(permuted_keys.slice(0, begin, end));
else
policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end));
});
}

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<false>(keys);
}

void PartitionedCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(keys);
}

c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::ReadingCompletedAsync(
torch::Tensor keys) {
return async([=] { return ReadingCompleted(keys); });
}

c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::WritingCompletedAsync(
torch::Tensor keys) {
return async([=] { return WritingCompleted(keys); });
}

template <typename CachePolicy>
c10::intrusive_ptr<PartitionedCachePolicy> PartitionedCachePolicy::Create(
int64_t capacity, int64_t num_partitions) {
Expand Down
Loading
Loading