From 12bba85ebac70eb3eef02dffc75ba1ae0beef1a4 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 28 Sep 2021 11:08:49 +0000 Subject: [PATCH 01/40] Use cuda virtual memory management and merge blocks, test=develop --- paddle/fluid/memory/allocation/CMakeLists.txt | 10 +- .../memory/allocation/allocator_facade.cc | 39 +++ .../auto_growth_best_fit_allocator.cc | 1 + .../auto_growth_best_fit_allocator_v2.cc | 300 ++++++++++++++++++ .../auto_growth_best_fit_allocator_v2.h | 90 ++++++ .../allocation/cuda_virtual_mem_allocator.cc | 193 +++++++++++ .../allocation/cuda_virtual_mem_allocator.h | 60 ++++ paddle/fluid/platform/dynload/cuda_driver.h | 12 +- 8 files changed, 703 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc create mode 100644 paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h create mode 100644 paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc create mode 100644 paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 6b4afae9f8c75..0f4de48079957 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -18,6 +18,9 @@ if (WITH_GPU) nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator) + if(CUDA_VERSION GREATER_EQUAL 10.2) + nv_library(cuda_virtual_mem_allocator SRCS cuda_virtual_mem_allocator.cc DEPS dynload_cuda) + endif() endif() if (WITH_ROCM) @@ -36,6 +39,9 @@ cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) if (WITH_GPU OR WITH_ROCM) set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator) + if(CUDA_VERSION GREATER_EQUAL 10.2) + list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator) + endif() elseif(WITH_XPU) set(AllocatorFacadeDeps xpu_info) elseif(WITH_ASCEND) @@ -72,7 +78,7 @@ else() cpu_allocator) endif() -list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator best_fit_allocator) +list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator auto_growth_best_fit_allocator_v2 best_fit_allocator) if (WITH_ASCEND_CL) list(APPEND AllocatorFacadeDeps npu_pinned_allocator) @@ -103,6 +109,8 @@ cc_library(auto_growth_best_fit_allocator SRCS auto_growth_best_fit_allocator.cc cc_test(auto_growth_best_fit_allocator_facade_test SRCS auto_growth_best_fit_allocator_facade_test.cc DEPS cpu_allocator auto_growth_best_fit_allocator) cc_test(auto_growth_best_fit_allocator_test SRCS auto_growth_best_fit_allocator_test.cc DEPS auto_growth_best_fit_allocator) +cc_library(auto_growth_best_fit_allocator_v2 SRCS auto_growth_best_fit_allocator_v2.cc DEPS allocator aligned_allocator) + if(NOT WIN32) cc_library(mmap_allocator SRCS mmap_allocator.cc DEPS allocator) cc_test(mmap_allocator_test SRCS mmap_allocator_test.cc DEPS mmap_allocator allocator) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 78bce53b6f4ff..c50659ea49fc9 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h" +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/naive_best_fit_allocator.h" #ifdef PADDLE_WITH_ASCEND_CL @@ -32,9 +33,13 @@ #include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/gpu_info.h" #endif +#if CUDA_VERSION >= 10020 +#include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#endif #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu/xpu_info.h" #endif +#include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/fluid/platform/npu_info.h" PADDLE_DEFINE_EXPORTED_int64( @@ -184,9 +189,43 @@ class AllocatorFacadePrivate { } void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) { +#if defined(PADDLE_WITH_HIP) + auto cuda_allocator = std::make_shared(p); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize()); +#endif + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 10020 + CUdevice device; + PADDLE_ENFORCE_EQ( + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()), + CUDA_SUCCESS, platform::errors::Fatal("Call cuDeviceGet faild.")); + + int val; + PADDLE_ENFORCE_EQ( + paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, + device), + CUDA_SUCCESS, + platform::errors::Fatal("Call cuDeviceGetAttribute faild.")); + + if (val > 0) { + auto cuda_allocator = std::make_shared(p); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize()); + } else { + auto cuda_allocator = std::make_shared(p); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize()); + } + +#else auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( cuda_allocator, platform::GpuMinChunkSize()); +#endif +#endif } #endif diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc index a35d8a73f7eda..25029444b2b85 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc @@ -97,6 +97,7 @@ Allocation *AutoGrowthBestFitAllocator::AllocateImpl(size_t size) { VLOG(2) << "Not found and reallocate " << realloc_size << ", and remaining " << remaining_size; } + return new BlockAllocation(block_it); } diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc new file mode 100644 index 0000000000000..5903d5b33f59f --- /dev/null +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -0,0 +1,300 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/memory/allocation/aligned_allocator.h" +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" + +#pragma GCC diagnostic ignored "-Wpointer-arith" + +namespace paddle { +namespace memory { +namespace allocation { + +bool NeedSplit(size_t block_size, size_t alignment, size_t allock_size) { + return block_size > (allock_size * 2) || + (block_size - allock_size) > alignment; +} + +AutoGrowthBestFitAllocatorV2::AutoGrowthBestFitAllocatorV2( + const std::shared_ptr &underlying_allocator, size_t alignment) + : underlying_allocator_( + std::make_shared(underlying_allocator, alignment)), + alignment_(alignment) {} + +Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { + size = AlignedSize(size, alignment_); + auto result = AllocFromFreeBlocks(size); + + if (!result) { + auto allocateptr = underlying_allocator_->Allocate(size); + TryMergeAlloctation2Blocks(allocateptr->ptr(), allocateptr->size()); + regions_.emplace(std::move(allocateptr)); + result = AllocFromFreeBlocks(size); + } + + // std::cout << "alloc " << result->ptr() << " " << result->size() << + // std::endl; + + return result; +} + +void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { + auto block_it = static_cast(allocation)->block_it_; + TryMergeBlock2Blocks(block_it); + delete allocation; +} + +uint64_t AutoGrowthBestFitAllocatorV2::ReleaseImpl( + const platform::Place &place) { + return 0; +} + +void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( + std::list::iterator block) { + std::lock_guard guard(spinlock_); + if (block->ptr_ == all_blocks_.front().ptr_ && + block->ptr_ == all_blocks_.back().ptr_) { + block->is_free_ = true; + // std::cout << "back1 " << block->ptr_ << " " << block->size_ << std::endl; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } else if (block->ptr_ == all_blocks_.front().ptr_) { + block++; + auto next = block; + block--; + if (next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == next->ptr_) { + // merge with next + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + block->size_ += next->size_; + block->is_free_ = true; + // std::cout << "merge1 " << block->ptr_ << " " << next->ptr_ << " " << + // block->size_ << std::endl; + all_blocks_.erase(next); + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } else { + block->is_free_ = true; + // std::cout << "back2 " << block->ptr_ << " " << block->size_ << + // std::endl; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } + } else if (block->ptr_ == all_blocks_.back().ptr_) { + block--; + auto pre = block; + block++; + if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_) { + // merge with pre + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + pre->size_ += block->size_; + // std::cout << "merge2 " << pre->ptr_ << " " << block->ptr_ << " " << + // pre->size_ << std::endl; + all_blocks_.erase(block); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else { + block->is_free_ = true; + // std::cout << "back3 " << block->ptr_ << " " << block->size_ << + // std::endl; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } + } else { + block--; + auto pre = block; + block++; + block++; + auto next = block; + block--; + if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_ && + !(next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == + next->ptr_)) { + // merge with pre + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + pre->size_ += block->size_; + // std::cout << "merge3 " << pre->ptr_ << " " << block->ptr_ << " " << + // pre->size_ << std::endl; + all_blocks_.erase(block); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else if (next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == + next->ptr_ && + !(pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == + block->ptr_)) { + // merge with next + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + block->size_ += next->size_; + block->is_free_ = true; + // std::cout << "merge4 " << block->ptr_ << " " << next->ptr_ << " " << + // block->size_ << std::endl; + all_blocks_.erase(next); + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } else if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == + block->ptr_ && + next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == + next->ptr_) { + // merge with pre and next + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + pre->size_ += (block->size_ + next->size_); + // std::cout << "merge5 " << pre->ptr_ << " " << block->ptr_ << " " << + // next->ptr_ << " " << pre->size_ << std::endl; + all_blocks_.erase(block); + all_blocks_.erase(next); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else { + block->is_free_ = true; + // std::cout << "back4 " << block->ptr_ << " " << block->size_ << + // std::endl; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } + } +} + +void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, + size_t size) { + std::lock_guard guard(spinlock_); + if (all_blocks_.empty()) { + all_blocks_.push_back(Block(ptr, size, true)); + // std::cout << "insert1 " << ptr << " " << size << std::endl; + free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); + return; + } + for (auto block_it = all_blocks_.begin(); block_it != all_blocks_.end(); + ++block_it) { + if (block_it->ptr_ > ptr) { + if (block_it == all_blocks_.begin()) { + // insert to front + if (block_it->is_free_ && + reinterpret_cast(ptr) + size == block_it->ptr_) { + // merge with next + free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); + // std::cout << "merge6 " << ptr << " " << block_it->ptr_ << " " << + // block_it->size_+size << std::endl; + block_it->ptr_ = ptr; + block_it->size_ += size; + free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), + block_it); + } else { + // do not merge + all_blocks_.push_front(Block(ptr, size, true)); + // std::cout << "insert2 " << ptr << " " << size << std::endl; + free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); + } + } else { + // insert to middle + auto next = block_it; + block_it--; + auto pre = block_it; + if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == ptr && + !(next->is_free_ && + reinterpret_cast(ptr) + size == next->ptr_)) { + // merge with pre + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + pre->size_ += size; + // std::cout << "merge7 " << pre->ptr_ << " " << ptr << " " << + // pre->size_ << std::endl; + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else if (next->is_free_ && + reinterpret_cast(ptr) + size == next->ptr_ && + !(pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == + ptr)) { + // merge with next + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + // std::cout << "merge8 " << ptr << " " << next->ptr_ << " " << + // next->size_+size << std::endl; + next->ptr_ = ptr; + next->size_ += size; + free_blocks_.emplace(std::make_pair(next->size_, next->ptr_), next); + } else if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == ptr && + next->is_free_ && + reinterpret_cast(ptr) + size == next->ptr_) { + // merge with pre and next + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + // std::cout << "merge9 " << pre->ptr_ << " " << ptr << " " << + // next->ptr_ << " " << pre->size_+next->size_+size << std::endl; + pre->size_ += (size + next->size_); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + all_blocks_.erase(next); + } else { + // do not merge + auto iter = all_blocks_.insert(next, Block(ptr, size, true)); + // std::cout << "insert3 " << ptr << " " << size << std::endl; + free_blocks_.emplace(std::make_pair(size, ptr), iter); + } + } + return; + } + } + + // insert to back + auto block_it = all_blocks_.end(); + block_it--; + if (block_it->is_free_ && + reinterpret_cast(block_it->ptr_) + block_it->size_ == ptr) { + // merge with pre + free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); + block_it->size_ += size; + // std::cout << "merge10 " << block_it->ptr_ << " " << ptr << " " << + // block_it->size_ << std::endl; + free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), + block_it); + } else { + // do not merge + all_blocks_.push_back(Block(ptr, size, true)); + auto block_it = all_blocks_.end(); + block_it--; + // std::cout << "insert4 " << ptr << " " << size << std::endl; + free_blocks_.emplace(std::make_pair(size, ptr), block_it); + } +} + +Allocation *AutoGrowthBestFitAllocatorV2::AllocFromFreeBlocks(size_t size) { + std::lock_guard guard(spinlock_); + auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); + if (iter != free_blocks_.end()) { + std::list::iterator block_it = iter->second; + free_blocks_.erase(iter); + if (NeedSplit(block_it->size_, alignment_, size)) { + size_t remaining_size = block_it->size_ - size; + auto remaining_free_block = all_blocks_.insert( + block_it, Block(block_it->ptr_, remaining_size, true)); + free_blocks_.emplace(std::make_pair(remaining_size, block_it->ptr_), + remaining_free_block); + block_it->ptr_ = + reinterpret_cast(block_it->ptr_) + remaining_size; + block_it->size_ = size; + // std::cout << "split " << remaining_free_block->ptr_ << " " << + // remaining_free_block->size_ << " " << block_it->ptr_ << " " << + // block_it->size_ << std::endl; + } + + block_it->is_free_ = false; + return new BlockAllocation(block_it, place_); + } + + return nullptr; +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h new file mode 100644 index 0000000000000..ad21e3157da90 --- /dev/null +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h @@ -0,0 +1,90 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/allocation/spin_lock.h" + +namespace paddle { +namespace memory { +namespace allocation { + +struct Block { + Block(void *ptr, size_t size, bool is_free) + : ptr_(ptr), size_(size), is_free_(is_free) {} + + void *ptr_; + size_t size_; + bool is_free_; +}; + +struct Region { + explicit Region(AllocationPtr allocation) + : allocation_(std::move(allocation)) {} + + AllocationPtr allocation_; +}; + +struct RegionComp { + bool operator()(const Region &a, const Region &b) { + return a.allocation_->ptr() < b.allocation_->ptr(); + } +}; + +struct BlockAllocation : public Allocation { + explicit BlockAllocation(const std::list::iterator &it, + platform::Place place) + : Allocation(it->ptr_, it->size_, place), block_it_(it) {} + + std::list::iterator block_it_; +}; + +class AutoGrowthBestFitAllocatorV2 : public Allocator { + public: + AutoGrowthBestFitAllocatorV2( + const std::shared_ptr &underlying_allocator, size_t alignment); + + bool IsAllocThreadSafe() const override { return true; } + + protected: + Allocation *AllocateImpl(size_t size) override; + + void FreeImpl(Allocation *allocation) override; + + // Release the memory block which is not used in pool. + uint64_t ReleaseImpl(const platform::Place &place) override; + + private: + Allocation *AllocFromFreeBlocks(size_t size); + void TryMergeAlloctation2Blocks(void *ptr, size_t size); + void TryMergeBlock2Blocks(std::list::iterator iter); + + std::shared_ptr underlying_allocator_; + size_t alignment_; + + std::map, std::list::iterator> free_blocks_; + std::list all_blocks_; + std::set regions_; + platform::Place place_; + SpinLock spinlock_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc new file mode 100644 index 0000000000000..7925140fd6716 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#include "paddle/fluid/platform/dynload/cuda_driver.h" + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif + +#ifdef PADDLE_WITH_HIP +#include +#endif + +#include +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/gpu_info.h" + +#if CUDA_VERSION >= 10020 + +namespace paddle { +namespace memory { +namespace allocation { + +#define PADDLE_ENFORCE_CUDA_SUCCESS2(COND) \ + do { \ + auto __cond__ = (COND); \ + if (UNLIKELY(__cond__ != CUDA_SUCCESS)) { \ + auto __summary__ = \ + ::paddle::platform::errors::External("cu error %d", __cond__); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) + +CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( + const platform::CUDAPlace& place) + : place_(place) { + prop_.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop_.location.id = place.GetDeviceId(); + + access_desc_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc_.location.id = place.GetDeviceId(); + access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + PADDLE_ENFORCE_CUDA_SUCCESS2( + paddle::platform::dynload::cuMemGetAllocationGranularity( + &granularity_, &prop_, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t actual_avail, actual_total; + paddle::platform::CUDADeviceGuard guard(place.GetDeviceId()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); + + // virtual_mem_size_ = actual_total/2; + + virtual_mem_size_ = (actual_total + granularity_ - 1) & ~(granularity_ - 1); + + // std::cout << "virtual_mem_size_=" << virtual_mem_size_ << std::endl; + + PADDLE_ENFORCE_CUDA_SUCCESS2(paddle::platform::dynload::cuMemAddressReserve( + &virtual_mem_base_, virtual_mem_size_, 0, 0, 0)); + virtual_mem_alloced_offset_ = 0; +} + +CUDAVirtualMemAllocator::~CUDAVirtualMemAllocator() { + paddle::platform::CUDADeviceGuard guard(place_.GetDeviceId()); + for (auto& item : virtual_2_physical_map_) { + PADDLE_ENFORCE_CUDA_SUCCESS2( + paddle::platform::dynload::cuMemUnmap(item.first, item.second.second)); + PADDLE_ENFORCE_CUDA_SUCCESS2( + paddle::platform::dynload::cuMemRelease(item.second.first)); + } + + PADDLE_ENFORCE_CUDA_SUCCESS2(paddle::platform::dynload::cuMemAddressFree( + virtual_mem_base_, virtual_mem_size_)); +} + +bool CUDAVirtualMemAllocator::IsAllocThreadSafe() const { return false; } + +void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { + PADDLE_ENFORCE_EQ( + BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, + platform::errors::PermissionDenied( + "GPU memory is freed in incorrect device. This may be a bug")); + + auto iter = virtual_2_physical_map_.find( + reinterpret_cast(allocation->ptr())); + if (iter == virtual_2_physical_map_.end()) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Can not find virtual memory address at %s", allocation->ptr())); + } + + paddle::platform::CUDADeviceGuard guard(place_.GetDeviceId()); + PADDLE_ENFORCE_CUDA_SUCCESS2( + paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second)); + PADDLE_ENFORCE_CUDA_SUCCESS2( + paddle::platform::dynload::cuMemRelease(iter->second.first)); + + virtual_2_physical_map_.erase(iter); + + delete allocation; +} + +Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { + size = (size + granularity_ - 1) & ~(granularity_ - 1); + + CUdeviceptr ptr = virtual_mem_base_ + virtual_mem_alloced_offset_; + + if (ptr + size > virtual_mem_base_ + virtual_mem_size_) { + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on GPU %d. " + "Cannot allocate %s memory on GPU %d, %s memory has been allocated and " + "available memory is only %s.\n\n" + "Please check whether there is any other process using GPU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" + "2. If no, please decrease the batch size of your model.\n\n", + place_.device, string::HumanReadableSize(size), place_.device, + string::HumanReadableSize(virtual_mem_alloced_offset_), + string::HumanReadableSize(virtual_mem_size_ - + virtual_mem_alloced_offset_), + place_.device)); + return nullptr; + } + + CUmemGenericAllocationHandle handle; + + paddle::platform::CUDADeviceGuard guard(place_.GetDeviceId()); + auto ret = paddle::platform::dynload::cuMemCreate(&handle, size, &prop_, 0); + + if (ret != CUDA_SUCCESS) { + if (ret == CUDA_ERROR_OUT_OF_MEMORY) { + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on GPU %d. " + "Cannot allocate %s memory on GPU %d, %s memory has been allocated " + "and " + "available memory is only %s.\n\n" + "Please check whether there is any other process using GPU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" + "2. If no, please decrease the batch size of your model.\n\n", + place_.device, string::HumanReadableSize(size), place_.device, + string::HumanReadableSize(virtual_mem_alloced_offset_), + string::HumanReadableSize(virtual_mem_size_ - + virtual_mem_alloced_offset_), + place_.device)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS2(ret); + } + return nullptr; + } + + ret = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); + + if (ret != CUDA_SUCCESS) { + paddle::platform::dynload::cuMemRelease(handle); + PADDLE_ENFORCE_CUDA_SUCCESS2(ret); + return nullptr; + } + + ret = paddle::platform::dynload::cuMemSetAccess(ptr, size, &access_desc_, 1); + + if (ret != CUDA_SUCCESS) { + paddle::platform::dynload::cuMemUnmap(ptr, size); + paddle::platform::dynload::cuMemRelease(handle); + PADDLE_ENFORCE_CUDA_SUCCESS2(ret); + return nullptr; + } + + virtual_2_physical_map_.emplace(ptr, std::make_pair(handle, size)); + + virtual_mem_alloced_offset_ += size; + + return new Allocation(reinterpret_cast(ptr), size, + platform::Place(place_)); +} + +} // namespace allocation +} // namespace memory +} // namespace paddle + +#endif diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h new file mode 100644 index 0000000000000..06f50b3462344 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h @@ -0,0 +1,60 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // NOLINT +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/fluid/platform/place.h" + +#if CUDA_VERSION >= 10020 + +namespace paddle { +namespace memory { +namespace allocation { + +class CUDAVirtualMemAllocator : public Allocator { + public: + explicit CUDAVirtualMemAllocator(const platform::CUDAPlace& place); + ~CUDAVirtualMemAllocator(); + + bool IsAllocThreadSafe() const override; + + protected: + void FreeImpl(Allocation* allocation) override; + Allocation* AllocateImpl(size_t size) override; + + private: + platform::CUDAPlace place_; + + CUdeviceptr virtual_mem_base_; + size_t virtual_mem_size_; + size_t virtual_mem_alloced_offset_; + size_t granularity_; + + CUmemAllocationProp prop_; + CUmemAccessDesc access_desc_; + + std::map> + virtual_2_physical_map_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h index 5799b084f5f31..242c11d511abe 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.h +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -57,7 +57,17 @@ extern bool HasCUDADriver(); __macro(cuCtxCreate); \ __macro(cuCtxGetCurrent); \ __macro(cuDeviceGetCount); \ - __macro(cuDevicePrimaryCtxGetState) + __macro(cuDevicePrimaryCtxGetState); \ + __macro(cuMemGetAllocationGranularity); \ + __macro(cuMemAddressReserve); \ + __macro(cuMemCreate); \ + __macro(cuMemMap); \ + __macro(cuMemSetAccess); \ + __macro(cuMemUnmap); \ + __macro(cuMemRelease); \ + __macro(cuMemAddressFree); \ + __macro(cuDeviceGetAttribute); \ + __macro(cuDeviceGet) CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); From 4ca9d2f4e93e8de74ffc8d2e1e240acc43d34159 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 28 Sep 2021 11:47:20 +0000 Subject: [PATCH 02/40] refine, test=develop --- .../memory/allocation/allocator_facade.cc | 17 ++- .../allocation/cuda_virtual_mem_allocator.cc | 131 ++++++++++-------- paddle/fluid/platform/dynload/cuda_driver.h | 20 ++- 3 files changed, 100 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index c50659ea49fc9..335d62869d76e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -198,17 +198,20 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_CUDA) #if CUDA_VERSION >= 10020 CUdevice device; + auto result = + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()); PADDLE_ENFORCE_EQ( - paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()), - CUDA_SUCCESS, platform::errors::Fatal("Call cuDeviceGet faild.")); + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuDeviceGet faild, return %d.", + result)); int val; + result = paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, device); PADDLE_ENFORCE_EQ( - paddle::platform::dynload::cuDeviceGetAttribute( - &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, - device), - CUDA_SUCCESS, - platform::errors::Fatal("Call cuDeviceGetAttribute faild.")); + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); if (val > 0) { auto cuda_allocator = std::make_shared(p); diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index 7925140fd6716..9b107328ac6d0 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -11,21 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" -#include "paddle/fluid/platform/dynload/cuda_driver.h" - #ifdef PADDLE_WITH_CUDA #include #include #endif -#ifdef PADDLE_WITH_HIP -#include -#endif - #include +#include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" #include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" @@ -35,57 +29,63 @@ namespace paddle { namespace memory { namespace allocation { -#define PADDLE_ENFORCE_CUDA_SUCCESS2(COND) \ - do { \ - auto __cond__ = (COND); \ - if (UNLIKELY(__cond__ != CUDA_SUCCESS)) { \ - auto __summary__ = \ - ::paddle::platform::errors::External("cu error %d", __cond__); \ - __THROW_ERROR_INTERNAL__(__summary__); \ - } \ - } while (0) - CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( const platform::CUDAPlace& place) : place_(place) { prop_.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop_.location.id = place.GetDeviceId(); + prop_.location.id = place.device; access_desc_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access_desc_.location.id = place.GetDeviceId(); + access_desc_.location.id = place.device; access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - PADDLE_ENFORCE_CUDA_SUCCESS2( - paddle::platform::dynload::cuMemGetAllocationGranularity( + auto result = paddle::platform::dynload::cuMemGetAllocationGranularity( &granularity_, &prop_, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); size_t actual_avail, actual_total; - paddle::platform::CUDADeviceGuard guard(place.GetDeviceId()); + paddle::platform::CUDADeviceGuard guard(place.device); PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); - // virtual_mem_size_ = actual_total/2; - virtual_mem_size_ = (actual_total + granularity_ - 1) & ~(granularity_ - 1); - // std::cout << "virtual_mem_size_=" << virtual_mem_size_ << std::endl; + result = paddle::platform::dynload::cuMemAddressReserve( + &virtual_mem_base_, virtual_mem_size_, 0, 0, 0); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuMemAddressReserve faild, return %d.", result)); - PADDLE_ENFORCE_CUDA_SUCCESS2(paddle::platform::dynload::cuMemAddressReserve( - &virtual_mem_base_, virtual_mem_size_, 0, 0, 0)); virtual_mem_alloced_offset_ = 0; } CUDAVirtualMemAllocator::~CUDAVirtualMemAllocator() { - paddle::platform::CUDADeviceGuard guard(place_.GetDeviceId()); + CUresult result; + paddle::platform::CUDADeviceGuard guard(place_.device); for (auto& item : virtual_2_physical_map_) { - PADDLE_ENFORCE_CUDA_SUCCESS2( - paddle::platform::dynload::cuMemUnmap(item.first, item.second.second)); - PADDLE_ENFORCE_CUDA_SUCCESS2( - paddle::platform::dynload::cuMemRelease(item.second.first)); + result = + paddle::platform::dynload::cuMemUnmap(item.first, item.second.second); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + result)); + result = paddle::platform::dynload::cuMemRelease(item.second.first); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuMemRelease faild, return %d.", + result)); } - PADDLE_ENFORCE_CUDA_SUCCESS2(paddle::platform::dynload::cuMemAddressFree( - virtual_mem_base_, virtual_mem_size_)); + result = paddle::platform::dynload::cuMemAddressFree(virtual_mem_base_, + virtual_mem_size_); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuMemAddressFree faild, return %d.", result)); } bool CUDAVirtualMemAllocator::IsAllocThreadSafe() const { return false; } @@ -103,11 +103,16 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { "Can not find virtual memory address at %s", allocation->ptr())); } - paddle::platform::CUDADeviceGuard guard(place_.GetDeviceId()); - PADDLE_ENFORCE_CUDA_SUCCESS2( - paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second)); - PADDLE_ENFORCE_CUDA_SUCCESS2( - paddle::platform::dynload::cuMemRelease(iter->second.first)); + paddle::platform::CUDADeviceGuard guard(place_.device); + auto result = + paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); + PADDLE_ENFORCE_EQ(result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuMemUnmap faild, return %d.", result)); + result = paddle::platform::dynload::cuMemRelease(iter->second.first); + PADDLE_ENFORCE_EQ(result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuMemUnmap faild, return %d.", result)); virtual_2_physical_map_.erase(iter); @@ -121,12 +126,11 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { if (ptr + size > virtual_mem_base_ + virtual_mem_size_) { PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( - "\n\nOut of memory error on GPU %d. " - "Cannot allocate %s memory on GPU %d, %s memory has been allocated and " + "\n\nOut of memory error on GPU Virtual Memory %d. " + "Cannot allocate %s memory on GPU Virtual Memory %d, %s memory has " + "been allocated and " "available memory is only %s.\n\n" - "Please check whether there is any other process using GPU %d.\n" - "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" - "2. If no, please decrease the batch size of your model.\n\n", + "Please decrease the batch size of your model.\n\n", place_.device, string::HumanReadableSize(size), place_.device, string::HumanReadableSize(virtual_mem_alloced_offset_), string::HumanReadableSize(virtual_mem_size_ - @@ -137,11 +141,16 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { CUmemGenericAllocationHandle handle; - paddle::platform::CUDADeviceGuard guard(place_.GetDeviceId()); - auto ret = paddle::platform::dynload::cuMemCreate(&handle, size, &prop_, 0); + paddle::platform::CUDADeviceGuard guard(place_.device); + auto result = + paddle::platform::dynload::cuMemCreate(&handle, size, &prop_, 0); + + if (result != CUDA_SUCCESS) { + if (result == CUDA_ERROR_OUT_OF_MEMORY) { + size_t actual_avail, actual_total; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); + size_t actual_allocated = actual_total - actual_avail; - if (ret != CUDA_SUCCESS) { - if (ret == CUDA_ERROR_OUT_OF_MEMORY) { PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( "\n\nOut of memory error on GPU %d. " "Cannot allocate %s memory on GPU %d, %s memory has been allocated " @@ -151,30 +160,32 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" "2. If no, please decrease the batch size of your model.\n\n", place_.device, string::HumanReadableSize(size), place_.device, - string::HumanReadableSize(virtual_mem_alloced_offset_), - string::HumanReadableSize(virtual_mem_size_ - - virtual_mem_alloced_offset_), - place_.device)); + string::HumanReadableSize(actual_allocated), + string::HumanReadableSize(actual_avail), place_.device)); } else { - PADDLE_ENFORCE_CUDA_SUCCESS2(ret); + PADDLE_THROW(platform::errors::Fatal( + "Call CUDA API cuMemCreate faild, return %d.", result)); } return nullptr; } - ret = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); + result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); - if (ret != CUDA_SUCCESS) { + if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemRelease(handle); - PADDLE_ENFORCE_CUDA_SUCCESS2(ret); + PADDLE_THROW(platform::errors::Fatal( + "Call CUDA API cuMemMap faild, return %d.", result)); return nullptr; } - ret = paddle::platform::dynload::cuMemSetAccess(ptr, size, &access_desc_, 1); + result = + paddle::platform::dynload::cuMemSetAccess(ptr, size, &access_desc_, 1); - if (ret != CUDA_SUCCESS) { + if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemUnmap(ptr, size); paddle::platform::dynload::cuMemRelease(handle); - PADDLE_ENFORCE_CUDA_SUCCESS2(ret); + PADDLE_THROW(platform::errors::Fatal( + "Call CUDA API cuMemSetAccess faild, return %d.", result)); return nullptr; } diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h index 242c11d511abe..424d8d38bc943 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.h +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -42,6 +42,7 @@ extern bool HasCUDADriver(); }; \ extern struct DynLoad__##__name __name +#if CUDA_VERSION >= 10020 /** * include all needed cuda driver functions **/ @@ -68,7 +69,24 @@ extern bool HasCUDADriver(); __macro(cuMemAddressFree); \ __macro(cuDeviceGetAttribute); \ __macro(cuDeviceGet) - +#else +/** + * include all needed cuda driver functions + **/ +#define CUDA_ROUTINE_EACH(__macro) \ + __macro(cuInit); \ + __macro(cuDriverGetVersion); \ + __macro(cuGetErrorString); \ + __macro(cuModuleLoadData); \ + __macro(cuModuleGetFunction); \ + __macro(cuModuleUnload); \ + __macro(cuOccupancyMaxActiveBlocksPerMultiprocessor); \ + __macro(cuLaunchKernel); \ + __macro(cuCtxCreate); \ + __macro(cuCtxGetCurrent); \ + __macro(cuDeviceGetCount); \ + __macro(cuDevicePrimaryCtxGetState) +#endif CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); #undef DECLARE_DYNAMIC_LOAD_CUDA_WRAP From 1fa73284b8ddf11218985a730bebac6b809cd437 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 28 Sep 2021 11:48:45 +0000 Subject: [PATCH 03/40] refine, test=develop --- paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc index 25029444b2b85..a35d8a73f7eda 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc @@ -97,7 +97,6 @@ Allocation *AutoGrowthBestFitAllocator::AllocateImpl(size_t size) { VLOG(2) << "Not found and reallocate " << realloc_size << ", and remaining " << remaining_size; } - return new BlockAllocation(block_it); } From 3544756553c65d3fc432c87a1cd58ec3069d4fed Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 28 Sep 2021 11:58:08 +0000 Subject: [PATCH 04/40] refine, test=develop --- .../auto_growth_best_fit_allocator_v2.cc | 42 ------------------- .../auto_growth_best_fit_allocator_v2.h | 3 -- .../allocation/cuda_virtual_mem_allocator.cc | 2 +- 3 files changed, 1 insertion(+), 46 deletions(-) diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index 5903d5b33f59f..b6287139e8330 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -45,9 +45,6 @@ Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { result = AllocFromFreeBlocks(size); } - // std::cout << "alloc " << result->ptr() << " " << result->size() << - // std::endl; - return result; } @@ -57,18 +54,12 @@ void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { delete allocation; } -uint64_t AutoGrowthBestFitAllocatorV2::ReleaseImpl( - const platform::Place &place) { - return 0; -} - void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( std::list::iterator block) { std::lock_guard guard(spinlock_); if (block->ptr_ == all_blocks_.front().ptr_ && block->ptr_ == all_blocks_.back().ptr_) { block->is_free_ = true; - // std::cout << "back1 " << block->ptr_ << " " << block->size_ << std::endl; free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else if (block->ptr_ == all_blocks_.front().ptr_) { block++; @@ -80,14 +71,10 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); block->size_ += next->size_; block->is_free_ = true; - // std::cout << "merge1 " << block->ptr_ << " " << next->ptr_ << " " << - // block->size_ << std::endl; all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else { block->is_free_ = true; - // std::cout << "back2 " << block->ptr_ << " " << block->size_ << - // std::endl; free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else if (block->ptr_ == all_blocks_.back().ptr_) { @@ -99,14 +86,10 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( // merge with pre free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += block->size_; - // std::cout << "merge2 " << pre->ptr_ << " " << block->ptr_ << " " << - // pre->size_ << std::endl; all_blocks_.erase(block); free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else { block->is_free_ = true; - // std::cout << "back3 " << block->ptr_ << " " << block->size_ << - // std::endl; free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else { @@ -124,8 +107,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( // merge with pre free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += block->size_; - // std::cout << "merge3 " << pre->ptr_ << " " << block->ptr_ << " " << - // pre->size_ << std::endl; all_blocks_.erase(block); free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else if (next->is_free_ && @@ -138,8 +119,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); block->size_ += next->size_; block->is_free_ = true; - // std::cout << "merge4 " << block->ptr_ << " " << next->ptr_ << " " << - // block->size_ << std::endl; all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else if (pre->is_free_ && @@ -152,15 +131,11 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); pre->size_ += (block->size_ + next->size_); - // std::cout << "merge5 " << pre->ptr_ << " " << block->ptr_ << " " << - // next->ptr_ << " " << pre->size_ << std::endl; all_blocks_.erase(block); all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else { block->is_free_ = true; - // std::cout << "back4 " << block->ptr_ << " " << block->size_ << - // std::endl; free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } @@ -171,7 +146,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, std::lock_guard guard(spinlock_); if (all_blocks_.empty()) { all_blocks_.push_back(Block(ptr, size, true)); - // std::cout << "insert1 " << ptr << " " << size << std::endl; free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); return; } @@ -184,8 +158,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, reinterpret_cast(ptr) + size == block_it->ptr_) { // merge with next free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); - // std::cout << "merge6 " << ptr << " " << block_it->ptr_ << " " << - // block_it->size_+size << std::endl; block_it->ptr_ = ptr; block_it->size_ += size; free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), @@ -193,7 +165,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, } else { // do not merge all_blocks_.push_front(Block(ptr, size, true)); - // std::cout << "insert2 " << ptr << " " << size << std::endl; free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); } } else { @@ -208,8 +179,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, // merge with pre free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += size; - // std::cout << "merge7 " << pre->ptr_ << " " << ptr << " " << - // pre->size_ << std::endl; free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else if (next->is_free_ && reinterpret_cast(ptr) + size == next->ptr_ && @@ -218,8 +187,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, ptr)) { // merge with next free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); - // std::cout << "merge8 " << ptr << " " << next->ptr_ << " " << - // next->size_+size << std::endl; next->ptr_ = ptr; next->size_ += size; free_blocks_.emplace(std::make_pair(next->size_, next->ptr_), next); @@ -230,15 +197,12 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, // merge with pre and next free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); - // std::cout << "merge9 " << pre->ptr_ << " " << ptr << " " << - // next->ptr_ << " " << pre->size_+next->size_+size << std::endl; pre->size_ += (size + next->size_); free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); all_blocks_.erase(next); } else { // do not merge auto iter = all_blocks_.insert(next, Block(ptr, size, true)); - // std::cout << "insert3 " << ptr << " " << size << std::endl; free_blocks_.emplace(std::make_pair(size, ptr), iter); } } @@ -254,8 +218,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, // merge with pre free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); block_it->size_ += size; - // std::cout << "merge10 " << block_it->ptr_ << " " << ptr << " " << - // block_it->size_ << std::endl; free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), block_it); } else { @@ -263,7 +225,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, all_blocks_.push_back(Block(ptr, size, true)); auto block_it = all_blocks_.end(); block_it--; - // std::cout << "insert4 " << ptr << " " << size << std::endl; free_blocks_.emplace(std::make_pair(size, ptr), block_it); } } @@ -283,9 +244,6 @@ Allocation *AutoGrowthBestFitAllocatorV2::AllocFromFreeBlocks(size_t size) { block_it->ptr_ = reinterpret_cast(block_it->ptr_) + remaining_size; block_it->size_ = size; - // std::cout << "split " << remaining_free_block->ptr_ << " " << - // remaining_free_block->size_ << " " << block_it->ptr_ << " " << - // block_it->size_ << std::endl; } block_it->is_free_ = false; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h index ad21e3157da90..d355b61a78be2 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h @@ -67,9 +67,6 @@ class AutoGrowthBestFitAllocatorV2 : public Allocator { void FreeImpl(Allocation *allocation) override; - // Release the memory block which is not used in pool. - uint64_t ReleaseImpl(const platform::Place &place) override; - private: Allocation *AllocFromFreeBlocks(size_t size); void TryMergeAlloctation2Blocks(void *ptr, size_t size); diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index 9b107328ac6d0..2dea2b3e5e7ee 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -41,7 +41,7 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; auto result = paddle::platform::dynload::cuMemGetAllocationGranularity( - &granularity_, &prop_, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + &granularity_, &prop_, CU_MEM_ALLOC_GRANULARITY_MINIMUM); PADDLE_ENFORCE_EQ( result, CUDA_SUCCESS, platform::errors::Fatal( From ae06a0ba8b186704e4e70baa8f7b6f80ebc3ad3d Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 28 Sep 2021 12:14:00 +0000 Subject: [PATCH 05/40] refine, test=develop --- paddle/fluid/memory/allocation/allocator_facade.cc | 4 ++-- .../memory/allocation/auto_growth_best_fit_allocator_v2.cc | 2 -- .../fluid/memory/allocation/cuda_virtual_mem_allocator.cc | 6 +++++- paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h | 6 ++++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 335d62869d76e..d73102a5d6d49 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h" -#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/naive_best_fit_allocator.h" #ifdef PADDLE_WITH_ASCEND_CL @@ -34,12 +33,13 @@ #include "paddle/fluid/platform/gpu_info.h" #endif #if CUDA_VERSION >= 10020 +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" #include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#include "paddle/fluid/platform/dynload/cuda_driver.h" #endif #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu/xpu_info.h" #endif -#include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/fluid/platform/npu_info.h" PADDLE_DEFINE_EXPORTED_int64( diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index b6287139e8330..47e3a9ff0c8fb 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -17,8 +17,6 @@ #include "paddle/fluid/memory/allocation/aligned_allocator.h" #include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" -#pragma GCC diagnostic ignored "-Wpointer-arith" - namespace paddle { namespace memory { namespace allocation { diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index 2dea2b3e5e7ee..52d6f43e621d8 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #ifdef PADDLE_WITH_CUDA #include #include @@ -18,10 +19,13 @@ #include #include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/dynload/cuda_driver.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" +#endif #if CUDA_VERSION >= 10020 diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h index 06f50b3462344..b06f663a2d321 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h @@ -14,11 +14,13 @@ #pragma once +#ifdef PADDLE_WITH_CUDA #include +#include "paddle/fluid/platform/cuda_device_guard.h" +#endif + #include // NOLINT #include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/platform/cuda_device_guard.h" -#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/place.h" #if CUDA_VERSION >= 10020 From 97234d4ce7491edfaeb7e90fb283675a670a4d40 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 28 Sep 2021 12:21:12 +0000 Subject: [PATCH 06/40] refine, test=develop --- .../auto_growth_best_fit_allocator_v2.cc | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index 47e3a9ff0c8fb..af7e1ae7fe8da 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -54,10 +54,10 @@ void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( std::list::iterator block) { - std::lock_guard guard(spinlock_); if (block->ptr_ == all_blocks_.front().ptr_ && block->ptr_ == all_blocks_.back().ptr_) { block->is_free_ = true; + std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else if (block->ptr_ == all_blocks_.front().ptr_) { block++; @@ -66,13 +66,15 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( if (next->is_free_ && reinterpret_cast(block->ptr_) + block->size_ == next->ptr_) { // merge with next - free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); block->size_ += next->size_; block->is_free_ = true; + std::lock_guard guard(spinlock_); + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else { block->is_free_ = true; + std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else if (block->ptr_ == all_blocks_.back().ptr_) { @@ -82,12 +84,14 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( if (pre->is_free_ && reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_) { // merge with pre + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += block->size_; all_blocks_.erase(block); free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else { block->is_free_ = true; + std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else { @@ -103,6 +107,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( reinterpret_cast(block->ptr_) + block->size_ == next->ptr_)) { // merge with pre + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += block->size_; all_blocks_.erase(block); @@ -114,9 +119,10 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_)) { // merge with next - free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); block->size_ += next->size_; block->is_free_ = true; + std::lock_guard guard(spinlock_); + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else if (pre->is_free_ && @@ -126,6 +132,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( reinterpret_cast(block->ptr_) + block->size_ == next->ptr_) { // merge with pre and next + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); pre->size_ += (block->size_ + next->size_); @@ -134,6 +141,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else { block->is_free_ = true; + std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } @@ -141,8 +149,8 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, size_t size) { - std::lock_guard guard(spinlock_); if (all_blocks_.empty()) { + std::lock_guard guard(spinlock_); all_blocks_.push_back(Block(ptr, size, true)); free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); return; @@ -155,6 +163,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, if (block_it->is_free_ && reinterpret_cast(ptr) + size == block_it->ptr_) { // merge with next + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); block_it->ptr_ = ptr; block_it->size_ += size; @@ -162,6 +171,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, block_it); } else { // do not merge + std::lock_guard guard(spinlock_); all_blocks_.push_front(Block(ptr, size, true)); free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); } @@ -175,6 +185,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, !(next->is_free_ && reinterpret_cast(ptr) + size == next->ptr_)) { // merge with pre + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += size; free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); @@ -184,6 +195,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, reinterpret_cast(pre->ptr_) + pre->size_ == ptr)) { // merge with next + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); next->ptr_ = ptr; next->size_ += size; @@ -193,6 +205,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, next->is_free_ && reinterpret_cast(ptr) + size == next->ptr_) { // merge with pre and next + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); pre->size_ += (size + next->size_); @@ -200,6 +213,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, all_blocks_.erase(next); } else { // do not merge + std::lock_guard guard(spinlock_); auto iter = all_blocks_.insert(next, Block(ptr, size, true)); free_blocks_.emplace(std::make_pair(size, ptr), iter); } @@ -214,12 +228,14 @@ void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, if (block_it->is_free_ && reinterpret_cast(block_it->ptr_) + block_it->size_ == ptr) { // merge with pre + std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); block_it->size_ += size; free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), block_it); } else { // do not merge + std::lock_guard guard(spinlock_); all_blocks_.push_back(Block(ptr, size, true)); auto block_it = all_blocks_.end(); block_it--; From 41a4b97819de913a097f7a3c20d6c248bd37fcf0 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 29 Sep 2021 03:19:14 +0000 Subject: [PATCH 07/40] refine, test=develop --- .../auto_growth_best_fit_allocator_v2.cc | 15 ++++++++++----- .../auto_growth_best_fit_allocator_v2.h | 17 ++--------------- .../allocation/cuda_virtual_mem_allocator.cc | 19 ++++++++++++------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index af7e1ae7fe8da..61eaddc6f92cd 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -37,9 +37,7 @@ Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { auto result = AllocFromFreeBlocks(size); if (!result) { - auto allocateptr = underlying_allocator_->Allocate(size); - TryMergeAlloctation2Blocks(allocateptr->ptr(), allocateptr->size()); - regions_.emplace(std::move(allocateptr)); + ExtendAndMerge(size); result = AllocFromFreeBlocks(size); } @@ -147,8 +145,15 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( } } -void AutoGrowthBestFitAllocatorV2::TryMergeAlloctation2Blocks(void *ptr, - size_t size) { +void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { + void *ptr = nullptr; + { + std::lock_guard guard(spinlock_); + auto allocateptr = underlying_allocator_->Allocate(size); + ptr = allocateptr->ptr(); + size = allocateptr->size(); + allocations_.push_back(std::move(allocateptr)); // hold allocation + } if (all_blocks_.empty()) { std::lock_guard guard(spinlock_); all_blocks_.push_back(Block(ptr, size, true)); diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h index d355b61a78be2..e4035d5a61778 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h @@ -34,19 +34,6 @@ struct Block { bool is_free_; }; -struct Region { - explicit Region(AllocationPtr allocation) - : allocation_(std::move(allocation)) {} - - AllocationPtr allocation_; -}; - -struct RegionComp { - bool operator()(const Region &a, const Region &b) { - return a.allocation_->ptr() < b.allocation_->ptr(); - } -}; - struct BlockAllocation : public Allocation { explicit BlockAllocation(const std::list::iterator &it, platform::Place place) @@ -69,7 +56,7 @@ class AutoGrowthBestFitAllocatorV2 : public Allocator { private: Allocation *AllocFromFreeBlocks(size_t size); - void TryMergeAlloctation2Blocks(void *ptr, size_t size); + void ExtendAndMerge(size_t size); void TryMergeBlock2Blocks(std::list::iterator iter); std::shared_ptr underlying_allocator_; @@ -77,7 +64,7 @@ class AutoGrowthBestFitAllocatorV2 : public Allocator { std::map, std::list::iterator> free_blocks_; std::list all_blocks_; - std::set regions_; + std::list allocations_; platform::Place place_; SpinLock spinlock_; }; diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index 52d6f43e621d8..e88ec83bb515a 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -36,13 +36,18 @@ namespace allocation { CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( const platform::CUDAPlace& place) : place_(place) { - prop_.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop_.location.id = place.device; - - access_desc_.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access_desc_.location.id = place.device; - access_desc_.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + CUmemAllocationProp prop = {}; + CUmemAccessDesc access_desc = {}; + + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = place.device; + prop_ = prop; + + access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc.location.id = place.device; + access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + access_desc_ = access_desc; auto result = paddle::platform::dynload::cuMemGetAllocationGranularity( &granularity_, &prop_, CU_MEM_ALLOC_GRANULARITY_MINIMUM); From a21c9826f15cdd85e057f5ac15c382fb2e1c37b3 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 29 Sep 2021 04:23:26 +0000 Subject: [PATCH 08/40] refine, test=develop --- .../auto_growth_best_fit_allocator_v2.cc | 33 +++++-------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index 61eaddc6f92cd..3fdacd03e6609 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -52,10 +52,10 @@ void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( std::list::iterator block) { + std::lock_guard guard(spinlock_); if (block->ptr_ == all_blocks_.front().ptr_ && block->ptr_ == all_blocks_.back().ptr_) { block->is_free_ = true; - std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else if (block->ptr_ == all_blocks_.front().ptr_) { block++; @@ -66,13 +66,11 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( // merge with next block->size_ += next->size_; block->is_free_ = true; - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else { block->is_free_ = true; - std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else if (block->ptr_ == all_blocks_.back().ptr_) { @@ -82,14 +80,12 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( if (pre->is_free_ && reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_) { // merge with pre - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += block->size_; all_blocks_.erase(block); free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else { block->is_free_ = true; - std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else { @@ -105,7 +101,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( reinterpret_cast(block->ptr_) + block->size_ == next->ptr_)) { // merge with pre - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += block->size_; all_blocks_.erase(block); @@ -119,7 +114,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( // merge with next block->size_ += next->size_; block->is_free_ = true; - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); all_blocks_.erase(next); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); @@ -130,7 +124,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( reinterpret_cast(block->ptr_) + block->size_ == next->ptr_) { // merge with pre and next - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); pre->size_ += (block->size_ + next->size_); @@ -139,7 +132,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); } else { block->is_free_ = true; - std::lock_guard guard(spinlock_); free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } @@ -147,15 +139,14 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { void *ptr = nullptr; - { - std::lock_guard guard(spinlock_); - auto allocateptr = underlying_allocator_->Allocate(size); - ptr = allocateptr->ptr(); - size = allocateptr->size(); - allocations_.push_back(std::move(allocateptr)); // hold allocation - } + std::lock_guard guard(spinlock_); + + auto allocateptr = underlying_allocator_->Allocate(size); + ptr = allocateptr->ptr(); + size = allocateptr->size(); + allocations_.push_back(std::move(allocateptr)); // hold allocation + if (all_blocks_.empty()) { - std::lock_guard guard(spinlock_); all_blocks_.push_back(Block(ptr, size, true)); free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); return; @@ -168,7 +159,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { if (block_it->is_free_ && reinterpret_cast(ptr) + size == block_it->ptr_) { // merge with next - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); block_it->ptr_ = ptr; block_it->size_ += size; @@ -176,7 +166,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { block_it); } else { // do not merge - std::lock_guard guard(spinlock_); all_blocks_.push_front(Block(ptr, size, true)); free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); } @@ -190,7 +179,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { !(next->is_free_ && reinterpret_cast(ptr) + size == next->ptr_)) { // merge with pre - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); pre->size_ += size; free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); @@ -200,7 +188,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { reinterpret_cast(pre->ptr_) + pre->size_ == ptr)) { // merge with next - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); next->ptr_ = ptr; next->size_ += size; @@ -210,7 +197,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { next->is_free_ && reinterpret_cast(ptr) + size == next->ptr_) { // merge with pre and next - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); pre->size_ += (size + next->size_); @@ -218,7 +204,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { all_blocks_.erase(next); } else { // do not merge - std::lock_guard guard(spinlock_); auto iter = all_blocks_.insert(next, Block(ptr, size, true)); free_blocks_.emplace(std::make_pair(size, ptr), iter); } @@ -233,14 +218,12 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { if (block_it->is_free_ && reinterpret_cast(block_it->ptr_) + block_it->size_ == ptr) { // merge with pre - std::lock_guard guard(spinlock_); free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); block_it->size_ += size; free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), block_it); } else { // do not merge - std::lock_guard guard(spinlock_); all_blocks_.push_back(Block(ptr, size, true)); auto block_it = all_blocks_.end(); block_it--; From d38050ba6cc87e75a3fdc84d09d89a566e45eee4 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 29 Sep 2021 09:07:06 +0000 Subject: [PATCH 09/40] window dll, test=develop --- paddle/fluid/platform/dynload/CMakeLists.txt | 4 ++-- paddle/fluid/platform/dynload/dynamic_loader.cc | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index c0d4b349a9e09..a841496e286d0 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -11,8 +11,8 @@ if (WITH_ROCM) endif() # There is no macOS version of NCCL. -# Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux. -if (NOT APPLE AND NOT WIN32) +# Disable nvrtc and cuda_driver api on MacOS, and only do a early test on Linux and Windows. +if (NOT APPLE) list(APPEND CUDA_SRCS nvrtc.cc cuda_driver.cc) if (WITH_NCCL) list(APPEND CUDA_SRCS nccl.cc) diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index a83f085f7d2d8..788a2b1b0f1ec 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -21,6 +21,10 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/enforce.h" +#if defined(_WIN32) +#include +#endif + DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " "/usr/local/cudnn/lib. If empty [default], dlopen " @@ -398,6 +402,10 @@ void* GetCUDADsoHandle() { return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false); #elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false); +#elif defined(_WIN32) + char system32_dir[MAX_PATH]; + GetSystemDirectory(system32_dir, MAX_PATH); + return GetDsoHandleFromSearchPath(system32_dir, "nvcuda.dll"); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so", false); #endif From 1296daa03ef517fe74c55a32bede596db122bd5e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 8 Oct 2021 08:14:08 +0000 Subject: [PATCH 10/40] fix cuda error of CUDA_ERROR_NOT_INITIALIZED, test=develop --- python/paddle/fluid/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 8ddac967e7bd3..8dfce06ad2cae 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -97,6 +97,9 @@ from .transpiler import HashName, RoundRobin from .backward import append_backward +import multiprocessing as mp +mp.set_start_method('spawn') + Tensor = LoDTensor enable_imperative = enable_dygraph disable_imperative = disable_dygraph From cc368b5459d71f1e23bd9664a0eb897dc67889b7 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 8 Oct 2021 10:23:26 +0000 Subject: [PATCH 11/40] use autogrowthv2 for system allocator, test=develop --- .../memory/allocation/allocator_facade.cc | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 6fae8c3fc6fab..8e936e8851970 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -234,7 +234,36 @@ class AllocatorFacadePrivate { int device_count = platform::GetCUDADeviceCount(); for (int i = 0; i < device_count; ++i) { platform::CUDAPlace p(i); +#if CUDA_VERSION >= 10020 + CUdevice device; + auto result = + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuDeviceGet faild, return %d.", + result)); + + int val; + result = paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, + device); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); + + if (val > 0) { + auto cuda_allocator = std::make_shared(p); + system_allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize()); + } else { + auto cuda_allocator = std::make_shared(p); + system_allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize()); + } +#else system_allocators_[p] = std::make_shared(p); +#endif } #endif } From c3889e754e118ad7fd043f4f01f2b6c5f83d6086 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 8 Oct 2021 10:52:09 +0000 Subject: [PATCH 12/40] remove ~CUDAVirtualMemAllocator(), test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 25 ------------------- .../allocation/cuda_virtual_mem_allocator.h | 1 - 2 files changed, 26 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index e88ec83bb515a..a42ec16232df0 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -72,31 +72,6 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( virtual_mem_alloced_offset_ = 0; } -CUDAVirtualMemAllocator::~CUDAVirtualMemAllocator() { - CUresult result; - paddle::platform::CUDADeviceGuard guard(place_.device); - for (auto& item : virtual_2_physical_map_) { - result = - paddle::platform::dynload::cuMemUnmap(item.first, item.second.second); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - result)); - result = paddle::platform::dynload::cuMemRelease(item.second.first); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuMemRelease faild, return %d.", - result)); - } - - result = paddle::platform::dynload::cuMemAddressFree(virtual_mem_base_, - virtual_mem_size_); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuMemAddressFree faild, return %d.", result)); -} - bool CUDAVirtualMemAllocator::IsAllocThreadSafe() const { return false; } void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h index b06f663a2d321..d96bae4870ae8 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h @@ -32,7 +32,6 @@ namespace allocation { class CUDAVirtualMemAllocator : public Allocator { public: explicit CUDAVirtualMemAllocator(const platform::CUDAPlace& place); - ~CUDAVirtualMemAllocator(); bool IsAllocThreadSafe() const override; From 4d8cfc16dbee46927aa2226a20b08ae8b643def6 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Sat, 9 Oct 2021 01:38:40 +0000 Subject: [PATCH 13/40] refine, test=develop --- python/paddle/fluid/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 8dfce06ad2cae..8ddac967e7bd3 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -97,9 +97,6 @@ from .transpiler import HashName, RoundRobin from .backward import append_backward -import multiprocessing as mp -mp.set_start_method('spawn') - Tensor = LoDTensor enable_imperative = enable_dygraph disable_imperative = disable_dygraph From 7863dfb04b75f1194bfcb3f6b79d5bc80cdafe72 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Sat, 9 Oct 2021 01:46:43 +0000 Subject: [PATCH 14/40] fix cuda error of CUDA_ERROR_NOT_INITIALIZED, test=develop --- python/paddle/fluid/reader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index dfc887292e7cf..8347636bd2b34 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -50,6 +50,8 @@ KEEP_DATA_LOADER_ORDER = True USE_PINNED_MEMORY = None +multiprocessing.set_start_method('spawn') + def keep_data_loader_order(*args): global KEEP_DATA_LOADER_ORDER From 34983a85e1c8507899e85d03b403836bc8a43c15 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Sat, 9 Oct 2021 01:53:54 +0000 Subject: [PATCH 15/40] fix cuda error of CUDA_ERROR_NOT_INITIALIZED, test=develop --- python/paddle/fluid/dataloader/dataloader_iter.py | 2 ++ python/paddle/fluid/reader.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 70c7b01b05ba3..efe9376e1d9f6 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -360,6 +360,8 @@ def __init__(self, loader): self._shutdown = False def _init_workers(self): + multiprocessing.set_start_method('spawn') + # multiprocess worker and indice queue list initial as empty self._workers = [] self._worker_status = [] diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 8347636bd2b34..dfc887292e7cf 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -50,8 +50,6 @@ KEEP_DATA_LOADER_ORDER = True USE_PINNED_MEMORY = None -multiprocessing.set_start_method('spawn') - def keep_data_loader_order(*args): global KEEP_DATA_LOADER_ORDER From c53a782dd93ece8fff8f2568b65dd193b794deda Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Sat, 9 Oct 2021 07:34:38 +0000 Subject: [PATCH 16/40] fix bug, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 35 ++++++++++++++----- .../fluid/dataloader/dataloader_iter.py | 5 ++- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index a42ec16232df0..a70d53df1d141 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -54,7 +54,8 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( PADDLE_ENFORCE_EQ( result, CUDA_SUCCESS, platform::errors::Fatal( - "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); + "Call CUDA API cuMemGetAllocationGranularity faild, return %d.", + result)); size_t actual_avail, actual_total; paddle::platform::CUDADeviceGuard guard(place.device); @@ -87,16 +88,32 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { "Can not find virtual memory address at %s", allocation->ptr())); } - paddle::platform::CUDADeviceGuard guard(place_.device); + int prev_id; + cudaGetDevice(&prev_id); + if (prev_id != place_.device) { + cudaSetDevice(place_.device); + } + auto result = paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); - PADDLE_ENFORCE_EQ(result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuMemUnmap faild, return %d.", result)); - result = paddle::platform::dynload::cuMemRelease(iter->second.first); - PADDLE_ENFORCE_EQ(result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuMemUnmap faild, return %d.", result)); + if (result != CUDA_ERROR_DEINITIALIZED) { + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + result)); + } + + if (result != CUDA_ERROR_DEINITIALIZED) { + result = paddle::platform::dynload::cuMemRelease(iter->second.first); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + result)); + } + + if (prev_id != place_.device) { + cudaSetDevice(prev_id); + } virtual_2_physical_map_.erase(iter); diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index efe9376e1d9f6..7c19814a22e45 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -360,7 +360,10 @@ def __init__(self, loader): self._shutdown = False def _init_workers(self): - multiprocessing.set_start_method('spawn') + try: + multiprocessing.set_start_method('spawn') + except RuntimeError: + pass # multiprocess worker and indice queue list initial as empty self._workers = [] From 8208d516e4a11110fb3c47c756dd61b4c76c91f6 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Sat, 9 Oct 2021 08:21:17 +0000 Subject: [PATCH 17/40] revert system allocator, test =develop --- .../memory/allocation/allocator_facade.cc | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 8e936e8851970..6fae8c3fc6fab 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -234,36 +234,7 @@ class AllocatorFacadePrivate { int device_count = platform::GetCUDADeviceCount(); for (int i = 0; i < device_count; ++i) { platform::CUDAPlace p(i); -#if CUDA_VERSION >= 10020 - CUdevice device; - auto result = - paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuDeviceGet faild, return %d.", - result)); - - int val; - result = paddle::platform::dynload::cuDeviceGetAttribute( - &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, - device); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); - - if (val > 0) { - auto cuda_allocator = std::make_shared(p); - system_allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize()); - } else { - auto cuda_allocator = std::make_shared(p); - system_allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize()); - } -#else system_allocators_[p] = std::make_shared(p); -#endif } #endif } From 52d021fef038596959e56ca5cc8185c6bb939242 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 11 Oct 2021 07:32:20 +0000 Subject: [PATCH 18/40] revert multiprocessing, test=develop --- python/paddle/fluid/dataloader/dataloader_iter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 7c19814a22e45..cd77b0423a622 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -360,10 +360,10 @@ def __init__(self, loader): self._shutdown = False def _init_workers(self): - try: - multiprocessing.set_start_method('spawn') - except RuntimeError: - pass + # try: + # multiprocessing.set_start_method('spawn') + # except RuntimeError: + # pass # multiprocess worker and indice queue list initial as empty self._workers = [] From c01bf0af44370027878b9f887dbceb98fbc2c64e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 11 Oct 2021 11:29:20 +0000 Subject: [PATCH 19/40] fix AutoGrowthBestFitAllocatorV2 mutxt, test=develop --- .../memory/allocation/auto_growth_best_fit_allocator_v2.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index 3fdacd03e6609..5ae0c32f1f8e7 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -33,6 +33,7 @@ AutoGrowthBestFitAllocatorV2::AutoGrowthBestFitAllocatorV2( alignment_(alignment) {} Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { + std::lock_guard guard(spinlock_); size = AlignedSize(size, alignment_); auto result = AllocFromFreeBlocks(size); @@ -45,6 +46,7 @@ Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { } void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { + std::lock_guard guard(spinlock_); auto block_it = static_cast(allocation)->block_it_; TryMergeBlock2Blocks(block_it); delete allocation; @@ -52,7 +54,6 @@ void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( std::list::iterator block) { - std::lock_guard guard(spinlock_); if (block->ptr_ == all_blocks_.front().ptr_ && block->ptr_ == all_blocks_.back().ptr_) { block->is_free_ = true; @@ -139,7 +140,6 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { void *ptr = nullptr; - std::lock_guard guard(spinlock_); auto allocateptr = underlying_allocator_->Allocate(size); ptr = allocateptr->ptr(); @@ -232,7 +232,6 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { } Allocation *AutoGrowthBestFitAllocatorV2::AllocFromFreeBlocks(size_t size) { - std::lock_guard guard(spinlock_); auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); if (iter != free_blocks_.end()) { std::list::iterator block_it = iter->second; From a2579842867f08a39625754e182b3abbc2b15d5d Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 11 Oct 2021 12:23:18 +0000 Subject: [PATCH 20/40] catch cudaErrorInitializationError when create allocator, test=develop --- .../memory/allocation/allocator_facade.cc | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 6fae8c3fc6fab..ee742abd9fb47 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -269,20 +269,25 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_CUDA) #if CUDA_VERSION >= 10020 CUdevice device; - auto result = - paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuDeviceGet faild, return %d.", - result)); - int val; - result = paddle::platform::dynload::cuDeviceGetAttribute( - &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, device); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); + try { + auto result = + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuDeviceGet faild, return %d.", + result)); + + result = paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, + device); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); + } catch (...) { + val = 0; + } if (val > 0) { auto cuda_allocator = std::make_shared(p); From ec85a0fdadb6fc79aa320bc18ed27fa3a29474e1 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 14 Oct 2021 09:03:35 +0000 Subject: [PATCH 21/40] fix cuMemSetAccess use, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 38 +++++++++++-------- .../allocation/cuda_virtual_mem_allocator.h | 2 +- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index a70d53df1d141..b06388bc08290 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -26,7 +26,6 @@ #include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/fluid/platform/gpu_info.h" #endif - #if CUDA_VERSION >= 10020 namespace paddle { @@ -37,25 +36,32 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( const platform::CUDAPlace& place) : place_(place) { CUmemAllocationProp prop = {}; - CUmemAccessDesc access_desc = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = place.device; prop_ = prop; - access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access_desc.location.id = place.device; - access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - access_desc_ = access_desc; + access_desc_.resize(platform::GetCUDADeviceCount()); + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { + access_desc_[dev_id].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc_[dev_id].location.id = dev_id; + access_desc_[dev_id].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } - auto result = paddle::platform::dynload::cuMemGetAllocationGranularity( - &granularity_, &prop_, CU_MEM_ALLOC_GRANULARITY_MINIMUM); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuMemGetAllocationGranularity faild, return %d.", - result)); + granularity_ = 0; + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { + size_t granularity; + prop.location.id = dev_id; + auto result = paddle::platform::dynload::cuMemGetAllocationGranularity( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal( + "Call CUDA API cuMemGetAllocationGranularity faild, return %d.", + result)); + granularity_ = std::max(granularity, granularity_); + } size_t actual_avail, actual_total; paddle::platform::CUDADeviceGuard guard(place.device); @@ -63,7 +69,7 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( virtual_mem_size_ = (actual_total + granularity_ - 1) & ~(granularity_ - 1); - result = paddle::platform::dynload::cuMemAddressReserve( + auto result = paddle::platform::dynload::cuMemAddressReserve( &virtual_mem_base_, virtual_mem_size_, 0, 0, 0); PADDLE_ENFORCE_EQ( result, CUDA_SUCCESS, @@ -179,8 +185,8 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { return nullptr; } - result = - paddle::platform::dynload::cuMemSetAccess(ptr, size, &access_desc_, 1); + result = paddle::platform::dynload::cuMemSetAccess( + ptr, size, access_desc_.data(), access_desc_.size()); if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemUnmap(ptr, size); diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h index d96bae4870ae8..8a52ac4ab1ea7 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h @@ -48,7 +48,7 @@ class CUDAVirtualMemAllocator : public Allocator { size_t granularity_; CUmemAllocationProp prop_; - CUmemAccessDesc access_desc_; + std::vector access_desc_; std::map> virtual_2_physical_map_; From 329c568169729015aba538a923e7452a4cc57dfc Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 14 Oct 2021 12:22:43 +0000 Subject: [PATCH 22/40] refine cuda api use, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index b06388bc08290..daccc99088dc8 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -44,6 +44,18 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( access_desc_.resize(platform::GetCUDADeviceCount()); for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { + if (place.device != dev_id) { + int capable = 0; + auto result = cuDeviceCanAccessPeer(&capable, place.device, dev_id); + if (result != CUDA_SUCCESS) { + PADDLE_THROW(platform::errors::Fatal( + "Call CUDA API cuDeviceCanAccessPeer faild, return %d.", result)); + return; + } + if (!capable) { + continue; + } + } access_desc_[dev_id].location.type = CU_MEM_LOCATION_TYPE_DEVICE; access_desc_[dev_id].location.id = dev_id; access_desc_[dev_id].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; @@ -179,18 +191,24 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); if (result != CUDA_SUCCESS) { - paddle::platform::dynload::cuMemRelease(handle); PADDLE_THROW(platform::errors::Fatal( "Call CUDA API cuMemMap faild, return %d.", result)); return nullptr; } + result = paddle::platform::dynload::cuMemRelease(handle); + + if (result != CUDA_SUCCESS) { + PADDLE_THROW(platform::errors::Fatal( + "Call CUDA API cuMemRelease faild, return %d.", result)); + return nullptr; + } + result = paddle::platform::dynload::cuMemSetAccess( ptr, size, access_desc_.data(), access_desc_.size()); if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemUnmap(ptr, size); - paddle::platform::dynload::cuMemRelease(handle); PADDLE_THROW(platform::errors::Fatal( "Call CUDA API cuMemSetAccess faild, return %d.", result)); return nullptr; From 3ba198587c60b0d43e3e99aec6e1e2f141bc49e6 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 14 Oct 2021 12:33:28 +0000 Subject: [PATCH 23/40] refine, test=develop --- .../fluid/memory/allocation/cuda_virtual_mem_allocator.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index daccc99088dc8..ebd702e68e336 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -46,12 +46,8 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { if (place.device != dev_id) { int capable = 0; - auto result = cuDeviceCanAccessPeer(&capable, place.device, dev_id); - if (result != CUDA_SUCCESS) { - PADDLE_THROW(platform::errors::Fatal( - "Call CUDA API cuDeviceCanAccessPeer faild, return %d.", result)); - return; - } + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaDeviceCanAccessPeer(&capable, place.device, dev_id)); if (!capable) { continue; } From f253ffb18c034207ed6b98e615a0efa68559c890 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 15 Oct 2021 08:47:53 +0000 Subject: [PATCH 24/40] for test, test=develop --- paddle/fluid/memory/allocation/allocator_facade.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 09e59d61eee21..c559437a3bb9e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -294,8 +294,8 @@ class AllocatorFacadePrivate { if (val > 0) { auto cuda_allocator = std::make_shared(p); - allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize()); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); } else { auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( From 8312f3cada45337f6f13c941dfff8b8c6475df9e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 18 Oct 2021 01:57:34 +0000 Subject: [PATCH 25/40] for test, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 85 ++++++++++--------- 1 file changed, 43 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index ebd702e68e336..ad9cf67939b38 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -90,48 +90,49 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( bool CUDAVirtualMemAllocator::IsAllocThreadSafe() const { return false; } void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { - PADDLE_ENFORCE_EQ( - BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, - platform::errors::PermissionDenied( - "GPU memory is freed in incorrect device. This may be a bug")); - - auto iter = virtual_2_physical_map_.find( - reinterpret_cast(allocation->ptr())); - if (iter == virtual_2_physical_map_.end()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Can not find virtual memory address at %s", allocation->ptr())); - } - - int prev_id; - cudaGetDevice(&prev_id); - if (prev_id != place_.device) { - cudaSetDevice(place_.device); - } - - auto result = - paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); - if (result != CUDA_ERROR_DEINITIALIZED) { - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - result)); - } - - if (result != CUDA_ERROR_DEINITIALIZED) { - result = paddle::platform::dynload::cuMemRelease(iter->second.first); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - result)); - } - - if (prev_id != place_.device) { - cudaSetDevice(prev_id); - } - - virtual_2_physical_map_.erase(iter); - - delete allocation; + // PADDLE_ENFORCE_EQ( + // BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, + // platform::errors::PermissionDenied( + // "GPU memory is freed in incorrect device. This may be a bug")); + + // auto iter = virtual_2_physical_map_.find( + // reinterpret_cast(allocation->ptr())); + // if (iter == virtual_2_physical_map_.end()) { + // PADDLE_THROW(platform::errors::InvalidArgument( + // "Can not find virtual memory address at %s", allocation->ptr())); + // } + + // int prev_id; + // cudaGetDevice(&prev_id); + // if (prev_id != place_.device) { + // cudaSetDevice(place_.device); + // } + + // auto result = + // paddle::platform::dynload::cuMemUnmap(iter->first, + // iter->second.second); + // if (result != CUDA_ERROR_DEINITIALIZED) { + // PADDLE_ENFORCE_EQ( + // result, CUDA_SUCCESS, + // platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + // result)); + // } + + // if (result != CUDA_ERROR_DEINITIALIZED) { + // result = paddle::platform::dynload::cuMemRelease(iter->second.first); + // PADDLE_ENFORCE_EQ( + // result, CUDA_SUCCESS, + // platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + // result)); + // } + + // if (prev_id != place_.device) { + // cudaSetDevice(prev_id); + // } + + // virtual_2_physical_map_.erase(iter); + + // delete allocation; } Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { From 7f1891e4039f415b43f7224b17be9d3e1bac6361 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 18 Oct 2021 09:19:46 +0000 Subject: [PATCH 26/40] switch to v2, test=develop --- paddle/fluid/memory/allocation/allocator_facade.cc | 4 ++-- .../memory/allocation/auto_growth_best_fit_allocator_v2.cc | 6 ++++-- .../memory/allocation/auto_growth_best_fit_allocator_v2.h | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index c559437a3bb9e..7b8d4d0f86579 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -294,8 +294,8 @@ class AllocatorFacadePrivate { if (val > 0) { auto cuda_allocator = std::make_shared(p); - allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), p); } else { auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc index 5ae0c32f1f8e7..b9392820d49d4 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -27,10 +27,12 @@ bool NeedSplit(size_t block_size, size_t alignment, size_t allock_size) { } AutoGrowthBestFitAllocatorV2::AutoGrowthBestFitAllocatorV2( - const std::shared_ptr &underlying_allocator, size_t alignment) + const std::shared_ptr &underlying_allocator, size_t alignment, + const platform::CUDAPlace &place) : underlying_allocator_( std::make_shared(underlying_allocator, alignment)), - alignment_(alignment) {} + alignment_(alignment), + place_(place) {} Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { std::lock_guard guard(spinlock_); diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h index e4035d5a61778..cb41c74689d9d 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h @@ -45,7 +45,8 @@ struct BlockAllocation : public Allocation { class AutoGrowthBestFitAllocatorV2 : public Allocator { public: AutoGrowthBestFitAllocatorV2( - const std::shared_ptr &underlying_allocator, size_t alignment); + const std::shared_ptr &underlying_allocator, size_t alignment, + const platform::CUDAPlace &place); bool IsAllocThreadSafe() const override { return true; } From ce93e11aba691eb712486dafacae8f3cc1a70656 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 19 Oct 2021 03:17:47 +0000 Subject: [PATCH 27/40] refine virtual allocator, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 95 +++++++++---------- 1 file changed, 44 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index ad9cf67939b38..7e5eda4b863a1 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -90,49 +90,48 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( bool CUDAVirtualMemAllocator::IsAllocThreadSafe() const { return false; } void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { - // PADDLE_ENFORCE_EQ( - // BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, - // platform::errors::PermissionDenied( - // "GPU memory is freed in incorrect device. This may be a bug")); - - // auto iter = virtual_2_physical_map_.find( - // reinterpret_cast(allocation->ptr())); - // if (iter == virtual_2_physical_map_.end()) { - // PADDLE_THROW(platform::errors::InvalidArgument( - // "Can not find virtual memory address at %s", allocation->ptr())); - // } - - // int prev_id; - // cudaGetDevice(&prev_id); - // if (prev_id != place_.device) { - // cudaSetDevice(place_.device); - // } - - // auto result = - // paddle::platform::dynload::cuMemUnmap(iter->first, - // iter->second.second); - // if (result != CUDA_ERROR_DEINITIALIZED) { - // PADDLE_ENFORCE_EQ( - // result, CUDA_SUCCESS, - // platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - // result)); - // } - - // if (result != CUDA_ERROR_DEINITIALIZED) { - // result = paddle::platform::dynload::cuMemRelease(iter->second.first); - // PADDLE_ENFORCE_EQ( - // result, CUDA_SUCCESS, - // platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - // result)); - // } - - // if (prev_id != place_.device) { - // cudaSetDevice(prev_id); - // } - - // virtual_2_physical_map_.erase(iter); - - // delete allocation; + PADDLE_ENFORCE_EQ( + BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, + platform::errors::PermissionDenied( + "GPU memory is freed in incorrect device. This may be a bug")); + + auto iter = virtual_2_physical_map_.find( + reinterpret_cast(allocation->ptr())); + if (iter == virtual_2_physical_map_.end()) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Can not find virtual memory address at %s", allocation->ptr())); + } + + int prev_id; + cudaGetDevice(&prev_id); + if (prev_id != place_.device) { + cudaSetDevice(place_.device); + } + + auto result = + paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); + if (result != CUDA_ERROR_DEINITIALIZED) { + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + result)); + } + + if (result != CUDA_ERROR_DEINITIALIZED) { + result = paddle::platform::dynload::cuMemRelease(iter->second.first); + PADDLE_ENFORCE_EQ( + result, CUDA_SUCCESS, + platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", + result)); + } + + if (prev_id != place_.device) { + cudaSetDevice(prev_id); + } + + virtual_2_physical_map_.erase(iter); + + delete allocation; } Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { @@ -188,24 +187,18 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); if (result != CUDA_SUCCESS) { + paddle::platform::dynload::cuMemRelease(handle); PADDLE_THROW(platform::errors::Fatal( "Call CUDA API cuMemMap faild, return %d.", result)); return nullptr; } - result = paddle::platform::dynload::cuMemRelease(handle); - - if (result != CUDA_SUCCESS) { - PADDLE_THROW(platform::errors::Fatal( - "Call CUDA API cuMemRelease faild, return %d.", result)); - return nullptr; - } - result = paddle::platform::dynload::cuMemSetAccess( ptr, size, access_desc_.data(), access_desc_.size()); if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemUnmap(ptr, size); + paddle::platform::dynload::cuMemRelease(handle); PADDLE_THROW(platform::errors::Fatal( "Call CUDA API cuMemSetAccess faild, return %d.", result)); return nullptr; From 6ab7de3b74cb2626d4aa02e9bdbeff14bd7492ac Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 19 Oct 2021 07:54:15 +0000 Subject: [PATCH 28/40] Record cuMemCreate and cuMemRelease, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 9 ++-- paddle/fluid/platform/gpu_info.cc | 41 +++++++++++++++++++ paddle/fluid/platform/gpu_info.h | 14 +++++++ 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index 7e5eda4b863a1..c00e34335aedb 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -118,7 +118,8 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { } if (result != CUDA_ERROR_DEINITIALIZED) { - result = paddle::platform::dynload::cuMemRelease(iter->second.first); + result = platform::RecordedCuMemRelease(iter->second.first, + iter->second.second, place_.device); PADDLE_ENFORCE_EQ( result, CUDA_SUCCESS, platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", @@ -158,7 +159,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { paddle::platform::CUDADeviceGuard guard(place_.device); auto result = - paddle::platform::dynload::cuMemCreate(&handle, size, &prop_, 0); + platform::RecordedCuMemCreate(&handle, size, &prop_, 0, place_.device); if (result != CUDA_SUCCESS) { if (result == CUDA_ERROR_OUT_OF_MEMORY) { @@ -187,7 +188,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); if (result != CUDA_SUCCESS) { - paddle::platform::dynload::cuMemRelease(handle); + platform::RecordedCuMemRelease(handle, size, place_.device); PADDLE_THROW(platform::errors::Fatal( "Call CUDA API cuMemMap faild, return %d.", result)); return nullptr; @@ -198,7 +199,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemUnmap(ptr, size); - paddle::platform::dynload::cuMemRelease(handle); + platform::RecordedCuMemRelease(handle, size, place_.device); PADDLE_THROW(platform::errors::Fatal( "Call CUDA API cuMemSetAccess faild, return %d.", result)); return nullptr; diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index c624ba94b74a3..d3e72c51bd1fb 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cudnn.h" #endif #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/macros.h" @@ -641,6 +642,30 @@ class RecordedCudaMallocHelper { uint64_t LimitSize() const { return limit_size_; } +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 + CUresult cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags) { // NOLINT + auto result = + paddle::platform::dynload::cuMemCreate(handle, size, prop, flags); + if (result == CUDA_SUCCESS) { + cur_size_.fetch_add(size); + } + return result; + } + + CUresult cuMemRelease(CUmemGenericAllocationHandle handle, size_t size) { + auto result = paddle::platform::dynload::cuMemRelease(handle); + if (result == CUDA_SUCCESS) { + cur_size_.fetch_sub(size); + } + return result; + } + +#endif +#endif + private: const int dev_id_; const uint64_t limit_size_; @@ -664,6 +689,22 @@ void RecordedCudaFree(void *p, size_t size, int dev_id) { return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size); } +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 +CUresult RecordedCuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags, int dev_id) { // NOLINT + return RecordedCudaMallocHelper::Instance(dev_id)->cuMemCreate(handle, size, + prop, flags); +} + +CUresult RecordedCuMemRelease(CUmemGenericAllocationHandle handle, size_t size, + int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->cuMemRelease(handle, size); +} +#endif +#endif + bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, size_t *actual_total, int dev_id) { return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo( diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 401873dcd77da..93e787fcf36f5 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -131,6 +131,20 @@ gpuError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id); //! CudaFree with recorded info void RecordedCudaFree(void *p, size_t size, int dev_id); +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 + +//! cuMemCreate with recorded info +CUresult RecordedCuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags, int dev_id); // NOLINT + +//! cuMemRelease with recorded info +CUresult RecordedCuMemRelease(CUmemGenericAllocationHandle handle, size_t size, + int dev_id); +#endif +#endif + //! Get available and total gpu memory with considering limitation bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, size_t *actual_total, int dev_id); From 1d9024680a64de50805e699ef9d7fe51fb977081 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 19 Oct 2021 08:02:36 +0000 Subject: [PATCH 29/40] refine, test=develop --- paddle/fluid/platform/gpu_info.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index d3e72c51bd1fb..1fa3bd95af411 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -26,7 +26,11 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cudnn.h" #endif #include "paddle/fluid/memory/malloc.h" +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 #include "paddle/fluid/platform/dynload/cuda_driver.h" +#endif +#endif #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/macros.h" From b9c04cc92d38e8c2875b531709c80a68f4584765 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 19 Oct 2021 09:18:10 +0000 Subject: [PATCH 30/40] avoid out of bounds, test=develop --- .../test_softmax_mask_fuse_upper_triangle_op.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py index 8b6d37882ba1a..a73ebd73e4946 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py @@ -43,7 +43,7 @@ def _get_softmax_upper(x, fp16=True): class TestSoftmaxMaskFuseOp(OpTest): def setUp(self): self.op_type = "fused_softmax_mask_upper_triangle" - x = np.random.random((1, 1, 32, 32)).astype("float16") + x = np.random.random((1, 4, 32, 32)).astype("float16") self.inputs = {'X': x} rst = _get_softmax_upper(x) self.outputs = {'Out': rst} @@ -60,7 +60,7 @@ def test_check_grad(self): class TestSoftmaxMaskFuseOp1(OpTest): def setUp(self): self.op_type = "fused_softmax_mask_upper_triangle" - x = np.random.random((1, 1, 32, 32)) + x = np.random.random((1, 4, 32, 32)) self.inputs = {'X': x} rst = _get_softmax_upper(x) self.outputs = {'Out': rst} @@ -90,10 +90,10 @@ def test_static(self): for dtype in self.dtypes: with fluid.program_guard(fluid.Program(), fluid.Program()): input_x = fluid.data( - name="x", shape=[1, 1, 32, 32], dtype=dtype) + name="x", shape=[1, 4, 32, 32], dtype=dtype) rst = incubate.softmax_mask_fuse_upper_triangle(input_x) - x_in_np = np.random.random((1, 1, 32, 32)).astype(dtype) + x_in_np = np.random.random((1, 4, 32, 32)).astype(dtype) rst_np = _get_softmax_upper(x_in_np, dtype == 'float16') exe = fluid.Executor(fluid.CUDAPlace(0)) @@ -105,7 +105,7 @@ def test_static(self): def test_dygraph(self): for dtype in self.dtypes: with fluid.dygraph.guard(fluid.CUDAPlace(0)): - x_in_np = np.random.random((1, 1, 32, 32)).astype(dtype) + x_in_np = np.random.random((1, 4, 32, 32)).astype(dtype) rst_np = _get_softmax_upper(x_in_np, dtype == 'float16') input_x = fluid.dygraph.to_variable(x_in_np) From 5fca3b0f2445d38e055c8bdbd70acb5e8db8c4cb Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 20 Oct 2021 10:20:51 +0000 Subject: [PATCH 31/40] rename allocator, test=develop --- paddle/fluid/memory/allocation/CMakeLists.txt | 4 ++-- .../memory/allocation/allocator_facade.cc | 12 +++++++---- ..._memory_auto_growth_best_fit_allocator.cc} | 20 ++++++++++--------- ...l_memory_auto_growth_best_fit_allocator.h} | 4 ++-- 4 files changed, 23 insertions(+), 17 deletions(-) rename paddle/fluid/memory/allocation/{auto_growth_best_fit_allocator_v2.cc => virtual_memory_auto_growth_best_fit_allocator.cc} (92%) rename paddle/fluid/memory/allocation/{auto_growth_best_fit_allocator_v2.h => virtual_memory_auto_growth_best_fit_allocator.h} (95%) diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 23c391b544a71..58979d6c3e185 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -78,7 +78,7 @@ else() cpu_allocator) endif() -list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator auto_growth_best_fit_allocator_v2 best_fit_allocator) +list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator virtual_memory_auto_growth_best_fit_allocator best_fit_allocator) if (WITH_ASCEND_CL) list(APPEND AllocatorFacadeDeps npu_pinned_allocator) @@ -113,7 +113,7 @@ cc_library(auto_growth_best_fit_allocator SRCS auto_growth_best_fit_allocator.cc cc_test(auto_growth_best_fit_allocator_facade_test SRCS auto_growth_best_fit_allocator_facade_test.cc DEPS cpu_allocator auto_growth_best_fit_allocator) cc_test(auto_growth_best_fit_allocator_test SRCS auto_growth_best_fit_allocator_test.cc DEPS auto_growth_best_fit_allocator) -cc_library(auto_growth_best_fit_allocator_v2 SRCS auto_growth_best_fit_allocator_v2.cc DEPS allocator aligned_allocator) +cc_library(virtual_memory_auto_growth_best_fit_allocator SRCS virtual_memory_auto_growth_best_fit_allocator.cc DEPS allocator aligned_allocator) if(NOT WIN32) cc_library(mmap_allocator SRCS mmap_allocator.cc DEPS allocator) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 7b8d4d0f86579..61d0e0cabd39e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -34,8 +34,8 @@ #include "paddle/fluid/platform/gpu_info.h" #endif #if CUDA_VERSION >= 10020 -#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" #include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#include "paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" #include "paddle/fluid/platform/dynload/cuda_driver.h" #endif #ifdef PADDLE_WITH_CUDA @@ -56,6 +56,9 @@ PADDLE_DEFINE_EXPORTED_bool( "Whether to use system allocator to allocate CPU and GPU memory. " "Only used for unittests."); +PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, + "Use VirtualMemoryAutoGrowthBestFitAllocator."); + DECLARE_string(allocator_strategy); namespace paddle { @@ -292,10 +295,11 @@ class AllocatorFacadePrivate { val = 0; } - if (val > 0) { + if (val > 0 && FLAGS_use_virtual_memory_auto_growth) { auto cuda_allocator = std::make_shared(p); - allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize(), p); + allocators_[p] = + std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), p); } else { auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc similarity index 92% rename from paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc rename to paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc index b9392820d49d4..059b74dd0cab4 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc @@ -15,7 +15,7 @@ #include #include "paddle/fluid/memory/allocation/aligned_allocator.h" -#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" +#include "paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" namespace paddle { namespace memory { @@ -26,15 +26,16 @@ bool NeedSplit(size_t block_size, size_t alignment, size_t allock_size) { (block_size - allock_size) > alignment; } -AutoGrowthBestFitAllocatorV2::AutoGrowthBestFitAllocatorV2( - const std::shared_ptr &underlying_allocator, size_t alignment, - const platform::CUDAPlace &place) +VirtualMemoryAutoGrowthBestFitAllocator:: + VirtualMemoryAutoGrowthBestFitAllocator( + const std::shared_ptr &underlying_allocator, + size_t alignment, const platform::CUDAPlace &place) : underlying_allocator_( std::make_shared(underlying_allocator, alignment)), alignment_(alignment), place_(place) {} -Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { +Allocation *VirtualMemoryAutoGrowthBestFitAllocator::AllocateImpl(size_t size) { std::lock_guard guard(spinlock_); size = AlignedSize(size, alignment_); auto result = AllocFromFreeBlocks(size); @@ -47,14 +48,14 @@ Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl(size_t size) { return result; } -void AutoGrowthBestFitAllocatorV2::FreeImpl(Allocation *allocation) { +void VirtualMemoryAutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) { std::lock_guard guard(spinlock_); auto block_it = static_cast(allocation)->block_it_; TryMergeBlock2Blocks(block_it); delete allocation; } -void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( +void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks( std::list::iterator block) { if (block->ptr_ == all_blocks_.front().ptr_ && block->ptr_ == all_blocks_.back().ptr_) { @@ -140,7 +141,7 @@ void AutoGrowthBestFitAllocatorV2::TryMergeBlock2Blocks( } } -void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { +void VirtualMemoryAutoGrowthBestFitAllocator::ExtendAndMerge(size_t size) { void *ptr = nullptr; auto allocateptr = underlying_allocator_->Allocate(size); @@ -233,7 +234,8 @@ void AutoGrowthBestFitAllocatorV2::ExtendAndMerge(size_t size) { } } -Allocation *AutoGrowthBestFitAllocatorV2::AllocFromFreeBlocks(size_t size) { +Allocation *VirtualMemoryAutoGrowthBestFitAllocator::AllocFromFreeBlocks( + size_t size) { auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); if (iter != free_blocks_.end()) { std::list::iterator block_it = iter->second; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h similarity index 95% rename from paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h rename to paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h index cb41c74689d9d..71a5cb12b0a98 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h @@ -42,9 +42,9 @@ struct BlockAllocation : public Allocation { std::list::iterator block_it_; }; -class AutoGrowthBestFitAllocatorV2 : public Allocator { +class VirtualMemoryAutoGrowthBestFitAllocator : public Allocator { public: - AutoGrowthBestFitAllocatorV2( + VirtualMemoryAutoGrowthBestFitAllocator( const std::shared_ptr &underlying_allocator, size_t alignment, const platform::CUDAPlace &place); From f164521dac3bd5b20901dc73d50fb5c757732af1 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 21 Oct 2021 06:24:23 +0000 Subject: [PATCH 32/40] refine, test=develop --- python/paddle/fluid/dataloader/dataloader_iter.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index cd77b0423a622..70c7b01b05ba3 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -360,11 +360,6 @@ def __init__(self, loader): self._shutdown = False def _init_workers(self): - # try: - # multiprocessing.set_start_method('spawn') - # except RuntimeError: - # pass - # multiprocess worker and indice queue list initial as empty self._workers = [] self._worker_status = [] From 55ef1007444d86f23ddf063e1af5100691877885 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 22 Oct 2021 07:15:04 +0000 Subject: [PATCH 33/40] use PADDLE_ENFORCE_CUDA_SUCCESS, test=develop --- .../memory/allocation/allocator_facade.cc | 21 ++++------ .../allocation/cuda_virtual_mem_allocator.cc | 40 +++++-------------- paddle/fluid/platform/enforce.h | 12 ++++++ paddle/fluid/platform/external_error.proto | 1 + 4 files changed, 31 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 61d0e0cabd39e..3fbf7f910b5d5 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -277,20 +277,13 @@ class AllocatorFacadePrivate { CUdevice device; int val; try { - auto result = - paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuDeviceGet faild, return %d.", - result)); - - result = paddle::platform::dynload::cuDeviceGetAttribute( - &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, - device); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuDeviceGetAttribute faild, return %d.", result)); + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId())); + + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, + device)); } catch (...) { val = 0; } diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index c00e34335aedb..ab5535ff0a396 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -61,13 +61,9 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { size_t granularity; prop.location.id = dev_id; - auto result = paddle::platform::dynload::cuMemGetAllocationGranularity( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuMemGetAllocationGranularity faild, return %d.", - result)); + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuMemGetAllocationGranularity( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); granularity_ = std::max(granularity, granularity_); } @@ -77,12 +73,8 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( virtual_mem_size_ = (actual_total + granularity_ - 1) & ~(granularity_ - 1); - auto result = paddle::platform::dynload::cuMemAddressReserve( - &virtual_mem_base_, virtual_mem_size_, 0, 0, 0); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal( - "Call CUDA API cuMemAddressReserve faild, return %d.", result)); + PADDLE_ENFORCE_CUDA_SUCCESS(paddle::platform::dynload::cuMemAddressReserve( + &virtual_mem_base_, virtual_mem_size_, 0, 0, 0)); virtual_mem_alloced_offset_ = 0; } @@ -111,19 +103,12 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { auto result = paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); if (result != CUDA_ERROR_DEINITIALIZED) { - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - result)); + PADDLE_ENFORCE_CUDA_SUCCESS(result); } if (result != CUDA_ERROR_DEINITIALIZED) { - result = platform::RecordedCuMemRelease(iter->second.first, - iter->second.second, place_.device); - PADDLE_ENFORCE_EQ( - result, CUDA_SUCCESS, - platform::errors::Fatal("Call CUDA API cuMemUnmap faild, return %d.", - result)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::RecordedCuMemRelease( + iter->second.first, iter->second.second, place_.device)); } if (prev_id != place_.device) { @@ -179,8 +164,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { string::HumanReadableSize(actual_allocated), string::HumanReadableSize(actual_avail), place_.device)); } else { - PADDLE_THROW(platform::errors::Fatal( - "Call CUDA API cuMemCreate faild, return %d.", result)); + PADDLE_ENFORCE_CUDA_SUCCESS(result); } return nullptr; } @@ -189,8 +173,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { if (result != CUDA_SUCCESS) { platform::RecordedCuMemRelease(handle, size, place_.device); - PADDLE_THROW(platform::errors::Fatal( - "Call CUDA API cuMemMap faild, return %d.", result)); + PADDLE_ENFORCE_CUDA_SUCCESS(result); return nullptr; } @@ -200,8 +183,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { if (result != CUDA_SUCCESS) { paddle::platform::dynload::cuMemUnmap(ptr, size); platform::RecordedCuMemRelease(handle, size, place_.device); - PADDLE_THROW(platform::errors::Fatal( - "Call CUDA API cuMemSetAccess faild, return %d.", result)); + PADDLE_ENFORCE_CUDA_SUCCESS(result); return nullptr; } diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 7427060add8b1..40fc966949d35 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -716,6 +716,7 @@ DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN); DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT); +DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL); @@ -730,6 +731,7 @@ inline const char* GetErrorMsgUrl(T status) { details::ExternalApiType<__CUDA_STATUS_TYPE__>::kProtoType; switch (proto_type) { case platform::proto::ApiType::CUDA: + case platform::proto::ApiType::CU: return "https://docs.nvidia.com/cuda/cuda-runtime-api/" "group__CUDART__TYPES.html#group__CUDART__TYPES_" "1g3f51e3575c2178246db0a94a430e0038"; @@ -844,6 +846,7 @@ template std::string GetExternalErrorMsg(cudnnStatus_t); template std::string GetExternalErrorMsg(cublasStatus_t); template std::string GetExternalErrorMsg(cusolverStatus_t); template std::string GetExternalErrorMsg(cufftResult_t); +template std::string GetExternalErrorMsg(CUresult); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) template std::string GetExternalErrorMsg(ncclResult_t); #endif @@ -913,6 +916,15 @@ inline std::string build_nvidia_error_msg(cufftResult_t stat) { return sout.str(); } +/*************** CUresult ERROR ***************/ +inline bool is_error(CUresult stat) { return stat != CUDA_SUCCESS; } + +inline std::string build_nvidia_error_msg(CUresult stat) { + std::ostringstream sout; + sout << "CU error(" << stat << "). " << GetExternalErrorMsg(stat); + return sout.str(); +} + /**************** NCCL ERROR ****************/ #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) inline bool is_error(ncclResult_t nccl_result) { diff --git a/paddle/fluid/platform/external_error.proto b/paddle/fluid/platform/external_error.proto index cbbf803492e64..fcbbb4162612d 100644 --- a/paddle/fluid/platform/external_error.proto +++ b/paddle/fluid/platform/external_error.proto @@ -25,6 +25,7 @@ enum ApiType { CUSOLVER = 4; NCCL = 5; CUFFT = 6; + CU = 7; } message MessageDesc { From fa077f38aa4d9f59a371f1816aa2e9659d785d57 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 22 Oct 2021 07:23:16 +0000 Subject: [PATCH 34/40] for test,test=develop --- paddle/fluid/memory/allocation/allocator_facade.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 3fbf7f910b5d5..ec4058f348295 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -56,7 +56,7 @@ PADDLE_DEFINE_EXPORTED_bool( "Whether to use system allocator to allocate CPU and GPU memory. " "Only used for unittests."); -PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, +PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, true, "Use VirtualMemoryAutoGrowthBestFitAllocator."); DECLARE_string(allocator_strategy); From b82a059a62003403a7c76b649c857463a7dba301 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 22 Oct 2021 10:14:47 +0000 Subject: [PATCH 35/40] refine, test=develop --- paddle/fluid/memory/allocation/allocator_facade.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index ec4058f348295..3fbf7f910b5d5 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -56,7 +56,7 @@ PADDLE_DEFINE_EXPORTED_bool( "Whether to use system allocator to allocate CPU and GPU memory. " "Only used for unittests."); -PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, true, +PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, "Use VirtualMemoryAutoGrowthBestFitAllocator."); DECLARE_string(allocator_strategy); From a4db9cb5d875410721bdc53899c303065796ef52 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 1 Nov 2021 03:14:10 +0000 Subject: [PATCH 36/40] refine, test=develop --- paddle/fluid/platform/dynload/cuda_driver.cc | 4 +- paddle/fluid/platform/dynload/cuda_driver.h | 40 +++++++------------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/platform/dynload/cuda_driver.cc b/paddle/fluid/platform/dynload/cuda_driver.cc index 89a29bae7f337..052c19e3d03ae 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.cc +++ b/paddle/fluid/platform/dynload/cuda_driver.cc @@ -22,7 +22,9 @@ std::once_flag cuda_dso_flag; void* cuda_dso_handle = nullptr; #define DEFINE_WRAP(__name) DynLoad__##__name __name - +CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); +#if CUDA_VERSION >= 10020 +#endif CUDA_ROUTINE_EACH(DEFINE_WRAP); bool HasCUDADriver() { diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h index 424d8d38bc943..b5212c64cd14d 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.h +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -42,7 +42,6 @@ extern bool HasCUDADriver(); }; \ extern struct DynLoad__##__name __name -#if CUDA_VERSION >= 10020 /** * include all needed cuda driver functions **/ @@ -59,34 +58,23 @@ extern bool HasCUDADriver(); __macro(cuCtxGetCurrent); \ __macro(cuDeviceGetCount); \ __macro(cuDevicePrimaryCtxGetState); \ - __macro(cuMemGetAllocationGranularity); \ - __macro(cuMemAddressReserve); \ - __macro(cuMemCreate); \ - __macro(cuMemMap); \ - __macro(cuMemSetAccess); \ - __macro(cuMemUnmap); \ - __macro(cuMemRelease); \ - __macro(cuMemAddressFree); \ __macro(cuDeviceGetAttribute); \ __macro(cuDeviceGet) -#else -/** - * include all needed cuda driver functions - **/ -#define CUDA_ROUTINE_EACH(__macro) \ - __macro(cuInit); \ - __macro(cuDriverGetVersion); \ - __macro(cuGetErrorString); \ - __macro(cuModuleLoadData); \ - __macro(cuModuleGetFunction); \ - __macro(cuModuleUnload); \ - __macro(cuOccupancyMaxActiveBlocksPerMultiprocessor); \ - __macro(cuLaunchKernel); \ - __macro(cuCtxCreate); \ - __macro(cuCtxGetCurrent); \ - __macro(cuDeviceGetCount); \ - __macro(cuDevicePrimaryCtxGetState) + +#if CUDA_VERSION >= 10020 +#define CUDA_ROUTINE_EACH_VVM(__macro) \ + __macro(cuMemGetAllocationGranularity); \ + __macro(cuMemAddressReserve); \ + __macro(cuMemCreate); \ + __macro(cuMemMap); \ + __macro(cuMemSetAccess); \ + __macro(cuMemUnmap); \ + __macro(cuMemRelease); \ + __macro(cuMemAddressFree) + +CUDA_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); #endif + CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); #undef DECLARE_DYNAMIC_LOAD_CUDA_WRAP From 7eadf41c534266ccd5d9f1075e26372720e5a012 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Tue, 2 Nov 2021 01:41:37 +0000 Subject: [PATCH 37/40] refine, test=develop --- paddle/fluid/platform/dynload/cuda_driver.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/platform/dynload/cuda_driver.cc b/paddle/fluid/platform/dynload/cuda_driver.cc index 052c19e3d03ae..6110e6b6ba93f 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.cc +++ b/paddle/fluid/platform/dynload/cuda_driver.cc @@ -22,8 +22,9 @@ std::once_flag cuda_dso_flag; void* cuda_dso_handle = nullptr; #define DEFINE_WRAP(__name) DynLoad__##__name __name -CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); + #if CUDA_VERSION >= 10020 +CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); #endif CUDA_ROUTINE_EACH(DEFINE_WRAP); From 4b200916c2b306bde9da22bc94ad6322cc6e1661 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 3 Nov 2021 06:12:49 +0000 Subject: [PATCH 38/40] refine, test=develop --- paddle/fluid/platform/gpu_info.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 1fa3bd95af411..9dc6254234a97 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -648,9 +648,9 @@ class RecordedCudaMallocHelper { #ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 10020 - CUresult cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, - const CUmemAllocationProp *prop, - unsigned long long flags) { // NOLINT + CUresult MemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags) { // NOLINT auto result = paddle::platform::dynload::cuMemCreate(handle, size, prop, flags); if (result == CUDA_SUCCESS) { @@ -659,7 +659,7 @@ class RecordedCudaMallocHelper { return result; } - CUresult cuMemRelease(CUmemGenericAllocationHandle handle, size_t size) { + CUresult MemRelease(CUmemGenericAllocationHandle handle, size_t size) { auto result = paddle::platform::dynload::cuMemRelease(handle); if (result == CUDA_SUCCESS) { cur_size_.fetch_sub(size); @@ -698,13 +698,13 @@ void RecordedCudaFree(void *p, size_t size, int dev_id) { CUresult RecordedCuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, const CUmemAllocationProp *prop, unsigned long long flags, int dev_id) { // NOLINT - return RecordedCudaMallocHelper::Instance(dev_id)->cuMemCreate(handle, size, - prop, flags); + return RecordedCudaMallocHelper::Instance(dev_id)->MemCreate(handle, size, + prop, flags); } CUresult RecordedCuMemRelease(CUmemGenericAllocationHandle handle, size_t size, int dev_id) { - return RecordedCudaMallocHelper::Instance(dev_id)->cuMemRelease(handle, size); + return RecordedCudaMallocHelper::Instance(dev_id)->MemRelease(handle, size); } #endif #endif From e8469f4e7caaf6bb1c8d26d3c7f30a1241ff417e Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 5 Nov 2021 08:13:59 +0000 Subject: [PATCH 39/40] refine, test=develop --- .../allocation/cuda_virtual_mem_allocator.cc | 27 +++++++++++++++++-- .../allocation/cuda_virtual_mem_allocator.h | 1 + ...l_memory_auto_growth_best_fit_allocator.cc | 19 ++++--------- ...al_memory_auto_growth_best_fit_allocator.h | 9 +++++++ 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc index ab5535ff0a396..ef64c3bdb355e 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -37,11 +37,18 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( : place_(place) { CUmemAllocationProp prop = {}; + // Setup the properties common for all the chunks + // The allocations will be device pinned memory. + // This property structure describes the physical location where the memory + // will be allocated via cuMemCreate allong with additional properties In this + // case, the allocation will be pinnded device memory local to a given device. prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = place.device; prop_ = prop; + // Prepare the access descriptor array indicating where and how the backings + // should be visible. access_desc_.resize(platform::GetCUDADeviceCount()); for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { if (place.device != dev_id) { @@ -52,11 +59,16 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( continue; } } + // Specify which device we are adding mappings for. access_desc_[dev_id].location.type = CU_MEM_LOCATION_TYPE_DEVICE; access_desc_[dev_id].location.id = dev_id; + + // Specify both read and write access. access_desc_[dev_id].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; } + // Get the minimum granularity needed for all devices + // (the max of the minimum granularity of each participating device) granularity_ = 0; for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { size_t granularity; @@ -71,8 +83,13 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( paddle::platform::CUDADeviceGuard guard(place.device); PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); - virtual_mem_size_ = (actual_total + granularity_ - 1) & ~(granularity_ - 1); + virtual_mem_size_ = AlignedSize(actual_total, granularity_); + // Reserve the required contiguous virtual address space for the allocations + // The maximum video memory size we can apply for is the video memory size of + // GPU, + // so the virtual address space size we reserve is equal to the GPU video + // memory size PADDLE_ENFORCE_CUDA_SUCCESS(paddle::platform::dynload::cuMemAddressReserve( &virtual_mem_base_, virtual_mem_size_, 0, 0, 0)); @@ -121,7 +138,7 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { } Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { - size = (size + granularity_ - 1) & ~(granularity_ - 1); + size = AlignedSize(size, granularity_); CUdeviceptr ptr = virtual_mem_base_ + virtual_mem_alloced_offset_; @@ -143,6 +160,8 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { CUmemGenericAllocationHandle handle; paddle::platform::CUDADeviceGuard guard(place_.device); + + // Create physical memory backing allocation. auto result = platform::RecordedCuMemCreate(&handle, size, &prop_, 0, place_.device); @@ -169,6 +188,9 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { return nullptr; } + // Assign the chunk to the appropriate VA range and release the handle. + // After mapping the memory, it can be referenced by virtual address. + // The allocation will be kept live until it is unmapped. result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); if (result != CUDA_SUCCESS) { @@ -177,6 +199,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { return nullptr; } + // Apply the access descriptors to the whole VA range. result = paddle::platform::dynload::cuMemSetAccess( ptr, size, access_desc_.data(), access_desc_.size()); diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h index 8a52ac4ab1ea7..c51b56566bb02 100644 --- a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h @@ -29,6 +29,7 @@ namespace paddle { namespace memory { namespace allocation { +// Allocate memory using NVIDIA's virtual memory management technology class CUDAVirtualMemAllocator : public Allocator { public: explicit CUDAVirtualMemAllocator(const platform::CUDAPlace& place); diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc index 059b74dd0cab4..5c7e8e2d933f3 100644 --- a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc @@ -62,9 +62,7 @@ void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks( block->is_free_ = true; free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } else if (block->ptr_ == all_blocks_.front().ptr_) { - block++; - auto next = block; - block--; + auto next = std::next(block); if (next->is_free_ && reinterpret_cast(block->ptr_) + block->size_ == next->ptr_) { // merge with next @@ -78,9 +76,7 @@ void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks( free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else if (block->ptr_ == all_blocks_.back().ptr_) { - block--; - auto pre = block; - block++; + auto pre = std::prev(block); if (pre->is_free_ && reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_) { // merge with pre @@ -93,12 +89,8 @@ void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks( free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); } } else { - block--; - auto pre = block; - block++; - block++; - auto next = block; - block--; + auto pre = std::prev(block); + auto next = std::next(block); if (pre->is_free_ && reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_ && !(next->is_free_ && @@ -175,8 +167,7 @@ void VirtualMemoryAutoGrowthBestFitAllocator::ExtendAndMerge(size_t size) { } else { // insert to middle auto next = block_it; - block_it--; - auto pre = block_it; + auto pre = std::prev(block_it); if (pre->is_free_ && reinterpret_cast(pre->ptr_) + pre->size_ == ptr && !(next->is_free_ && diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h index 71a5cb12b0a98..e1e43cacedd00 100644 --- a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h @@ -42,6 +42,15 @@ struct BlockAllocation : public Allocation { std::list::iterator block_it_; }; +/** + * Like AutoGrowthBestFitAllocator, VirtualMemoryAutoGrowthBestFitAllocator will + * gradually apply to GPU for video memory as the model uses more video memory. + * However, the difference is that virtualmemoryautogrowthbestfitallocator uses + * nviaid's virtual memory management technology and obtains the virtual memory + * address. If the video memory applied for twice is continuous, we can combine + * the two video memories later. This combination can greatly reduce + * fragmentation. + */ class VirtualMemoryAutoGrowthBestFitAllocator : public Allocator { public: VirtualMemoryAutoGrowthBestFitAllocator( From f7df2b886ec9beae60b5bc4ef7dfb1303cd611bf Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Fri, 5 Nov 2021 08:15:29 +0000 Subject: [PATCH 40/40] refine, test=develop --- .../allocation/virtual_memory_auto_growth_best_fit_allocator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h index e1e43cacedd00..5171e5b3cd1bf 100644 --- a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h @@ -45,7 +45,7 @@ struct BlockAllocation : public Allocation { /** * Like AutoGrowthBestFitAllocator, VirtualMemoryAutoGrowthBestFitAllocator will * gradually apply to GPU for video memory as the model uses more video memory. - * However, the difference is that virtualmemoryautogrowthbestfitallocator uses + * However, the difference is that VirtualMemoryAutoGrowthBestFitAllocator uses * nviaid's virtual memory management technology and obtains the virtual memory * address. If the video memory applied for twice is continuous, we can combine * the two video memories later. This combination can greatly reduce