Skip to content

Commit

Permalink
Simplify the GPU memory logger. (dmlc#10927)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 24, 2024
1 parent e8a3ead commit 18edf86
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 76 deletions.
121 changes: 45 additions & 76 deletions src/common/device_vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1

#include <cuda.h>
#include <cuda.h> // for CUmemGenericAllocationHandle

#include <atomic> // for atomic, memory_order
#include <cstddef> // for size_t
#include <cstdint> // for int64_t
#include <cub/util_allocator.cuh> // for CachingDeviceAllocator
#include <cub/util_device.cuh> // for CurrentDevice
#include <map> // for map
#include <memory> // for unique_ptr
#include <mutex> // for defer_lock

#include "common.h" // for safe_cuda, HumanMemUnit
#include "cuda_dr_utils.h" // for CuDriverApi
Expand All @@ -41,92 +41,70 @@

namespace dh {
namespace detail {
// std::atomic::fetch_max in c++26
template <typename T>
T AtomicFetchMax(std::atomic<T> &atom, T val, // NOLINT
std::memory_order order = std::memory_order_seq_cst) {
auto expected = atom.load();
auto desired = expected > val ? expected : val;

while (desired == val && !atom.compare_exchange_strong(expected, desired, order, order)) {
desired = expected > val ? expected : val;
}

return expected;
}

/** \brief Keeps track of global device memory allocations. Thread safe.*/
class MemoryLogger {
// Information for a single device
struct DeviceStats {
std::size_t currently_allocated_bytes{0};
size_t peak_allocated_bytes{0};
size_t num_allocations{0};
size_t num_deallocations{0};
std::map<void *, size_t> device_allocations;
void RegisterAllocation(void *ptr, size_t n) {
auto itr = device_allocations.find(ptr);
if (itr != device_allocations.cend()) {
LOG(WARNING) << "Attempting to allocate " << n << " bytes."
<< " that was already allocated\nptr:" << ptr << "\n"
<< dmlc::StackTrace();
}
device_allocations[ptr] = n;
// Use signed int to allow temporary under-flow.
std::atomic<std::int64_t> currently_allocated_bytes{0};
std::atomic<std::int64_t> peak_allocated_bytes{0};
void RegisterAllocation(std::int64_t n) {
currently_allocated_bytes += n;
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
num_allocations++;
CHECK_GT(num_allocations, num_deallocations);
}
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) {
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
<< " that was never allocated\nptr:" << ptr << "\n"
<< dmlc::StackTrace();
} else {
num_deallocations++;
CHECK_LE(num_deallocations, num_allocations);
currently_allocated_bytes -= itr->second;
device_allocations.erase(itr);
}
AtomicFetchMax(peak_allocated_bytes, currently_allocated_bytes.load());
}
void RegisterDeallocation(std::int64_t n) { currently_allocated_bytes -= n; }
};
DeviceStats stats_;
std::mutex mutex_;

public:
/**
* @brief Register the allocation for logging.
*
* @param lock Set to false if the allocator has locking machanism.
*/
void RegisterAllocation(void *ptr, size_t n, bool lock) {
void RegisterAllocation(std::size_t n) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return;
}
std::unique_lock guard{mutex_, std::defer_lock};
if (lock) {
guard.lock();
}
stats_.RegisterAllocation(ptr, n);
stats_.RegisterAllocation(static_cast<std::int64_t>(n));
}
/**
* @brief Register the deallocation for logging.
*
* @param lock Set to false if the allocator has locking machanism.
*/
void RegisterDeallocation(void *ptr, size_t n, bool lock) {
void RegisterDeallocation(std::size_t n) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return;
}
std::unique_lock guard{mutex_, std::defer_lock};
if (lock) {
guard.lock();
}
stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice());
stats_.RegisterDeallocation(static_cast<std::int64_t>(n));
}
std::int64_t PeakMemory() const { return stats_.peak_allocated_bytes; }
std::int64_t CurrentlyAllocatedBytes() const { return stats_.currently_allocated_bytes; }
void Clear() {
stats_.currently_allocated_bytes = 0;
stats_.peak_allocated_bytes = 0;
}
size_t PeakMemory() const { return stats_.peak_allocated_bytes; }
size_t CurrentlyAllocatedBytes() const { return stats_.currently_allocated_bytes; }
void Clear() { stats_ = DeviceStats(); }

void Log() {
void Log() const {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return;
}
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
dh::safe_cuda(cudaGetDevice(&current_device));
auto current_device = cub::CurrentDevice();
LOG(CONSOLE) << "======== Device " << current_device << " Memory Allocations: "
<< " ========";
LOG(CONSOLE) << "Peak memory usage: "
<< xgboost::common::HumanMemUnit(stats_.peak_allocated_bytes);
LOG(CONSOLE) << "Number of allocations: " << stats_.num_allocations;
}
};

Expand Down Expand Up @@ -313,12 +291,11 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
} catch (const std::exception &e) {
detail::ThrowOOMError(e.what(), n * sizeof(T));
}
// We can't place a lock here as template allocator is transient.
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T), true);
GlobalMemoryLogger().RegisterAllocation(n * sizeof(T));
return ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
GlobalMemoryLogger().RegisterDeallocation(n * sizeof(T));
SuperT::deallocate(ptr, n);
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
Expand Down Expand Up @@ -367,14 +344,13 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
detail::ThrowOOMError(e.what(), n * sizeof(T));
}
}
// We can't place a lock here as template allocator is transient.
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T), true);
GlobalMemoryLogger().RegisterAllocation(n * sizeof(T));
return thrust_ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
GlobalMemoryLogger().RegisterDeallocation(n * sizeof(T));
if (use_cub_allocator_) {
GetGlobalCachingAllocator().DeviceFree(ptr.get());
GetGlobalCachingAllocator().DeviceFree(thrust::raw_pointer_cast(ptr));
} else {
SuperT::deallocate(ptr, n);
}
Expand Down Expand Up @@ -402,7 +378,9 @@ using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
template <typename T>
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;

/** @brief Specialisation of thrust device vector using custom allocator. */
/** @brief Specialisation of thrust device vector using custom allocator. In addition, it catches
* OOM errors.
*/
template <typename T>
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLINT
template <typename T>
Expand All @@ -414,7 +392,6 @@ using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocato
*/
class LoggingResource : public rmm::mr::device_memory_resource {
rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()};
std::mutex lock_;

public:
LoggingResource() = default;
Expand All @@ -432,13 +409,9 @@ class LoggingResource : public rmm::mr::device_memory_resource {
}

void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
try {
auto const ptr = mr_->allocate(bytes, stream);
GlobalMemoryLogger().RegisterAllocation(ptr, bytes, false);
GlobalMemoryLogger().RegisterAllocation(bytes);
return ptr;
} catch (rmm::bad_alloc const &e) {
detail::ThrowOOMError(e.what(), bytes);
Expand All @@ -448,12 +421,8 @@ class LoggingResource : public rmm::mr::device_memory_resource {

void do_deallocate(void *ptr, std::size_t bytes, // NOLINT
rmm::cuda_stream_view stream) override {
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
mr_->deallocate(ptr, bytes, stream);
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes, false);
GlobalMemoryLogger().RegisterDeallocation(bytes);
}

[[nodiscard]] bool do_is_equal( // NOLINT
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/common/test_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <thread> // for thread

#include <numeric> // for iota
#include <thrust/detail/sequence.inl> // for sequence
Expand Down Expand Up @@ -120,4 +121,22 @@ TEST(TestVirtualMem, Version) {
}
#endif // defined(xgboost_IS_WIN)
}

TEST(AtomitFetch, Max) {
auto n_threads = std::thread::hardware_concurrency();
std::vector<std::thread> threads;
std::atomic<std::int64_t> n{0};
decltype(n)::value_type add = 64;
for (decltype(n_threads) t = 0; t < n_threads; ++t) {
threads.emplace_back([=, &n] {
for (decltype(add) i = 0; i < add; ++i) {
detail::AtomicFetchMax(n, static_cast<decltype(add)>(t + i));
}
});
}
for (auto& t : threads) {
t.join();
}
ASSERT_EQ(n, n_threads - 1 + add - 1); // 0-based indexing
}
} // namespace dh

0 comments on commit 18edf86

Please sign in to comment.