Skip to content

Commit

Permalink
Merge branch 'master' into gb_async_feature_fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 22, 2024
2 parents dd636ed + afcf65c commit d9c73e9
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 139 deletions.
24 changes: 15 additions & 9 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
*/
#include "./cache_policy.h"

#include "./utils.h"

namespace graphbolt {
namespace storage {

template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto indices = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
auto filtered_keys =
torch::empty_like(keys, keys.options().pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto filtered_keys = torch::empty_like(
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
Expand Down Expand Up @@ -63,8 +67,9 @@ template <typename CachePolicy>
torch::Tensor BaseCachePolicy::ReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
Expand Down Expand Up @@ -124,7 +129,8 @@ void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
}

SieveCachePolicy::SieveCachePolicy(int64_t capacity)
: hand_(queue_.end()), capacity_(capacity), cache_usage_(0) {
// Ensure that queue_ is constructed first before accessing its `.end()`.
: queue_(), hand_(queue_.end()), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
}
Expand Down
18 changes: 18 additions & 0 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ struct CacheKey {

class BaseCachePolicy {
public:
/**
* @brief A virtual base class constructor ensures that the derived class
* destructor gets called.
*/
virtual ~BaseCachePolicy() = default;

/**
* @brief The policy query function.
* @param keys The keys to query the cache.
Expand Down Expand Up @@ -144,6 +150,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {

S3FifoCachePolicy() = default;

S3FifoCachePolicy(S3FifoCachePolicy&&) = default;

virtual ~S3FifoCachePolicy() = default;

/**
* @brief See BaseCachePolicy::Query.
*/
Expand Down Expand Up @@ -254,6 +264,8 @@ class SieveCachePolicy : public BaseCachePolicy {

SieveCachePolicy() = default;

virtual ~SieveCachePolicy() = default;

/**
* @brief See BaseCachePolicy::Query.
*/
Expand Down Expand Up @@ -332,6 +344,8 @@ class LruCachePolicy : public BaseCachePolicy {

LruCachePolicy() = default;

virtual ~LruCachePolicy() = default;

/**
* @brief See BaseCachePolicy::Query.
*/
Expand Down Expand Up @@ -409,6 +423,10 @@ class ClockCachePolicy : public BaseCachePolicy {

ClockCachePolicy() = default;

ClockCachePolicy(ClockCachePolicy&&) = default;

virtual ~ClockCachePolicy() = default;

/**
* @brief See BaseCachePolicy::Query.
*/
Expand Down
4 changes: 3 additions & 1 deletion graphbolt/src/cnumpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <stdexcept>

#include "./circular_queue.h"
#include "./utils.h"

namespace graphbolt {
namespace storage {
Expand Down Expand Up @@ -152,13 +153,14 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
shape, index.options()
.dtype(dtype_)
.layout(torch::kStrided)
.pinned_memory(index.is_pinned())
.pinned_memory(utils::is_pinned(index))
.requires_grad(false));
auto result_buffer = reinterpret_cast<char *>(result.data_ptr());

// Indicator for index error.
std::atomic<int> error_flag{};
std::atomic<int64_t> work_queue{};
std::lock_guard lock(mtx_);
torch::parallel_for(0, num_thread_, 1, [&](int64_t begin, int64_t end) {
if (begin >= end) return;
const auto thread_id = begin;
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/cnumpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -113,6 +114,7 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
int64_t aligned_length_; // Aligned feature_size.
int num_thread_; // Default thread number.
torch::Tensor read_tensor_; // Provides temporary read buffer.
std::mutex mtx_;

#ifdef HAVE_LIBRARY_LIBURING
std::unique_ptr<io_uring[]> io_uring_queue_; // io_uring queue.
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/cuda/gpu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace cuda {
GpuCache::GpuCache(const std::vector<int64_t> &shape, torch::ScalarType dtype) {
TORCH_CHECK(shape.size() >= 2, "Shape must at least have 2 dimensions.");
const auto num_items = shape[0];
TORCH_CHECK(
num_items > 0, "The capacity of GpuCache needs to be a positive.");
const int64_t num_feats =
std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>());
const int element_size =
Expand Down
4 changes: 3 additions & 1 deletion graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "./feature_cache.h"

#include "./index_select.h"
#include "./utils.h"

namespace graphbolt {
namespace storage {
Expand All @@ -34,7 +35,8 @@ FeatureCache::FeatureCache(

torch::Tensor FeatureCache::Query(
torch::Tensor positions, torch::Tensor indices, int64_t size) {
const bool pin_memory = positions.is_pinned() || indices.is_pinned();
const bool pin_memory =
utils::is_pinned(positions) || utils::is_pinned(indices);
std::vector<int64_t> output_shape{
tensor_.sizes().begin(), tensor_.sizes().end()};
output_shape[0] = size;
Expand Down
5 changes: 3 additions & 2 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
auto output_shape = input.sizes().vec();
output_shape[0] = index.numel();
auto result = torch::empty(
output_shape,
index.options().dtype(input.dtype()).pinned_memory(index.is_pinned()));
output_shape, index.options()
.dtype(input.dtype())
.pinned_memory(utils::is_pinned(index)));
return torch::index_select_out(result, input, 0, index);
}

Expand Down
54 changes: 36 additions & 18 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include <numeric>

#include "./utils.h"

namespace graphbolt {
namespace storage {

Expand Down Expand Up @@ -114,7 +116,10 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) {

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
PartitionedCachePolicy::Query(torch::Tensor keys) {
if (policies_.size() == 1) return policies_[0]->Query(keys);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
return policies_[0]->Query(keys);
};
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto offsets_ptr = offsets.data_ptr<int64_t>();
Expand All @@ -125,30 +130,36 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
torch::Tensor result_offsets_tensor =
torch::empty(policies_.size() * 2 + 1, offsets.options());
auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
TORCH_CHECK(end - begin == 1);
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
results[tid] = policies_.at(tid)->Query(permuted_keys.slice(0, begin, end));
result_offsets[tid] = std::get<0>(results[tid]).size(0);
result_offsets[tid + policies_.size()] = std::get<2>(results[tid]).size(0);
});
{
std::lock_guard lock(mtx_);
torch::parallel_for(
0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
TORCH_CHECK(end - begin == 1);
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
results[tid] =
policies_.at(tid)->Query(permuted_keys.slice(0, begin, end));
result_offsets[tid] = std::get<0>(results[tid]).size(0);
result_offsets[tid + policies_.size()] =
std::get<2>(results[tid]).size(0);
});
}
std::exclusive_scan(
result_offsets, result_offsets + result_offsets_tensor.size(0),
result_offsets, 0);
torch::Tensor positions = torch::empty(
result_offsets[policies_.size()],
std::get<0>(results[0]).options().pinned_memory(keys.is_pinned()));
std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor output_indices = torch::empty_like(
indices, indices.options().pinned_memory(keys.is_pinned()));
indices, indices.options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor missing_keys = torch::empty(
indices.size(0) - positions.size(0),
std::get<2>(results[0]).options().pinned_memory(keys.is_pinned()));
std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor found_keys = torch::empty(
positions.size(0),
std::get<3>(results[0]).options().pinned_memory(keys.is_pinned()));
std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
auto output_indices_ptr = output_indices.data_ptr<int64_t>();
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
Expand Down Expand Up @@ -196,15 +207,20 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
}

torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) {
if (policies_.size() == 1) return policies_[0]->Replace(keys);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
return policies_[0]->Replace(keys);
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto output_positions = torch::empty_like(
keys,
keys.options().dtype(torch::kInt64).pinned_memory(keys.is_pinned()));
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto offsets_ptr = offsets.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto output_positions_ptr = output_positions.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
Expand All @@ -228,12 +244,14 @@ c10::intrusive_ptr<Future<torch::Tensor>> PartitionedCachePolicy::ReplaceAsync(

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
policies_[0]->ReadingCompleted(keys);
return;
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto offsets_ptr = offsets.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/partitioned_cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <torch/custom_class.h>
#include <torch/torch.h>

#include <mutex>
#include <pcg_random.hpp>
#include <random>
#include <type_traits>
Expand Down Expand Up @@ -118,6 +119,7 @@ class PartitionedCachePolicy : public BaseCachePolicy,

int64_t capacity_;
std::vector<std::unique_ptr<BaseCachePolicy>> policies_;
std::mutex mtx_;
};

} // namespace storage
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "./index_select.h"
#include "./partitioned_cache_policy.h"
#include "./random.h"
#include "./utils.h"

#ifdef GRAPHBOLT_USE_CUDA
#include "./cuda/extension/gpu_graph_cache.h"
Expand Down Expand Up @@ -145,6 +146,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("index_select_csc_batched", &ops::IndexSelectCSCBatched);
m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create);
m.def("detect_io_uring", &io_uring::IsAvailable);
m.def("set_worker_id", &utils::SetWorkerId);
m.def("set_seed", &RandomEngine::SetManualSeed);
#ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
Expand Down
36 changes: 36 additions & 0 deletions graphbolt/src/utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @file utils.cc
* @brief Graphbolt utils implementations.
*/
#include "./utils.h"

#include <optional>

namespace graphbolt {
namespace utils {

namespace {
std::optional<int64_t> worker_id;
}

std::optional<int64_t> GetWorkerId() { return worker_id; }

void SetWorkerId(int64_t worker_id_value) { worker_id = worker_id_value; }

} // namespace utils
} // namespace graphbolt
20 changes: 20 additions & 0 deletions graphbolt/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
namespace graphbolt {
namespace utils {

/**
* @brief If this process is a worker part as part of a DataLoader, then returns
* the assigned worker id less than the # workers.
*/
std::optional<int64_t> GetWorkerId();

/**
* @brief If this process is a worker part as part of a DataLoader, then this
* function is called to initialize its worked id to be less than the # workers.
*/
void SetWorkerId(int64_t worker_id_value);

/**
* @brief Checks whether the tensor is stored on the GPU.
*/
Expand All @@ -26,6 +38,14 @@ inline bool is_accessible_from_gpu(const torch::Tensor& tensor) {
return is_on_gpu(tensor) || tensor.is_pinned();
}

/**
* @brief Checks whether the tensor is stored on the pinned memory.
*/
inline bool is_pinned(const torch::Tensor& tensor) {
// If this process is a worker, we should avoid initializing the CUDA context.
return !GetWorkerId() && tensor.is_pinned();
}

/**
* @brief Checks whether the tensors are all stored on the GPU or the pinned
* memory.
Expand Down
Loading

0 comments on commit d9c73e9

Please sign in to comment.