From df2fe6d8ed67207c6ef03e6bc2e37e5c0a225e32 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Wed, 24 Jul 2024 18:06:01 -0400 Subject: [PATCH] [GraphBolt] CachePolicy Writer lock for read_async correctness. --- graphbolt/src/cache_policy.cc | 34 ++++-- graphbolt/src/cache_policy.h | 109 ++++++++++++++---- graphbolt/src/feature_cache.cc | 1 + graphbolt/src/feature_cache.h | 2 + graphbolt/src/partitioned_cache_policy.cc | 26 ++++- graphbolt/src/partitioned_cache_policy.h | 11 ++ graphbolt/src/python_binding.cc | 8 +- .../dgl/graphbolt/impl/cpu_cached_feature.py | 12 +- python/dgl/graphbolt/impl/feature_cache.py | 2 +- 9 files changed, 165 insertions(+), 40 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index aac56d130f57..f48607d21752 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -47,7 +47,7 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { auto filtered_keys_ptr = filtered_keys.data_ptr(); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - auto pos = policy.Read(key); + auto pos = policy.template Read(key); if (pos.has_value()) { positions_ptr[found_cnt] = *pos; filtered_keys_ptr[found_cnt] = key; @@ -78,7 +78,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl( position_set.reserve(keys.size(0)); for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; - const auto pos_optional = policy.Read(key); + const auto pos_optional = policy.template Read(key); const auto pos = pos_optional ? *pos_optional : policy.Insert(key); positions_ptr[i] = pos; TORCH_CHECK( @@ -91,14 +91,14 @@ torch::Tensor BaseCachePolicy::ReplaceImpl( return positions; } -template -void BaseCachePolicy::ReadingCompletedImpl( +template +void BaseCachePolicy::ReadingWritingCompletedImpl( CachePolicy& policy, torch::Tensor keys) { AT_DISPATCH_INDEX_TYPES( keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] { auto keys_ptr = keys.data_ptr(); for (int64_t i = 0; i < keys.size(0); i++) { - policy.Unmark(keys_ptr[i]); + policy.template Unmark(keys_ptr[i]); } })); } @@ -125,7 +125,11 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) { } void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) { - ReadingCompletedImpl(*this, keys); + ReadingWritingCompletedImpl(*this, keys); +} + +void S3FifoCachePolicy::WritingCompleted(torch::Tensor keys) { + ReadingWritingCompletedImpl(*this, keys); } SieveCachePolicy::SieveCachePolicy(int64_t capacity) @@ -145,7 +149,11 @@ torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) { } void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) { - ReadingCompletedImpl(*this, keys); + ReadingWritingCompletedImpl(*this, keys); +} + +void SieveCachePolicy::WritingCompleted(torch::Tensor keys) { + ReadingWritingCompletedImpl(*this, keys); } LruCachePolicy::LruCachePolicy(int64_t capacity) @@ -164,7 +172,11 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) { } void LruCachePolicy::ReadingCompleted(torch::Tensor keys) { - ReadingCompletedImpl(*this, keys); + ReadingWritingCompletedImpl(*this, keys); +} + +void LruCachePolicy::WritingCompleted(torch::Tensor keys) { + ReadingWritingCompletedImpl(*this, keys); } ClockCachePolicy::ClockCachePolicy(int64_t capacity) @@ -183,7 +195,11 @@ torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) { } void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) { - ReadingCompletedImpl(*this, keys); + ReadingWritingCompletedImpl(*this, keys); +} + +void ClockCachePolicy::WritingCompleted(torch::Tensor keys) { + ReadingWritingCompletedImpl(*this, keys); } } // namespace storage diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index e45ff16a9d3a..b7fc0e61116d 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) + * Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,6 +24,8 @@ #include #include +#include + #include "./circular_queue.h" namespace graphbolt { @@ -31,7 +33,11 @@ namespace storage { struct CacheKey { CacheKey(int64_t key, int64_t position) - : freq_(0), key_(key), position_in_cache_(position), reference_count_(1) { + : freq_(0), + key_(key), + position_in_cache_(position), + read_reference_count_(0), + write_reference_count_(1) { static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t)); } @@ -63,17 +69,30 @@ struct CacheKey { return *this; } + template CacheKey& StartUse() { - ++reference_count_; + if constexpr (write) { + TORCH_CHECK( + write_reference_count_++ < std::numeric_limits::max()); + } else { + TORCH_CHECK(read_reference_count_++ < std::numeric_limits::max()); + } return *this; } + template CacheKey& EndUse() { - --reference_count_; + if constexpr (write) { + --write_reference_count_; + } else { + --read_reference_count_; + } return *this; } - bool InUse() { return reference_count_ > 0; } + bool InUse() const { return read_reference_count_ || write_reference_count_; } + + bool BeingWritten() const { return write_reference_count_; } friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) { return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", " @@ -83,8 +102,10 @@ struct CacheKey { private: int64_t freq_ : 3; int64_t key_ : 61; - int64_t position_in_cache_ : 48; - int64_t reference_count_ : 16; + int64_t position_in_cache_ : 40; + int64_t read_reference_count_ : 8; + // There could be a chain of writes so it is better to have larger bit count. + int64_t write_reference_count_ : 16; }; class BaseCachePolicy { @@ -123,6 +144,12 @@ class BaseCachePolicy { */ virtual void ReadingCompleted(torch::Tensor keys) = 0; + /** + * @brief A writer has finished writing these keys, so they can be evicted. + * @param keys The keys to unmark. + */ + virtual void WritingCompleted(torch::Tensor keys) = 0; + protected: template static std::tuple @@ -131,8 +158,9 @@ class BaseCachePolicy { template static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys); - template - static void ReadingCompletedImpl(CachePolicy& policy, torch::Tensor keys); + template + static void ReadingWritingCompletedImpl( + CachePolicy& policy, torch::Tensor keys); }; /** @@ -170,6 +198,11 @@ class S3FifoCachePolicy : public BaseCachePolicy { */ void ReadingCompleted(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::WritingCompleted. + */ + void WritingCompleted(torch::Tensor keys); + friend std::ostream& operator<<( std::ostream& os, const S3FifoCachePolicy& policy) { return os << "small_queue_: " << policy.small_queue_ << "\n" @@ -178,11 +211,13 @@ class S3FifoCachePolicy : public BaseCachePolicy { << "capacity_: " << policy.capacity_ << "\n"; } + template std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto& cache_key = *it->second; - return cache_key.Increment().StartUse().getPos(); + if (write || !cache_key.BeingWritten()) + return cache_key.Increment().StartUse().getPos(); } return std::nullopt; } @@ -195,7 +230,10 @@ class S3FifoCachePolicy : public BaseCachePolicy { return pos; } - void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + template + void Unmark(int64_t key) { + key_to_cache_key_[key]->EndUse(); + } private: int64_t EvictMainQueue() { @@ -282,11 +320,18 @@ class SieveCachePolicy : public BaseCachePolicy { */ void ReadingCompleted(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::WritingCompleted. + */ + void WritingCompleted(torch::Tensor keys); + + template std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto& cache_key = *it->second; - return cache_key.SetFreq().StartUse().getPos(); + if (write || !cache_key.BeingWritten()) + return cache_key.SetFreq().StartUse().getPos(); } return std::nullopt; } @@ -298,7 +343,10 @@ class SieveCachePolicy : public BaseCachePolicy { return pos; } - void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + template + void Unmark(int64_t key) { + key_to_cache_key_[key]->EndUse(); + } private: int64_t Evict() { @@ -362,14 +410,22 @@ class LruCachePolicy : public BaseCachePolicy { */ void ReadingCompleted(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::WritingCompleted. + */ + void WritingCompleted(torch::Tensor keys); + + template std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto cache_key = *it->second; - queue_.erase(it->second); - queue_.push_front(cache_key.StartUse()); - it->second = queue_.begin(); - return cache_key.getPos(); + if (write || !cache_key.BeingWritten()) { + queue_.erase(it->second); + queue_.push_front(cache_key.StartUse()); + it->second = queue_.begin(); + return cache_key.getPos(); + } } return std::nullopt; } @@ -381,7 +437,10 @@ class LruCachePolicy : public BaseCachePolicy { return pos; } - void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + template + void Unmark(int64_t key) { + key_to_cache_key_[key]->EndUse(); + } private: int64_t Evict() { @@ -443,11 +502,18 @@ class ClockCachePolicy : public BaseCachePolicy { */ void ReadingCompleted(torch::Tensor keys); + /** + * @brief See BaseCachePolicy::WritingCompleted. + */ + void WritingCompleted(torch::Tensor keys); + + template std::optional Read(int64_t key) { auto it = key_to_cache_key_.find(key); if (it != key_to_cache_key_.end()) { auto& cache_key = *it->second; - return cache_key.SetFreq().StartUse().getPos(); + if (write || !cache_key.BeingWritten()) + return cache_key.SetFreq().StartUse().getPos(); } return std::nullopt; } @@ -458,7 +524,10 @@ class ClockCachePolicy : public BaseCachePolicy { return pos; } - void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); } + template + void Unmark(int64_t key) { + key_to_cache_key_[key]->EndUse(); + } private: int64_t Evict() { diff --git a/graphbolt/src/feature_cache.cc b/graphbolt/src/feature_cache.cc index 291ff4ed2f22..eeab728ad194 100644 --- a/graphbolt/src/feature_cache.cc +++ b/graphbolt/src/feature_cache.cc @@ -72,6 +72,7 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) { auto values_ptr = reinterpret_cast(values.data_ptr()); const auto tensor_ptr = reinterpret_cast(tensor_.data_ptr()); const auto positions_ptr = positions.data_ptr(); + std::lock_guard lock(mtx_); torch::parallel_for( 0, positions.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { diff --git a/graphbolt/src/feature_cache.h b/graphbolt/src/feature_cache.h index 3a5a05a86843..6a8b78a078e8 100644 --- a/graphbolt/src/feature_cache.h +++ b/graphbolt/src/feature_cache.h @@ -89,6 +89,8 @@ struct FeatureCache : public torch::CustomClassHolder { private: torch::Tensor tensor_; + // Protects writes only as reads are guaranteed to be safe. + std::mutex mtx_; }; } // namespace storage diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index df62e2b8b0a5..77958a19126c 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -242,10 +242,14 @@ c10::intrusive_ptr> PartitionedCachePolicy::ReplaceAsync( return async([=] { return Replace(keys); }); } -void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) { +template +void PartitionedCachePolicy::ReadingWritingCompletedImpl(torch::Tensor keys) { if (policies_.size() == 1) { std::lock_guard lock(mtx_); - policies_[0]->ReadingCompleted(keys); + if constexpr (write) + policies_[0]->WritingCompleted(keys); + else + policies_[0]->ReadingCompleted(keys); return; } torch::Tensor offsets, indices, permuted_keys; @@ -257,15 +261,31 @@ void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) { const auto tid = begin; begin = offsets_ptr[tid]; end = offsets_ptr[tid + 1]; - policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end)); + if constexpr (write) + policies_.at(tid)->WritingCompleted(permuted_keys.slice(0, begin, end)); + else + policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end)); }); } +void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) { + ReadingWritingCompletedImpl(keys); +} + +void PartitionedCachePolicy::WritingCompleted(torch::Tensor keys) { + ReadingWritingCompletedImpl(keys); +} + c10::intrusive_ptr> PartitionedCachePolicy::ReadingCompletedAsync( torch::Tensor keys) { return async([=] { return ReadingCompleted(keys); }); } +c10::intrusive_ptr> PartitionedCachePolicy::WritingCompletedAsync( + torch::Tensor keys) { + return async([=] { return WritingCompleted(keys); }); +} + template c10::intrusive_ptr PartitionedCachePolicy::Create( int64_t capacity, int64_t num_partitions) { diff --git a/graphbolt/src/partitioned_cache_policy.h b/graphbolt/src/partitioned_cache_policy.h index 3a9ba7e26c81..5d7495be9ad0 100644 --- a/graphbolt/src/partitioned_cache_policy.h +++ b/graphbolt/src/partitioned_cache_policy.h @@ -81,14 +81,25 @@ class PartitionedCachePolicy : public BaseCachePolicy, c10::intrusive_ptr> ReplaceAsync(torch::Tensor keys); + template + void ReadingWritingCompletedImpl(torch::Tensor keys); + /** * @brief A reader has finished reading these keys, so they can be evicted. * @param keys The keys to unmark. */ void ReadingCompleted(torch::Tensor keys); + /** + * @brief A writer has finished writing these keys, so they can be evicted. + * @param keys The keys to unmark. + */ + void WritingCompleted(torch::Tensor keys); + c10::intrusive_ptr> ReadingCompletedAsync(torch::Tensor keys); + c10::intrusive_ptr> WritingCompletedAsync(torch::Tensor keys); + template static c10::intrusive_ptr Create( int64_t capacity, int64_t num_partitions); diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 5b5d3f863120..ca62fab419ba 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -114,7 +114,13 @@ TORCH_LIBRARY(graphbolt, m) { &storage::PartitionedCachePolicy::ReadingCompleted) .def( "reading_completed_async", - &storage::PartitionedCachePolicy::ReadingCompletedAsync); + &storage::PartitionedCachePolicy::ReadingCompletedAsync) + .def( + "writing_completed", + &storage::PartitionedCachePolicy::WritingCompleted) + .def( + "writing_completed_async", + &storage::PartitionedCachePolicy::WritingCompletedAsync); m.def( "s3_fifo_cache_policy", &storage::PartitionedCachePolicy::Create); diff --git a/python/dgl/graphbolt/impl/cpu_cached_feature.py b/python/dgl/graphbolt/impl/cpu_cached_feature.py index 2bdc6153f6a8..58f6aa7f0f00 100644 --- a/python/dgl/graphbolt/impl/cpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/cpu_cached_feature.py @@ -172,7 +172,7 @@ def read_async(self, ids: torch.Tensor): reading_completed.wait() replace_future.wait() - reading_completed = policy.reading_completed_async(missing_keys) + writing_completed = policy.writing_completed_async(missing_keys) num_found = positions.size(0) class _Waiter: @@ -201,7 +201,7 @@ def wait(self): return values yield _Waiter( - [missing_values_copy_event, reading_completed], + [missing_values_copy_event, writing_completed], values_from_cpu, missing_values_cuda, index, @@ -264,7 +264,7 @@ def wait(self): reading_completed.wait() replace_future.wait() - reading_completed = policy.reading_completed_async(missing_keys) + writing_completed = policy.writing_completed_async(missing_keys) class _Waiter: def __init__(self, events, values): @@ -280,7 +280,7 @@ def wait(self): self.events = self.values = None return values - yield _Waiter([values_copy_event, reading_completed], values_cuda) + yield _Waiter([values_copy_event, writing_completed], values_cuda) else: policy_future = policy.query_async(ids) @@ -319,7 +319,7 @@ def wait(self): reading_completed.wait() replace_future.wait() - reading_completed = policy.reading_completed_async(missing_keys) + writing_completed = policy.writing_completed_async(missing_keys) class _Waiter: def __init__(self, event, values): @@ -334,7 +334,7 @@ def wait(self): self.event = self.values = None return values - yield _Waiter(reading_completed, values) + yield _Waiter(writing_completed, values) def read_async_num_stages(self, ids_device: torch.device): """The number of stages of the read_async operation. See read_async diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index 1cdeaaa3867a..c61924a7093b 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -96,7 +96,7 @@ def replace(self, keys, values): """ positions = self._policy.replace(keys) self._cache.replace(positions, values) - self._policy.reading_completed(keys) + self._policy.writing_completed(keys) @property def miss_rate(self):