diff --git a/cmake/operators.cmake b/cmake/operators.cmake index b094f08619e3e..d1615b0efed98 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -218,7 +218,7 @@ function(op_library TARGET) "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" -"fused_bn_add_activation_op" "fused_attention_op" "fused_feedforward_op") +"fused_bn_add_activation_op" "fused_attention_op" "fused_feedforward_op" "resnet_unit_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 0629f1b91504a..25110fe24f587 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -143,6 +143,8 @@ struct BuildStrategy { // Turn off inplace addto by default. bool enable_addto_{false}; + bool allow_cuda_graph_capture_{false}; + // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // num_trainers is 1, so the current fields of build_strategy doesn't tell if // it's distributed model. diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 120bdd2bc9f56..75998e4582e2b 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -130,10 +130,12 @@ FetchResultType FastThreadedSSAGraphExecutor::Run( } } // Wait FetchOps. - ClearFetchOp(graph_, &fetch_ops); + if (!fetch_ops.empty()) { + ClearFetchOp(graph_, &fetch_ops); - for (auto &place : places_) { - fetch_ctxs_.Get(place)->Wait(); + for (auto &place : places_) { + fetch_ctxs_.Get(place)->Wait(); + } } return fetches; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index fcfbfd0557e25..f4f4f5a42fd1b 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -86,19 +86,28 @@ struct ScaleLossGradFunctor { } }; +std::string ScaleLossGradOpHandle::LossGradName() const { + return static_cast(this->outputs_[0])->name(); +} + void ScaleLossGradOpHandle::RunImpl() { platform::RecordEvent record_event(Name()); - // Doesn't wait any event - std::string var_name = static_cast(this->outputs_[0])->name(); + RunOnVar(local_exec_scopes_[0]->FindVar(LossGradName()), true); +} - auto *tensor = - local_exec_scopes_[0]->FindVar(var_name)->GetMutable(); +void ScaleLossGradOpHandle::RunOnVar(Variable *var, bool record_event) { + auto *tensor = var->GetMutable(); tensor->Resize(make_ddim({1})); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, this->dev_ctxes_.at(place_)); - this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); }); + if (record_event) { + this->RunAndRecordEvent( + [&] { framework::VisitDataType(out_dtype_, func); }); + } else { + framework::VisitDataType(out_dtype_, func); + } #else ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, nullptr); framework::VisitDataType(out_dtype_, func); diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 02e5aa88443df..88fe02a749fe4 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -46,6 +46,12 @@ struct ScaleLossGradOpHandle : public OpHandleBase { std::string Name() const override; + platform::Place GetPlace() const { return place_; } + + void RunOnVar(Variable *var, bool record_event = false); + + std::string LossGradName() const; + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index ad47846c59a05..5d271d06b6922 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -22,7 +22,9 @@ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/profiler.h" + namespace paddle { namespace framework { namespace details { @@ -49,8 +51,29 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( PrepareLocalExeScopes(); } +static void RunProgramDescs(const ProgramDescs &programs, + const std::vector &local_exec_scopes, + const std::vector &places) { + for (auto &program : programs) { + for (auto &op_desc : program.Block(0).AllOps()) { + for (size_t i = 0; i < local_exec_scopes.size(); ++i) { + auto op = OpRegistry::CreateOp(*op_desc); + op->Run(*local_exec_scopes[i], places[i]); + } + } + } +} + FetchResultType ScopeBufferedSSAGraphExecutor::Run( const std::vector &fetch_tensors, bool return_merged) { +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + strategy_.num_iteration_per_drop_scope_ = + std::numeric_limits::max(); + DropLocalExeScopes(/*need_wait=*/false); + } +#endif + if (drop_scope_counter_ == 0) { platform::RecordEvent e("InitLocalVars"); InitVariables(); @@ -84,7 +107,7 @@ FetchResultType ScopeBufferedSSAGraphExecutor::Run( ++drop_scope_counter_; if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ || DropScopeOrNot()) { - DropLocalExeScopes(); + DropLocalExeScopes(!platform::IsCUDAGraphCapturing()); } if (VLOG_IS_ON(5)) { @@ -128,15 +151,7 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() { if (graph.Has(details::kStartupProgramDescs)) { auto &program_descs = graph.Get(details::kStartupProgramDescs); - - for (auto &program_desc : program_descs) { - for (auto &op_desc : program_desc.Block(0).AllOps()) { - for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { - auto op = OpRegistry::CreateOp(*op_desc); - op->Run(*local_exec_scopes_[i], places_[i]); - } - } - } + RunProgramDescs(program_descs, local_exec_scopes_, places_); } is_initialized_ = true; } @@ -144,23 +159,17 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() { if (graph.Has(details::kProgramDescs)) { auto &program_descs = graph.Get(details::kProgramDescs); - - for (auto &program_desc : program_descs) { - for (auto &op_desc : program_desc.Block(0).AllOps()) { - for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { - auto op = OpRegistry::CreateOp(*op_desc); - op->Run(*local_exec_scopes_[i], places_[i]); - } - } - } + RunProgramDescs(program_descs, local_exec_scopes_, places_); } } -void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() { +void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes(bool need_wait) { platform::RecordEvent drop_scope_event("DropLocalExeScopes"); drop_scope_counter_ = 0; - for (auto &p : places_) { - platform::DeviceContextPool::Instance().Get(p)->Wait(); + if (need_wait) { + for (auto &p : places_) { + platform::DeviceContextPool::Instance().Get(p)->Wait(); + } } scope_monitor_.ClearHistoryLocalExecScopes(); for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index aa2b113c960a3..ea5a3c07957bf 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -53,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { FetchResultType Run(const std::vector& fetch_tensors, bool return_merged) override; - void DropLocalExeScopes(); + void DropLocalExeScopes(bool need_wait = true); bool NeedCreateLocalExeScope(); diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 17d15a94c7287..e7a25de96a947 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -115,6 +115,7 @@ message BuildStrategy { optional bool enable_auto_fusion = 11 [ default = false ]; optional bool enable_addto = 12 [ default = false ]; optional bool fix_op_run_order = 13 [ default = false ]; + optional bool allow_cuda_graph_capture = 14 [ default = false ]; } message ExecutionStrategy { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc index 849d0dabab779..d09de5be84c35 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/inplace_addto_op_pass.cc @@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const { out_var_ptr->GeneratedOp()); // NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy - if (right_generated_op->Name() != "conv2d_grad") { + if (right_generated_op->Name() != "conv2d_grad" && + right_generated_op->Name() != "resnet_unit_grad") { continue; } @@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) { if (node.inputs.empty()) return false; auto *generated_op = node.inputs[0]; auto *op_desc = generated_op->Op(); - if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") { + if (op_desc == nullptr || (op_desc->Type() != "conv2d_grad" && + op_desc->Type() != "resnet_unit_grad")) { return false; } const auto &outputs = op_desc->Outputs(); - auto iter = outputs.find(GradVarName("Input")); + std::string grad_var_name = op_desc->Type() == "conv2d_grad" ? "Input" : "X"; + auto iter = outputs.find(GradVarName(grad_var_name)); return iter != outputs.end() && !iter->second.empty() && iter->second[0] == node.Name() && !op_desc->GetAttrIfExists("use_addto"); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt index 6764799d82866..fea12baf0651f 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) +cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle op_graph_view multi_devices_helper) cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper) diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc index 70b95c9154fd3..afd80e45cf65e 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h" @@ -21,14 +22,23 @@ namespace paddle { namespace framework { namespace ir { +template +static bool IsMatchedPlaceSingleDeviceOp(details::OpHandleBase *op_base, + const platform::Place &place) { + auto *op = dynamic_cast(op_base); + return op && op->GetPlace() == place; +} + static bool IsLockAndRecordEventFreeComputationOpHandle( details::ComputationOpHandle *op, const OpGraphView &graph_view) { if (!platform::is_gpu_place(op->GetPlace()) && !platform::is_xpu_place(op->GetPlace())) return false; for (auto &pending_op : graph_view.PendingOps(op)) { - auto *tmp = dynamic_cast(pending_op); - if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { + if (!IsMatchedPlaceSingleDeviceOp( + pending_op, op->GetPlace()) && + !IsMatchedPlaceSingleDeviceOp( + pending_op, op->GetPlace())) { return false; } } diff --git a/paddle/fluid/framework/operator_kernel_configs.h b/paddle/fluid/framework/operator_kernel_configs.h index 68edb7c89dd87..ab812a30981f0 100644 --- a/paddle/fluid/framework/operator_kernel_configs.h +++ b/paddle/fluid/framework/operator_kernel_configs.h @@ -15,8 +15,10 @@ limitations under the License. */ #pragma once #include +#include #include #include +#include "glog/logging.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index adbbfb380bc45..d19ac0b65f4d1 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" @@ -34,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/profiler.h" @@ -43,6 +45,10 @@ limitations under the License. */ DECLARE_double(eager_delete_tensor_gb); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +DECLARE_bool(sync_nccl_allreduce); +#endif + #ifdef WITH_GPERFTOOLS #include "gperftools/profiler.h" #endif @@ -669,6 +675,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // ncclOp std::vector async_graphs = CompileGraphWithBuildStrategy(graph, &graphs, loss_var_name); + PrepareForCUDAGraphCapture(graph); graph = member_->ApplyMemoryOptimizePass(graph); async_graphs[0] = graph; @@ -882,6 +889,23 @@ void ParallelExecutor::BCastParamsToDevices( FetchResultType ParallelExecutor::Run( const std::vector &fetch_tensors, bool return_merged) { VLOG(3) << "enter ParallelExecutor Run"; +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_tensors.empty(), true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + PADDLE_ENFORCE_EQ( + member_->build_strategy_.allow_cuda_graph_capture_, true, + platform::errors::InvalidArgument( + "You must turn on build_strategy.allow_cuda_graph_capture = True " + "to enable CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + member_->places_[0], platform::CUDAGraphCapturingPlace(), + platform::errors::InvalidArgument("The place to capture CUDAGraph is " + "not the same as the place to run.")); + } +#endif + #ifdef WITH_GPERFTOOLS if (gProfileStarted) { ProfilerFlush(); @@ -932,6 +956,16 @@ void ParallelExecutor::SkipMemoryReuse( void ParallelExecutor::FeedTensorsIntoLocalScopes( const std::vector> &tensors) { + if (platform::IsCUDAGraphCapturing()) { + for (auto &tensor : tensors) { + PADDLE_ENFORCE_EQ( + tensor.empty(), true, + platform::errors::PermissionDenied( + "Feeding data is not permitted when capturing CUDA Graph.")); + } + return; + } + if (!member_->AllowPartialFeed()) { PADDLE_ENFORCE_EQ(tensors.size(), member_->local_scopes_.size(), platform::errors::Unimplemented( @@ -987,6 +1021,14 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( const std::unordered_map &tensors) { + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ( + tensors.empty(), true, + platform::errors::PermissionDenied( + "Feeding data is not permitted when capturing CUDA Graph.")); + return; + } + size_t num_places = member_->places_.size(); bool allow_partial_feed = member_->AllowPartialFeed(); @@ -1568,6 +1610,107 @@ const ir::Graph &ParallelExecutor::Graph() const { return member_->executor_->Graph(); } +void ParallelExecutor::PrepareForCUDAGraphCapture(ir::Graph *graph) { + const auto &build_strategy = member_->build_strategy_; + if (!build_strategy.allow_cuda_graph_capture_) return; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ( + build_strategy.async_mode_, false, + platform::errors::InvalidArgument( + "Async Executor does not support CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + platform::IsCUDAGraphCapturing(), false, + platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " + "when running the first batch.")); + PADDLE_ENFORCE_EQ( + member_->places_.size(), 1, + platform::errors::InvalidArgument( + "CUDA Graph is only supported when one GPU device is running.")); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(member_->places_[0]), true, + platform::errors::InvalidArgument( + "CUDA Graph is only supported on NVIDIA GPU device.")); + PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce, false, + platform::errors::InvalidArgument( + "FLAGS_sync_nccl_allreduce must be False to support " + "CUDA Graph capturing.")); + + std::unordered_map> all_vars; + for (auto &node : graph->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + auto *var_desc = node->Var(); + all_vars[var_desc->Name()].emplace_back(var_desc); + } + } + + auto mark_var_as_persistable = [&all_vars](const std::string &name) { + auto iter = all_vars.find(name); + if (iter != all_vars.end()) { + for (auto *var_desc : iter->second) { + var_desc->SetPersistable(true); + } + } + }; + + // Step 1: All fused vars must be persistable. + if (graph->Has(details::kFusedVars)) { + auto &fused_vars = graph->Get(details::kFusedVars); + for (auto &fused_var : fused_vars) { + fused_var.second.persistable_ = true; + mark_var_as_persistable(fused_var.first); + } + } + + // Step 2: All pinned vars must be persistable. + if (graph->Has(details::kPinnedVars)) { + auto &pinned_vars = graph->Get(details::kPinnedVars); + for (auto &pinned_var : pinned_vars) { + mark_var_as_persistable(pinned_var); + } + } + + // Step 3: Move all main programs to startup programs to make sure that + // the main programs would only be run once. + if (graph->Has(details::kProgramDescs)) { + auto &startup_programs = + graph->GetOrInit(details::kStartupProgramDescs); + auto &main_programs = + graph->Get(details::kProgramDescs); + for (auto &main_program : main_programs) { + startup_programs.emplace_back(main_program); + } + graph->Erase(details::kProgramDescs); + } + + // Step 4: Mark all vars in startup programs to be persistable. + if (graph->Has(details::kStartupProgramDescs)) { + auto &startup_programs = + graph->GetOrInit(details::kStartupProgramDescs); + for (auto &startup_program : startup_programs) { + for (auto &op_desc : startup_program.Block(0).AllOps()) { + for (auto &output : op_desc->OutputArgumentNames()) { + mark_var_as_persistable(output); + } + } + } + } + + // Step 5: ScaleLossGrad must be run beforehand to avoid H2D copy. + auto ops = ir::FilterByNodeWrapper(*graph); + auto *scope = member_->local_scopes_[0]; + for (auto *op : ops) { + auto *loss_grad_op = dynamic_cast(op); + if (loss_grad_op == nullptr) continue; + auto loss_grad_name = loss_grad_op->LossGradName(); + mark_var_as_persistable(loss_grad_name); + loss_grad_op->RunOnVar(scope->Var(loss_grad_name)); + loss_grad_op->SetSkipRunning(true); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported on NVIDIA GPU device.")); +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 6c871a8d85815..78774f0489638 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -144,6 +144,8 @@ class ParallelExecutor { void SetReaderOpDeviceInfoOfGraphs( const std::vector &final_graphs); + void PrepareForCUDAGraphCapture(ir::Graph *graph); + ParallelExecutorPrivate *member_; std::vector> async_graphs_; std::vector var_infos_; diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 6b4afae9f8c75..4aa1900f53f5e 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -82,7 +82,11 @@ endif() cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator) cc_test(test_aligned_allocator SRCS test_aligned_allocator.cc DEPS aligned_allocator) cc_library(allocator_strategy SRCS allocator_strategy.cc DEPS gflags ${AllocatorFacadeDeps}) -cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy ) +cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy) + +if (WITH_GPU) + target_link_libraries(allocator_facade cuda_graph) +endif() cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator) if (WITH_TESTING) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 78bce53b6f4ff..0388e2d13afb0 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -32,6 +32,9 @@ #include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/gpu_info.h" #endif +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_graph.h" +#endif #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu/xpu_info.h" #endif @@ -47,17 +50,64 @@ PADDLE_DEFINE_EXPORTED_bool( "Whether to use system allocator to allocate CPU and GPU memory. " "Only used for unittests."); +DECLARE_string(allocator_strategy); + namespace paddle { namespace memory { namespace allocation { +#ifdef PADDLE_WITH_CUDA +class CUDAGraphAllocator + : public Allocator, + public std::enable_shared_from_this { + private: + class PrivateAllocation : public Allocation { + public: + PrivateAllocation(CUDAGraphAllocator* allocator, + AllocationPtr underlying_allocation) + : Allocation(underlying_allocation->ptr(), + underlying_allocation->size(), + underlying_allocation->place()), + allocator_(allocator->shared_from_this()), + underlying_allocation_(std::move(underlying_allocation)) {} + + private: + std::shared_ptr allocator_; + AllocationPtr underlying_allocation_; + }; + + explicit CUDAGraphAllocator(const std::shared_ptr& allocator) + : underlying_allocator_(allocator) {} + + public: + static std::shared_ptr Create( + const std::shared_ptr& allocator) { + return std::shared_ptr(new CUDAGraphAllocator(allocator)); + } + + protected: + Allocation* AllocateImpl(size_t size) { + VLOG(10) << "Allocate " << size << " for CUDA Graph"; + return new PrivateAllocation(this, underlying_allocator_->Allocate(size)); + } + + void FreeImpl(Allocation* allocation) { + VLOG(10) << "delete for CUDA Graph"; + delete allocation; + } + + private: + std::shared_ptr underlying_allocator_; +}; +#endif + class AllocatorFacadePrivate { public: using AllocatorMap = std::map>; - AllocatorFacadePrivate() { - auto strategy = GetAllocatorStrategy(); - switch (strategy) { + explicit AllocatorFacadePrivate(bool allow_free_idle_chunk = true) { + strategy_ = GetAllocatorStrategy(); + switch (strategy_) { case AllocatorStrategy::kNaiveBestFit: { InitNaiveBestFitCPUAllocator(); #ifdef PADDLE_WITH_XPU @@ -91,7 +141,8 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { - InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id)); + InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id), + allow_free_idle_chunk); } InitNaiveBestFitCUDAPinnedAllocator(); #endif @@ -117,7 +168,7 @@ class AllocatorFacadePrivate { default: { PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported allocator strategy: %d", static_cast(strategy))); + "Unsupported allocator strategy: %d", static_cast(strategy_))); } } InitZeroSizeAllocators(); @@ -130,11 +181,29 @@ class AllocatorFacadePrivate { CheckAllocThreadSafe(); } + inline const AllocatorMap& GetAllocatorMap() { +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { + auto id = platform::CUDAGraph::CapturingID(); + auto iter = cuda_graph_allocator_map_.find(id); + PADDLE_ENFORCE_NE( + iter, cuda_graph_allocator_map_.end(), + platform::errors::PermissionDenied( + "No memory pool is prepared for CUDA Graph capturing.")); + return iter->second->allocators_; + } else { + return allocators_; + } +#else + return allocators_; +#endif + } + inline const std::shared_ptr& GetAllocator( const platform::Place& place, size_t size) { const auto& allocators = (size > 0 ? (UNLIKELY(FLAGS_use_system_allocator) ? system_allocators_ - : allocators_) + : GetAllocatorMap()) : zero_size_allocators_); auto iter = allocators.find(place); PADDLE_ENFORCE_NE(iter, allocators.end(), @@ -145,6 +214,7 @@ class AllocatorFacadePrivate { private: void InitSystemAllocators() { + if (!system_allocators_.empty()) return; system_allocators_[platform::CPUPlace()] = std::make_shared(); #ifdef PADDLE_WITH_XPU int device_count = platform::GetXPUDeviceCount(); @@ -183,10 +253,11 @@ class AllocatorFacadePrivate { allocators_[p] = std::make_shared(p); } - void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) { + void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, + bool allow_free_idle_chunk) { auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize()); + cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); } #endif @@ -226,6 +297,7 @@ class AllocatorFacadePrivate { }; void InitZeroSizeAllocators() { + if (!zero_size_allocators_.empty()) return; std::vector places; places.emplace_back(platform::CPUPlace()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -279,12 +351,57 @@ class AllocatorFacadePrivate { } } +#ifdef PADDLE_WITH_CUDA + + public: + void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { + PADDLE_ENFORCE_EQ(strategy_, AllocatorStrategy::kAutoGrowth, + platform::errors::InvalidArgument( + "CUDA Graph is only supported when the " + "FLAGS_allocator_strategy=\"auto_growth\", but got " + "FLAGS_allocator_strategy=\"%s\"", + FLAGS_allocator_strategy)); + auto& allocator = cuda_graph_allocator_map_[id]; + PADDLE_ENFORCE_EQ( + allocator.get(), nullptr, + platform::errors::InvalidArgument( + "The memory pool of the CUDA Graph with ID %d have been prepared.", + id)); + allocator.reset( + new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); + for (auto& item : allocator->allocators_) { + auto& old_allocator = item.second; + old_allocator = CUDAGraphAllocator::Create(old_allocator); + } + VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id; + } + + void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { + auto iter = cuda_graph_allocator_map_.find(id); + PADDLE_ENFORCE_NE(iter, cuda_graph_allocator_map_.end(), + platform::errors::InvalidArgument( + "Cannot find CUDA Graph with ID = %d", id)); + cuda_graph_allocator_map_.erase(iter); + VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id; + } +#endif + private: AllocatorMap allocators_; - AllocatorMap zero_size_allocators_; - AllocatorMap system_allocators_; +#ifdef PADDLE_WITH_CUDA + std::unordered_map> + cuda_graph_allocator_map_; +#endif + AllocatorStrategy strategy_; + + static AllocatorMap zero_size_allocators_; + static AllocatorMap system_allocators_; }; +AllocatorFacadePrivate::AllocatorMap + AllocatorFacadePrivate::zero_size_allocators_; +AllocatorFacadePrivate::AllocatorMap AllocatorFacadePrivate::system_allocators_; + // Pimpl. Make interface clean. AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {} // delete m_ may cause core dump when the destructor of python in conflict with @@ -316,6 +433,16 @@ const std::shared_ptr& AllocatorFacade::GetAllocator( return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); } +#ifdef PADDLE_WITH_CUDA +void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { + return m_->PrepareMemoryPoolForCUDAGraph(id); +} + +void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { + return m_->RemoveMemoryPoolOfCUDAGraph(id); +} +#endif + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 7f6ad561aa931..8d889ec38eed7 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -18,6 +18,9 @@ #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/memory/allocation/npu_pinned_allocator.h" #endif +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/gpu_info.h" +#endif #include "paddle/fluid/platform/place.h" namespace paddle { @@ -54,6 +57,11 @@ class AllocatorFacade { uint64_t Release(const platform::Place& place); const std::shared_ptr& GetAllocator(const platform::Place& place); +#ifdef PADDLE_WITH_CUDA + void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id); + void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id); +#endif + // TODO(yy): Allocate a Copy-On-Write allocation? private: AllocatorFacade(); diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc index a35d8a73f7eda..f36d589f907fb 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc @@ -39,11 +39,12 @@ namespace allocation { AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator( const std::shared_ptr &underlying_allocator, size_t alignment, - size_t chunk_size) + size_t chunk_size, bool allow_free_idle_chunk) : underlying_allocator_( std::make_shared(underlying_allocator, alignment)), alignment_(alignment), - chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)) {} + chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)), + allow_free_idle_chunk_(allow_free_idle_chunk) {} Allocation *AutoGrowthBestFitAllocator::AllocateImpl(size_t size) { size = AlignedSize(size, alignment_); @@ -139,6 +140,9 @@ void AutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) { } uint64_t AutoGrowthBestFitAllocator::FreeIdleChunks() { + if (!allow_free_idle_chunk_) { + return 0; + } uint64_t bytes = 0; for (auto chunk_it = chunks_.begin(); chunk_it != chunks_.end();) { auto &blocks = chunk_it->blocks_; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h index 5ed6eb94f158f..d1fa6cce0164f 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h @@ -31,7 +31,7 @@ class AutoGrowthBestFitAllocator : public Allocator { public: AutoGrowthBestFitAllocator( const std::shared_ptr &underlying_allocator, size_t alignment, - size_t chunk_size = 0); + size_t chunk_size = 0, bool allow_free_idle_chunk = true); bool IsAllocThreadSafe() const override { return true; } @@ -86,6 +86,7 @@ class AutoGrowthBestFitAllocator : public Allocator { std::list chunks_; size_t alignment_; size_t chunk_size_; + bool allow_free_idle_chunk_; SpinLock spinlock_; }; diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 4c0ef02074e2e..f4183bf570926 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/cudnn_desc.h" namespace paddle { namespace operators { @@ -480,6 +481,7 @@ struct SearchAlgorithm { static algo_t Find(const ConvArgs& args, bool exhaustive_search, bool deterministic, const framework::ExecutionContext& ctx) { + platform::CUDAGraphCaptureModeGuard guard; auto dtype = platform::CudnnDataType::type; size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; size_t workspace_size = 0; @@ -601,6 +603,7 @@ struct SearchAlgorithm { } static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + platform::CUDAGraphCaptureModeGuard guard; size_t workspace_size = 0; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index b08808f353d4a..e96d3acd9357f 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -18,7 +18,8 @@ register_operators(EXCLUDES fused_bn_add_activation_op fused_attention_op fused_feedforward_op - fused_transformer_op) + fused_transformer_op + resnet_unit_op) # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) @@ -86,4 +87,11 @@ if (WITH_GPU OR WITH_ROCM) op_library(fused_attention_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") endif() + # resnet_unit needs cudnn 8.0 above + if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) + op_library(resnet_unit_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n") + cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) + cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory) + endif() endif() diff --git a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc new file mode 100644 index 0000000000000..36477bb09a2ad --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc @@ -0,0 +1,784 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h" +#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/float16.h" + +DECLARE_bool(cudnn_batchnorm_spatial_persistent); + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace op = paddle::operators; +using Tensor = paddle::framework::Tensor; + +USE_OP(batch_norm); +USE_CUDA_ONLY_OP(fused_bn_add_activation); +USE_CUDA_ONLY_OP(fused_bn_add_activation_grad); + +template +void InitRandomTensor(const std::vector &dims, + framework::Tensor *cpu_out) { + T *cpu_out_ptr = cpu_out->mutable_data(framework::make_ddim(dims), + platform::CPUPlace()); + std::default_random_engine random(0); + std::uniform_real_distribution dis(-1.0, 1.0); + for (int i = 0; i < cpu_out->numel(); ++i) { + cpu_out_ptr[i] = static_cast(dis(random)); + } +} + +template +void InitConstantTensor(const std::vector &dims, T value, + framework::Tensor *cpu_out) { + T *cpu_out_ptr = cpu_out->mutable_data(framework::make_ddim(dims), + platform::CPUPlace()); + for (int i = 0; i < cpu_out->numel(); ++i) { + cpu_out_ptr[i] = value; + } +} + +template +void CheckOutput(std::string name, const framework::Tensor &cpu_res, + const framework::Tensor &cpu_base, float diff, + bool is_relative_atol = false) { + if (cpu_res.dims().size() == cpu_base.dims().size()) { + EXPECT_EQ(cpu_res.dims(), cpu_base.dims()); + } else { + EXPECT_EQ(cpu_res.numel(), cpu_base.numel()); + } + + const T *cpu_res_ptr = cpu_res.data(); + const T *cpu_base_ptr = cpu_base.data(); + float max_diff = 0; + int index = 0; + for (int i = 0; i < cpu_res.numel(); ++i) { + float cur_diff; + if (is_relative_atol) { + cur_diff = static_cast( + std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / cpu_base_ptr[i])); + EXPECT_LT(static_cast(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / + cpu_base_ptr[i])), + diff); + } else { + cur_diff = static_cast(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])); + EXPECT_LT(static_cast(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])), + diff); + } + if (cur_diff > max_diff) { + max_diff = cur_diff; + index = i; + } + } + std::string error_type = is_relative_atol ? "relative" : "absolute"; + LOG(INFO) << "[" << name << "] The dims is [" << cpu_res.dims() + << "], maximum " << error_type << " error is " << max_diff << ": " + << cpu_res_ptr[index] << " vs " << cpu_base_ptr[index]; +} + +template +void ComputeSumAndSquareSum(const framework::Tensor &cpu_x, + framework::Tensor *cpu_sum, + framework::Tensor *cpu_sum_of_square) { + // x is in NHWC format. + auto dims = cpu_x.dims(); + int64_t c = dims[3]; + + const T *cpu_x_ptr = cpu_x.data(); + float *cpu_sum_ptr = + cpu_sum->mutable_data({1, 1, 1, c}, platform::CPUPlace()); + float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data( + {1, 1, 1, c}, platform::CPUPlace()); + + for (int j = 0; j < c; ++j) { + float tmp_sum = 0.0f; + float tmp_sum_of_squares = 0.0f; + for (int i = 0; i < cpu_x.numel() / c; ++i) { + float tmp_x = static_cast(cpu_x_ptr[i * c + j]); + tmp_sum += tmp_x; + tmp_sum_of_squares += tmp_x * tmp_x; + } + cpu_sum_ptr[j] = tmp_sum; + cpu_sum_square_ptr[j] = tmp_sum_of_squares; + } +} + +template +void ComputeInplaceAdd(const framework::Tensor &cpu_x, + framework::Tensor *cpu_y) { + EXPECT_EQ(cpu_x.dims(), cpu_y->dims()); + + const T *cpu_x_ptr = cpu_x.data(); + T *cpu_y_ptr = cpu_y->data(); + for (int64_t i = 0; i < cpu_x.numel(); ++i) { + cpu_y_ptr[i] += cpu_x_ptr[i]; + } +} + +template +void ComputeInplaceRelu(framework::Tensor *cpu_x) { + T *cpu_x_ptr = cpu_x->data(); + for (int64_t i = 0; i < cpu_x->numel(); ++i) { + cpu_x_ptr[i] = + cpu_x_ptr[i] > static_cast(0) ? cpu_x_ptr[i] : static_cast(0); + } +} + +void ComputeBatchNormForward(const platform::CUDADeviceContext &ctx, + const Tensor &cpu_x, const Tensor &cpu_scale, + const Tensor &cpu_bias, Tensor *cpu_mean, + Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var, Tensor *cpu_y, + Tensor *saved_reserve_space) { + framework::Scope scope; + auto *x = scope.Var("X")->GetMutable(); + auto *scale = scope.Var("Scale")->GetMutable(); + auto *bias = scope.Var("Bias")->GetMutable(); + auto *mean = scope.Var("Mean")->GetMutable(); + auto *var = scope.Var("Variance")->GetMutable(); + auto *y = scope.Var("Y")->GetMutable(); + auto *saved_mean = scope.Var("SavedMean")->GetMutable(); + auto *saved_var = + scope.Var("SavedVariance")->GetMutable(); + auto *reserve_space = + scope.Var("ReserveSpace")->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_x, place, x); + TensorCopySync(cpu_scale, place, scale); + TensorCopySync(cpu_bias, place, bias); + TensorCopySync(*cpu_mean, place, mean); + TensorCopySync(*cpu_var, place, var); + + int64_t channels = x->dims()[3]; + scale->Resize({channels}); + bias->Resize({channels}); + mean->Resize({channels}); + var->Resize({channels}); + + framework::AttributeMap attrs; + std::string data_layout = "NHWC"; + attrs.insert({"data_layout", data_layout}); + + auto op = framework::OpRegistry::CreateOp( + "batch_norm", {{"X", {"X"}}, + {"Scale", {"Scale"}}, + {"Bias", {"Bias"}}, + {"Mean", {"Mean"}}, + {"Variance", {"Variance"}}}, + {{"Y", {"Y"}}, + {"MeanOut", {"Mean"}}, + {"VarianceOut", {"Variance"}}, + {"SavedMean", {"SavedMean"}}, + {"SavedVariance", {"SavedVariance"}}, + {"ReserveSpace", {"ReserveSpace"}}}, + attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*y, platform::CPUPlace(), cpu_y); + TensorCopySync(*mean, platform::CPUPlace(), cpu_mean); + TensorCopySync(*var, platform::CPUPlace(), cpu_var); + TensorCopySync(*saved_mean, platform::CPUPlace(), cpu_saved_mean); + TensorCopySync(*saved_var, platform::CPUPlace(), cpu_saved_var); + // reserved_space will stay on GPU and used in grad op. + saved_reserve_space->ShareDataWith(*reserve_space); +} + +void ComputeFusedBNAddReluForward(const platform::CUDADeviceContext &ctx, + const Tensor &cpu_x, const Tensor &cpu_z, + const Tensor &cpu_scale, + const Tensor &cpu_bias, Tensor *cpu_mean, + Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var, Tensor *cpu_y, + Tensor *saved_reserve_space) { + framework::Scope scope; + auto *x = scope.Var("X")->GetMutable(); + auto *z = scope.Var("Z")->GetMutable(); + auto *scale = scope.Var("Scale")->GetMutable(); + auto *bias = scope.Var("Bias")->GetMutable(); + auto *mean = scope.Var("Mean")->GetMutable(); + auto *var = scope.Var("Variance")->GetMutable(); + auto *y = scope.Var("Y")->GetMutable(); + auto *saved_mean = scope.Var("SavedMean")->GetMutable(); + auto *saved_var = + scope.Var("SavedVariance")->GetMutable(); + auto *reserve_space = + scope.Var("ReserveSpace")->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_x, place, x); + TensorCopySync(cpu_z, place, z); + TensorCopySync(cpu_scale, place, scale); + TensorCopySync(cpu_bias, place, bias); + TensorCopySync(*cpu_mean, place, mean); + TensorCopySync(*cpu_var, place, var); + + int64_t channels = x->dims()[3]; + scale->Resize({channels}); + bias->Resize({channels}); + mean->Resize({channels}); + var->Resize({channels}); + + framework::AttributeMap attrs; + + auto op = framework::OpRegistry::CreateOp( + "fused_bn_add_activation", + {{"X", {"X"}}, {"Z", {"Z"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}}, + {{"Y", {"Y"}}, + {"MeanOut", {"Mean"}}, + {"VarianceOut", {"Variance"}}, + {"SavedMean", {"SavedMean"}}, + {"SavedVariance", {"SavedVariance"}}, + {"ReserveSpace", {"ReserveSpace"}}}, + attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*y, platform::CPUPlace(), cpu_y); + TensorCopySync(*mean, platform::CPUPlace(), cpu_mean); + TensorCopySync(*var, platform::CPUPlace(), cpu_var); + TensorCopySync(*saved_mean, platform::CPUPlace(), cpu_saved_mean); + TensorCopySync(*saved_var, platform::CPUPlace(), cpu_saved_var); + // reserved_space will stay on GPU and used in grad op. + saved_reserve_space->ShareDataWith(*reserve_space); +} + +void ComputeFusedBNAddReluBackward( + const platform::CUDADeviceContext &ctx, const Tensor &cpu_dy, + const Tensor &cpu_x, const Tensor &cpu_scale, const Tensor &cpu_bias, + const Tensor &cpu_saved_mean, const Tensor &cpu_saved_var, + const Tensor &cpu_y, const Tensor &saved_reserve_space, Tensor *cpu_dx, + Tensor *cpu_dz, Tensor *cpu_dscale, Tensor *cpu_dbias) { + framework::Scope scope; + auto *x = scope.Var("X")->GetMutable(); + auto *y = scope.Var("Y")->GetMutable(); + auto *dy = scope.Var("Y@GRAD")->GetMutable(); + auto *scale = scope.Var("Scale")->GetMutable(); + auto *bias = scope.Var("Bias")->GetMutable(); + auto *saved_mean = scope.Var("SavedMean")->GetMutable(); + auto *saved_var = + scope.Var("SavedVariance")->GetMutable(); + auto *reserve_space = + scope.Var("ReserveSpace")->GetMutable(); + auto *dx = scope.Var("X@GRAD")->GetMutable(); + auto *dz = scope.Var("Z@GRAD")->GetMutable(); + auto *dscale = scope.Var("Scale@GRAD")->GetMutable(); + auto *dbias = scope.Var("Bias@GRAD")->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_x, place, x); + TensorCopySync(cpu_y, place, y); + TensorCopySync(cpu_dy, place, dy); + TensorCopySync(cpu_scale, place, scale); + TensorCopySync(cpu_bias, place, bias); + TensorCopySync(cpu_saved_mean, place, saved_mean); + TensorCopySync(cpu_saved_var, place, saved_var); + reserve_space->ShareDataWith(saved_reserve_space); + + int64_t channels = x->dims()[3]; + scale->Resize({channels}); + bias->Resize({channels}); + saved_mean->Resize({channels}); + saved_var->Resize({channels}); + + framework::AttributeMap attrs; + float momentum = 0.9; + float epsilon = 1e-5; + std::string act_type = "relu"; + attrs.insert({"momentum", momentum}); + attrs.insert({"epsilon", epsilon}); + attrs.insert({"act_type", act_type}); + + auto op = framework::OpRegistry::CreateOp( + "fused_bn_add_activation_grad", {{"X", {"X"}}, + {"Y", {"Y"}}, + {"Y@GRAD", {"Y@GRAD"}}, + {"Scale", {"Scale"}}, + {"Bias", {"Bias"}}, + {"SavedMean", {"SavedMean"}}, + {"SavedVariance", {"SavedVariance"}}, + {"ReserveSpace", {"ReserveSpace"}}}, + {{"X@GRAD", {"X@GRAD"}}, + {"Z@GRAD", {"Z@GRAD"}}, + {"Scale@GRAD", {"Scale@GRAD"}}, + {"Bias@GRAD", {"Bias@GRAD"}}}, + attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*dx, platform::CPUPlace(), cpu_dx); + TensorCopySync(*dz, platform::CPUPlace(), cpu_dz); + TensorCopySync(*dscale, platform::CPUPlace(), cpu_dscale); + TensorCopySync(*dbias, platform::CPUPlace(), cpu_dbias); +} + +template +class CudnnBNAddReluTester { + public: + CudnnBNAddReluTester(int batch_size, int height, int width, int channels, + std::string act_type, bool fuse_add, bool has_shortcut) { + batch_size_ = batch_size; + height_ = height; + width_ = width; + channels_ = channels; + ele_count_ = batch_size_ * height_ * width_; + act_type_ = act_type; + fuse_add_ = fuse_add; + has_shortcut_ = has_shortcut; + SetUp(); + } + + ~CudnnBNAddReluTester() {} + + void CheckForward(float diff, bool is_relative_atol = false) { + LOG(INFO) << "[CheckForward, diff=" << diff + << ", is_relative_atol=" << is_relative_atol + << "] act_type=" << act_type_ << ", fuse_add=" << fuse_add_ + << ", has_shortcut=" << has_shortcut_; + platform::CUDADeviceContext *ctx = + static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(0))); + + auto select = [&](Tensor *in) { return has_shortcut_ ? in : nullptr; }; + + framework::Tensor cpu_mean_base_x; + framework::Tensor cpu_var_base_x; + framework::Tensor cpu_mean_base_z; + framework::Tensor cpu_var_base_z; + if (!has_shortcut_ && fuse_add_ && (act_type_ == "relu")) { + BaselineForwardFusedBNAddRelu( + *ctx, &cpu_mean_base_x, &cpu_var_base_x, &cpu_saved_mean_base_x_, + &cpu_saved_var_base_x_, &cpu_y_base_, &saved_reserve_space_x_); + } else { + BaselineForward( + *ctx, &cpu_mean_base_x, &cpu_var_base_x, &cpu_saved_mean_base_x_, + &cpu_saved_var_base_x_, &cpu_y_base_, &saved_reserve_space_x_, + select(&cpu_mean_base_z), select(&cpu_var_base_z), + select(&cpu_saved_mean_base_z_), select(&cpu_saved_var_base_z_), + select(&saved_reserve_space_z_)); + } + + framework::Tensor cpu_mean_x; + framework::Tensor cpu_var_x; + framework::Tensor cpu_y; + framework::Tensor cpu_mean_z; + framework::Tensor cpu_var_z; + FusedForward(*ctx, &cpu_mean_x, &cpu_var_x, &cpu_saved_mean_x_, + &cpu_saved_var_x_, &cpu_y, &cpu_bitmask_, select(&cpu_mean_z), + select(&cpu_var_z), select(&cpu_saved_mean_z_), + select(&cpu_saved_var_z_)); + + CheckOutput("Mean", cpu_mean_x, cpu_mean_base_x, diff, + is_relative_atol); + CheckOutput("Variance", cpu_var_x, cpu_var_base_x, diff, + is_relative_atol); + CheckOutput("SavedMean", cpu_saved_mean_x_, cpu_saved_mean_base_x_, + diff, is_relative_atol); + CheckOutput("SavedVariance", cpu_saved_var_x_, cpu_saved_var_base_x_, + diff, is_relative_atol); + if (has_shortcut_) { + CheckOutput("MeanZ", cpu_mean_z, cpu_mean_base_z, diff, + is_relative_atol); + CheckOutput("VarianceZ", cpu_var_z, cpu_var_base_z, diff, + is_relative_atol); + CheckOutput("SavedMeanZ", cpu_saved_mean_z_, + cpu_saved_mean_base_z_, diff, is_relative_atol); + CheckOutput("SavedVarianceZ", cpu_saved_var_z_, + cpu_saved_var_base_z_, diff, is_relative_atol); + } + CheckOutput("Y", cpu_y, cpu_y_base_, diff, is_relative_atol); + } + + void CheckBackward(float diff, bool is_relative_atol = false) { + platform::CUDADeviceContext *ctx = + static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(0))); + + framework::Tensor cpu_dx_base; + framework::Tensor cpu_dz_base; + framework::Tensor cpu_dscale_base; + framework::Tensor cpu_dbias_base; + BaselineBackwardFusedBNAddRelu(*ctx, &cpu_dx_base, &cpu_dz_base, + &cpu_dscale_base, &cpu_dbias_base); + + framework::Tensor cpu_dx; + framework::Tensor cpu_dz; + framework::Tensor cpu_dscale; + framework::Tensor cpu_dbias; + FusedBackward(*ctx, &cpu_dx, &cpu_dz, &cpu_dscale, &cpu_dbias); + + CheckOutput("DX", cpu_dx, cpu_dx_base, diff, is_relative_atol); + CheckOutput("DZ", cpu_dz, cpu_dz_base, diff, is_relative_atol); + CheckOutput("DScale", cpu_dscale, cpu_dscale_base, diff, + is_relative_atol); + CheckOutput("DBias", cpu_dbias, cpu_dbias_base, diff, + is_relative_atol); + } + + private: + void SetUp() { + InitRandomTensor({batch_size_, height_, width_, channels_}, &cpu_x_); + InitRandomTensor({channels_}, &cpu_bn_scale_x_); + InitRandomTensor({channels_}, &cpu_bn_bias_x_); + + if (has_shortcut_) { + InitRandomTensor({batch_size_, height_, width_, channels_}, &cpu_z_); + InitRandomTensor({channels_}, &cpu_bn_scale_z_); + InitRandomTensor({channels_}, &cpu_bn_bias_z_); + } else { + if (fuse_add_) { + InitRandomTensor({batch_size_, height_, width_, channels_}, &cpu_z_); + } + } + + InitRandomTensor({batch_size_, height_, width_, channels_}, &cpu_dy_); + } + + void InitMeanVar(Tensor *cpu_mean, Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var) { + InitConstantTensor({channels_}, static_cast(0.0f), cpu_mean); + InitConstantTensor({channels_}, static_cast(1.0f), cpu_var); + InitConstantTensor({channels_}, static_cast(0.0f), + cpu_saved_mean); + InitConstantTensor({channels_}, static_cast(0.0f), + cpu_saved_var); + } + + void BaselineForward(const platform::CUDADeviceContext &ctx, + Tensor *cpu_mean_x, Tensor *cpu_var_x, + Tensor *cpu_saved_mean_x, Tensor *cpu_saved_var_x, + Tensor *cpu_y, Tensor *saved_reserve_space_x, + Tensor *cpu_mean_z = nullptr, + Tensor *cpu_var_z = nullptr, + Tensor *cpu_saved_mean_z = nullptr, + Tensor *cpu_saved_var_z = nullptr, + Tensor *saved_reserve_space_z = nullptr) { + InitMeanVar(cpu_mean_x, cpu_var_x, cpu_saved_mean_x, cpu_saved_var_x); + ComputeBatchNormForward(ctx, cpu_x_, cpu_bn_scale_x_, cpu_bn_bias_x_, + cpu_mean_x, cpu_var_x, cpu_saved_mean_x, + cpu_saved_var_x, cpu_y, saved_reserve_space_x); + if (has_shortcut_) { + framework::Tensor cpu_z_out; + InitMeanVar(cpu_mean_z, cpu_var_z, cpu_saved_mean_z, cpu_saved_var_z); + ComputeBatchNormForward( + ctx, cpu_z_, cpu_bn_scale_z_, cpu_bn_bias_z_, cpu_mean_z, cpu_var_z, + cpu_saved_mean_z, cpu_saved_var_z, &cpu_z_out, saved_reserve_space_z); + ComputeInplaceAdd(cpu_z_out, cpu_y); + } else { + if (fuse_add_) { + ComputeInplaceAdd(cpu_z_, cpu_y); + } + } + if (act_type_ == "relu") { + ComputeInplaceRelu(cpu_y); + } + } + + void BaselineForwardFusedBNAddRelu(const platform::CUDADeviceContext &ctx, + Tensor *cpu_mean, Tensor *cpu_var, + Tensor *cpu_saved_mean, + Tensor *cpu_saved_var, Tensor *cpu_y, + Tensor *saved_reserve_space) { + InitMeanVar(cpu_mean, cpu_var, cpu_saved_mean, cpu_saved_var); + ComputeFusedBNAddReluForward( + ctx, cpu_x_, cpu_z_, cpu_bn_scale_x_, cpu_bn_bias_x_, cpu_mean, cpu_var, + cpu_saved_mean, cpu_saved_var, cpu_y, saved_reserve_space); + } + + void BaselineBackwardFusedBNAddRelu(const platform::CUDADeviceContext &ctx, + Tensor *cpu_dx, Tensor *cpu_dz, + Tensor *cpu_dscale, Tensor *cpu_dbias) { + ComputeFusedBNAddReluBackward( + ctx, cpu_dy_, cpu_x_, cpu_bn_scale_x_, cpu_bn_bias_x_, + cpu_saved_mean_base_x_, cpu_saved_var_base_x_, cpu_y_base_, + saved_reserve_space_x_, cpu_dx, cpu_dz, cpu_dscale, cpu_dbias); + } + + void ComputeFusedBNStatsFinalize(const platform::CUDADeviceContext &ctx, + const Tensor &cpu_x, + const Tensor &cpu_bn_scale, + const Tensor &cpu_bn_bias, Tensor *sum, + Tensor *sum_of_square, Tensor *bn_scale, + Tensor *bn_bias, Tensor *mean, Tensor *var, + Tensor *saved_mean, Tensor *saved_var, + Tensor *equiv_scale, Tensor *equiv_bias) { + framework::Tensor cpu_sum; + framework::Tensor cpu_sum_of_square; + ComputeSumAndSquareSum(cpu_x, &cpu_sum, &cpu_sum_of_square); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_sum, place, sum); + TensorCopySync(cpu_sum_of_square, place, sum_of_square); + TensorCopySync(cpu_bn_scale, place, bn_scale); + TensorCopySync(cpu_bn_bias, place, bn_bias); + + bn_scale->Resize({1, 1, 1, channels_}); + bn_bias->Resize({1, 1, 1, channels_}); + + // input + mean->Resize({1, 1, 1, channels_}); + var->Resize({1, 1, 1, channels_}); + + // output + equiv_scale->Resize({1, 1, 1, channels_}); + equiv_bias->Resize({1, 1, 1, channels_}); + saved_mean->Resize({1, 1, 1, channels_}); + saved_var->Resize({1, 1, 1, channels_}); + + auto param_shape = framework::vectorize(bn_scale->dims()); + op::CudnnBNStatsFinalize bn_op(ctx, param_shape); + bn_op.Forward(ctx, *sum, *sum_of_square, *bn_scale, *bn_bias, saved_mean, + saved_var, mean, var, equiv_scale, equiv_bias, eps_, + momentum_, ele_count_, true); + } + + // Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu + void FusedForward(const platform::CUDADeviceContext &ctx, Tensor *cpu_mean_x, + Tensor *cpu_var_x, Tensor *cpu_saved_mean_x, + Tensor *cpu_saved_var_x, Tensor *cpu_y, Tensor *cpu_bitmask, + Tensor *cpu_mean_z = nullptr, Tensor *cpu_var_z = nullptr, + Tensor *cpu_saved_mean_z = nullptr, + Tensor *cpu_saved_var_z = nullptr) { + framework::Tensor x; + framework::Tensor sum_x; + framework::Tensor sum_of_square_x; + framework::Tensor bn_scale_x; + framework::Tensor bn_bias_x; + + framework::Tensor z; + framework::Tensor sum_z; + framework::Tensor sum_of_square_z; + framework::Tensor bn_scale_z; + framework::Tensor bn_bias_z; + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_x_, place, &x); + if (fuse_add_ || has_shortcut_) { + TensorCopySync(cpu_z_, place, &z); + } + + framework::Tensor mean_x; + framework::Tensor var_x; + framework::Tensor saved_mean_x; + framework::Tensor saved_var_x; + framework::Tensor equiv_scale_x; + framework::Tensor equiv_bias_x; + + framework::Tensor mean_z; + framework::Tensor var_z; + framework::Tensor saved_mean_z; + framework::Tensor saved_var_z; + framework::Tensor equiv_scale_z; + framework::Tensor equiv_bias_z; + + framework::Tensor y; + framework::Tensor bitmask; + + InitMeanVar(cpu_mean_x, cpu_var_x, cpu_saved_mean_x, cpu_saved_var_x); + TensorCopySync(*cpu_mean_x, place, &mean_x); + TensorCopySync(*cpu_var_x, place, &var_x); + if (has_shortcut_) { + InitMeanVar(cpu_mean_z, cpu_var_z, cpu_saved_mean_z, cpu_saved_var_z); + TensorCopySync(*cpu_mean_z, place, &mean_z); + TensorCopySync(*cpu_var_z, place, &var_z); + } + + // 1. BN Stats Finalize + ComputeFusedBNStatsFinalize(ctx, cpu_x_, cpu_bn_scale_x_, cpu_bn_bias_x_, + &sum_x, &sum_of_square_x, &bn_scale_x, + &bn_bias_x, &mean_x, &var_x, &saved_mean_x, + &saved_var_x, &equiv_scale_x, &equiv_bias_x); + if (has_shortcut_) { + ComputeFusedBNStatsFinalize(ctx, cpu_z_, cpu_bn_scale_z_, cpu_bn_bias_z_, + &sum_z, &sum_of_square_z, &bn_scale_z, + &bn_bias_z, &mean_z, &var_z, &saved_mean_z, + &saved_var_z, &equiv_scale_z, &equiv_bias_z); + } + + y.Resize(framework::make_ddim({batch_size_, height_, width_, channels_})); + + int c = channels_; + int64_t nhw = ele_count_; + int32_t c_int32_elems = ((c + 63) & ~63) / 32; + int32_t nhw_int32_elems = (nhw + 31) & ~31; + bitmask.Resize(framework::make_ddim({nhw_int32_elems, c_int32_elems, 1})); + + auto data_shape = framework::vectorize(x.dims()); + auto param_shape = framework::vectorize(bn_scale_x.dims()); + auto bitmask_shape = framework::vectorize(bitmask.dims()); + + // 2. Scale Bias + Relu + op::CudnnScaleBiasAddRelu sbar_op(ctx, act_type_, fuse_add_, + has_shortcut_, data_shape, param_shape, + bitmask_shape); + sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, &z, &equiv_scale_z, + &equiv_bias_z, &y, &bitmask); + + TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x); + TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x); + TensorCopySync(saved_mean_x, platform::CPUPlace(), cpu_saved_mean_x); + TensorCopySync(saved_var_x, platform::CPUPlace(), cpu_saved_var_x); + if (has_shortcut_) { + TensorCopySync(mean_z, platform::CPUPlace(), cpu_mean_z); + TensorCopySync(var_z, platform::CPUPlace(), cpu_var_z); + TensorCopySync(saved_mean_z, platform::CPUPlace(), cpu_saved_mean_z); + TensorCopySync(saved_var_z, platform::CPUPlace(), cpu_saved_var_z); + } + TensorCopySync(y, platform::CPUPlace(), cpu_y); + TensorCopySync(bitmask, platform::CPUPlace(), cpu_bitmask); + } + + // Get backward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu + void FusedBackward(const platform::CUDADeviceContext &ctx, Tensor *cpu_dx, + Tensor *cpu_dz, Tensor *cpu_dscale, Tensor *cpu_dbias) { + framework::Tensor dy; + framework::Tensor x; + framework::Tensor bn_scale; + framework::Tensor bn_bias; + framework::Tensor saved_mean; + framework::Tensor saved_var; + framework::Tensor bitmask; + framework::Tensor dx; + framework::Tensor dz; + framework::Tensor dscale; + framework::Tensor dbias; + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_dy_, place, &dy); + TensorCopySync(cpu_x_, place, &x); + TensorCopySync(cpu_bn_scale_x_, place, &bn_scale); + TensorCopySync(cpu_bn_bias_x_, place, &bn_bias); + TensorCopySync(cpu_saved_mean_x_, place, &saved_mean); + TensorCopySync(cpu_saved_var_x_, place, &saved_var); + TensorCopySync(cpu_bitmask_, place, &bitmask); + + bn_scale.Resize({1, 1, 1, channels_}); + bn_bias.Resize({1, 1, 1, channels_}); + saved_mean.Resize({1, 1, 1, channels_}); + saved_var.Resize({1, 1, 1, channels_}); + + dx.Resize(framework::make_ddim({batch_size_, height_, width_, channels_})); + dz.Resize(framework::make_ddim({batch_size_, height_, width_, channels_})); + dscale.Resize(framework::make_ddim({1, 1, 1, channels_})); + dbias.Resize(framework::make_ddim({1, 1, 1, channels_})); + + auto data_shape = framework::vectorize(x.dims()); + auto param_shape = framework::vectorize(bn_scale.dims()); + auto bitmask_shape = framework::vectorize(bitmask.dims()); + + std::string act_type = "relu"; + op::CudnnScaleBiasAddRelu sbar_op(ctx, act_type, true, false, data_shape, + param_shape, bitmask_shape); + sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var, + &bitmask, &dx, &dz, &dscale, &dbias, eps_); + + TensorCopySync(dx, platform::CPUPlace(), cpu_dx); + TensorCopySync(dz, platform::CPUPlace(), cpu_dz); + TensorCopySync(dscale, platform::CPUPlace(), cpu_dscale); + TensorCopySync(dbias, platform::CPUPlace(), cpu_dbias); + } + + private: + int batch_size_; + int height_; + int width_; + int channels_; + int ele_count_; + + std::string act_type_; + bool fuse_add_; + bool has_shortcut_; + + // Forward input + framework::Tensor cpu_x_; + framework::Tensor cpu_bn_scale_x_; + framework::Tensor cpu_bn_bias_x_; + framework::Tensor cpu_z_; + framework::Tensor cpu_bn_scale_z_; + framework::Tensor cpu_bn_bias_z_; + + // Backward input + framework::Tensor cpu_dy_; + framework::Tensor cpu_bitmask_; + framework::Tensor cpu_saved_mean_x_; + framework::Tensor cpu_saved_var_x_; + framework::Tensor cpu_saved_mean_z_; + framework::Tensor cpu_saved_var_z_; + framework::Tensor cpu_saved_mean_base_x_; + framework::Tensor cpu_saved_var_base_x_; + framework::Tensor saved_reserve_space_x_; + framework::Tensor cpu_saved_mean_base_z_; + framework::Tensor cpu_saved_var_base_z_; + framework::Tensor saved_reserve_space_z_; + framework::Tensor cpu_y_base_; + + double eps_ = 1e-5; + float momentum_ = 0.9; +}; + +TEST(CudnnBNAddReluFp16, BNAdd) { + int batch_size = 4; + int height = 8; + int width = 8; + int channels = 64; + std::string act_type = ""; + bool has_shortcut = false; + FLAGS_cudnn_batchnorm_spatial_persistent = true; + for (auto fuse_add : {false, true}) { + CudnnBNAddReluTester test( + batch_size, height, width, channels, act_type, fuse_add, has_shortcut); + test.CheckForward(2e-3); + } +} + +TEST(CudnnBNAddReluFp16, BNAddRelu) { + int batch_size = 4; + int height = 8; + int width = 8; + int channels = 64; + std::string act_type = "relu"; + bool has_shortcut = false; + FLAGS_cudnn_batchnorm_spatial_persistent = true; + for (auto fuse_add : {false, true}) { + CudnnBNAddReluTester test( + batch_size, height, width, channels, act_type, fuse_add, has_shortcut); + test.CheckForward(2e-3); + if (fuse_add) { + test.CheckBackward(2e-4); + } + } +} + +TEST(CudnnBNAddReluFp16, HasShortcut) { + int batch_size = 4; + int height = 8; + int width = 8; + int channels = 64; + std::string act_type = ""; + bool fuse_add = false; + bool has_shortcut = true; + FLAGS_cudnn_batchnorm_spatial_persistent = true; + CudnnBNAddReluTester test( + batch_size, height, width, channels, act_type, fuse_add, has_shortcut); + test.CheckForward(5e-3); +} diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h new file mode 100644 index 0000000000000..1b995e1313f47 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -0,0 +1,193 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +namespace dynload = platform::dynload; +template +using BatchNormParamType = + typename platform::CudnnDataType::BatchNormParamType; + +#if CUDNN_VERSION >= 8000 + +template +struct BNStatsFinalizeArgs { + BNStatsFinalizeArgs() { + dtype = platform::CudnnDataType::type; + param_dtype = platform::CudnnDataType>::type; + format = CUDNN_TENSOR_NHWC; + } + + void Set(const std::vector ¶m_shape) { + PADDLE_ENFORCE_EQ( + param_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of param_shape is expected to 4. But recieved " + "param_shape's size is %d, param_shape is [%s].", + param_shape.size(), framework::make_ddim(param_shape))); + + in_desc.set(param_shape, format, param_dtype); + out_desc.set(param_shape, format, dtype); + } + + cudnnDataType_t dtype; + cudnnDataType_t param_dtype; + cudnnTensorFormat_t format; + + platform::TensorDescriptor in_desc; + platform::TensorDescriptor out_desc; +}; + +template +class CudnnBNStatsFinalize { + public: + CudnnBNStatsFinalize(const platform::CUDADeviceContext &ctx, + const std::vector ¶m_shape) + : train_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING), + inference_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE) { + args_.Set(param_shape); + } + ~CudnnBNStatsFinalize() {} + + void Forward(const platform::CUDADeviceContext &ctx, const Tensor &sum, + const Tensor &sum_of_squares, const Tensor &scale, + const Tensor &bias, Tensor *saved_mean, Tensor *saved_invstd, + Tensor *running_mean, Tensor *running_var, Tensor *equiv_scale, + Tensor *equiv_bias, double eps, float momentum, + int64_t ele_count, bool is_train) { + auto place = ctx.GetPlace(); + if (is_train) { + TrainInit(ctx); + } else { + InferenceInit(ctx); + } + auto &op = is_train ? train_op_ : inference_op_; + + // Set variant_param for both inference_op_ and train_op_ + float *sum_ptr = const_cast(sum.data()); + float *sum_of_squares_ptr = + const_cast(sum_of_squares.data()); + float *scale_ptr = const_cast(scale.data()); + float *bias_ptr = const_cast(bias.data()); + float *saved_mean_ptr = saved_mean->mutable_data(place); + float *saved_invstd_ptr = saved_invstd->mutable_data(place); + float *running_mean_ptr = running_mean->mutable_data(place); + float *running_var_ptr = running_var->mutable_data(place); + T *equiv_scale_ptr = equiv_scale->mutable_data(place); + T *equiv_bias_ptr = equiv_bias->mutable_data(place); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_VAR, running_var_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, equiv_scale_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, equiv_bias_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EPSILON, &eps); + + // Set extra variant_param only for train_op_: + if (is_train) { + op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, saved_invstd_ptr); + double avg_factor = 1.0 - momentum; + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT, + &ele_count); + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR, + &avg_factor); + } + // fused op execute + auto handle = ctx.cudnn_handle(); + op.Execute(handle); + } + + private: + void TrainInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param for train op + train_op_.SetOpConstParamAttr( + {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER, + CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER, + CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + // Set input and output desc for train op + train_op_.SetOpConstParamDesc( + {CUDNN_PARAM_YSTATS_DESC, CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC}, + args_.in_desc.desc()); + train_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + args_.out_desc.desc()); + + // Get workspace + auto handle = ctx.cudnn_handle(); + train_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + // Check workspace size, also creates plan. + size_t workspace_size_bytes = train_op_.GetWorkspaceSizeInBytes(handle); + PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U, + platform::errors::InvalidArgument( + "Unexpected non-zero workspace size for " + "CudnnBNStatsFinalize.")); + train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + static_cast(nullptr)); + train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + &workspace_size_bytes); + } + + void InferenceInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param for inference op + inference_op_.SetOpConstParamAttr( + {CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER, + CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + // Set input and output desc for inference op + inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC, + args_.in_desc.desc()); + inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + args_.out_desc.desc()); + + // Get workspace + auto handle = ctx.cudnn_handle(); + inference_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + // Check workspace size, also creates plan. + size_t workspace_size_bytes = inference_op_.GetWorkspaceSizeInBytes(handle); + PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U, + platform::errors::InvalidArgument( + "Unexpected non-zero workspace size for " + "CudnnBNStatsFinalize.")); + inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + static_cast(nullptr)); + inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + &workspace_size_bytes); + } + + BNStatsFinalizeArgs args_; + CudnnFusionOp train_op_; + CudnnFusionOp inference_op_; +}; +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_fusion_helper.h b/paddle/fluid/operators/fused/cudnn_fusion_helper.h new file mode 100644 index 0000000000000..78655416ffbc2 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_fusion_helper.h @@ -0,0 +1,167 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/operator_kernel_configs.h" +#include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +namespace dynload = platform::dynload; + +#if CUDNN_VERSION >= 8000 + +// A wrapper for cuDNN fused_op API. +class CudnnFusionOp { + public: + explicit CudnnFusionOp(cudnnFusedOps_t op_id) : plan_created_(false) { + // New 'fused op' descriptor creation + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateFusedOpsPlan(&op_, op_id)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnCreateFusedOpsConstParamPack(&op_const_params_, op_id)); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreateFusedOpsVariantParamPack( + &op_variant_params_, op_id)); + } + + ~CudnnFusionOp() PADDLE_MAY_THROW { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_)); + } + + // Execute fused op + void Execute(cudnnHandle_t cudnn_handle) { + PADDLE_ENFORCE_EQ( + plan_created_, true, + platform::errors::Fatal( + "CudnnFusionOp exec requested without a valid 'plan', need: " + ", GetWorkspaceSizeBytes(), Execute().")); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnFusedOpsExecute(cudnn_handle, op_, op_variant_params_)); + } + + // Set const param pack attribute given a descriptor. + template + void SetOpConstParamDesc(cudnnFusedOpsConstParamLabel_t param_label, + T *param_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetFusedOpsConstParamPackAttribute( + op_const_params_, param_label, param_ptr)); + plan_created_ = false; + } + + // Set multiple const param pack attribute given a descriptor. + template + void SetOpConstParamDesc( + const std::vector ¶m_labels, + T *param_ptr) { + for (auto param_label : param_labels) { + SetOpConstParamDesc(param_label, param_ptr); + } + } + + // Set const param pack attribute given a value of param. + template + void SetOpConstParamAttr(cudnnFusedOpsConstParamLabel_t param_label, + T param) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetFusedOpsConstParamPackAttribute(op_const_params_, + param_label, ¶m)); + plan_created_ = false; + } + + // Set multiple const param pack attribute given a value of param. + template + void SetOpConstParamAttr( + const std::vector ¶m_labels, + T param) { + for (auto param_label : param_labels) { + SetOpConstParamAttr(param_label, param); + } + } + + // Set a variant param pack attribute given a reference to a param. + template + void SetOpVariantParamAttrPtr(cudnnFusedOpsVariantParamLabel_t param_label, + T *param_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetFusedOpsVariantParamPackAttribute( + op_variant_params_, param_label, param_ptr)); + } + + // Set multiple const param pack attributes given a reference to a param. + template + void SetOpVariantParamAttrPtr( + const std::vector ¶m_labels, + const T *param_ptr) { + for (auto param_label : param_labels) { + SetOpVariantParamAttrPtr(param_label, param_ptr); + } + } + + // Get the workspace, which is required before Execute(). + size_t GetWorkspaceSizeInBytes(cudnnHandle_t cudnn_handle) { + if (!plan_created_) { + workspace_bytes_ = 0U; + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan( + cudnn_handle, op_, op_const_params_, &workspace_bytes_)); + plan_created_ = true; + } + return workspace_bytes_; + } + + private: + bool plan_created_; + size_t workspace_bytes_; + + cudnnFusedOpsPlan_t op_; + cudnnFusedOpsConstParamPack_t op_const_params_; + cudnnFusedOpsVariantParamPack_t op_variant_params_; +}; + +class CudnnFusionOpCache { + public: + static CudnnFusionOpCache &Instance() { + static CudnnFusionOpCache instance; + return instance; + } + + framework::AlgorithmsCache *GetForward() { + return &forward_cache_; + } + framework::AlgorithmsCache *GetBackward() { + return &backward_cache_; + } + + private: + CudnnFusionOpCache() {} + ~CudnnFusionOpCache() { + // Need to delete the memory of cache. + } + CudnnFusionOpCache(const CudnnFusionOpCache &) {} + + private: + framework::AlgorithmsCache forward_cache_; + framework::AlgorithmsCache backward_cache_; +}; + +#endif // CUDNN_VERSION >= 8000 +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h new file mode 100644 index 0000000000000..71ff12d6a35f2 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h @@ -0,0 +1,385 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +namespace dynload = platform::dynload; + +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; + +#if CUDNN_VERSION >= 8000 + +static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; } + +template +struct NormConvolutionArgs { + NormConvolutionArgs() { + dtype = platform::CudnnDataType::type; + format = CUDNN_TENSOR_NHWC; + compute_type = platform::CudnnDataType::type; + } + + void Set(const platform::CUDADeviceContext &ctx, + const std::vector &input_shape, + const std::vector &filter_shape, + const std::vector &output_shape, int padding, int stride, + int dilation, int group) { + PADDLE_ENFORCE_EQ( + input_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of input_shape is expected to 4. But recieved " + "input_shape's size is %d, input_shape is [%s].", + input_shape.size(), framework::make_ddim(input_shape))); + PADDLE_ENFORCE_EQ( + filter_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of filter_shape is expected to 4. But recieved " + "filter_shape's size is %d, filter_shape is [%s].", + filter_shape.size(), framework::make_ddim(filter_shape))); + PADDLE_ENFORCE_EQ(filter_shape[1] == filter_shape[2] && + (filter_shape[1] == 1 || filter_shape[1] == 3), + true, + platform::errors::InvalidArgument( + "The filter_shape is expected to store as nhwc, and " + "h = w = 1 or 3. But recieved filter_shape is [%s].", + framework::make_ddim(filter_shape))); + PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0), + true, + platform::errors::InvalidArgument( + "The input channel is expected to be multiple of 8, " + "and the output channel is expected to be multiple " + "of 32. But recieved input channel is %d, output " + "channel is %d.", + filter_shape[3], filter_shape[0])); + PADDLE_ENFORCE_EQ( + output_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of output_shape is expected to 4. But recieved " + "filter_shape's size is %d, filter_shape is [%s].", + output_shape.size(), framework::make_ddim(output_shape))); + is_support = IsSupport(ctx, filter_shape, stride, dilation, group); + PADDLE_ENFORCE_EQ( + is_support, true, + platform::errors::InvalidArgument( + "Current test is only supported in the platforms with " + "compatiblity greater than or equal to 70 and the kernel size " + "must be equal to 1 or 3. When the kernel size is 1, " + "the stride must be 1 if the compatiblity is equal to 70. " + "Besides, the dilation and group must be equal to 1. But recieved " + "compatiblity is %d, kernel size is %d, stride is %d, " + "dilation is %d, group is %d", + ctx.GetComputeCapability(), filter_shape[1], stride, dilation, + group)); + + for (size_t i = 0; i < input_shape.size(); ++i) { + in_dims.push_back(input_shape[i]); + } + for (size_t i = 0; i < filter_shape.size(); ++i) { + filter_dims.push_back(filter_shape[i]); + } + paddings = {padding, padding}; + strides = {stride, stride}; + dilations = {dilation, dilation}; + + in_desc.set(input_shape, format, dtype); + filter_desc.set(filter_shape, format, dtype, group); + out_desc.set(output_shape, format, dtype); + + int output_channel = filter_shape[0]; + std::vector stats_shape = {1, 1, 1, output_channel}; + out_stats_desc.set(stats_shape, format, compute_type); + + conv_desc.set(dtype, paddings, strides, dilations, false, group); + } + + bool IsSupport(const platform::CUDADeviceContext &ctx, + const std::vector &filter_shape, int stride, int dilation, + int group) { + int kernel_size = filter_shape[1]; + if (dilation != 1 || group != 1) { + return false; + } + if (ctx.GetComputeCapability() == 70) { + if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) { + return true; + } + } else if (ctx.GetComputeCapability() > 70) { + if ((kernel_size == 3) || (kernel_size == 1)) { + return true; + } + } + return false; + } + + cudnnDataType_t dtype; + cudnnTensorFormat_t format; + cudnnDataType_t compute_type; + + std::vector in_dims; + std::vector filter_dims; + std::vector strides; + std::vector paddings; + std::vector dilations; + + platform::TensorDescriptor in_desc; + platform::FilterDescriptor filter_desc; + platform::TensorDescriptor out_desc; + platform::TensorDescriptor out_stats_desc; + platform::ConvolutionDescriptor conv_desc; + + bool is_support; +}; + +template +class CudnnNormConvolution { + public: + CudnnNormConvolution(const platform::CUDADeviceContext &ctx, + const std::vector &input_shape, + const std::vector &filter_shape, + const std::vector &output_shape, const int &padding, + const int &stride, const int &dilation, + const int &group) { + args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride, + dilation, group); + } + ~CudnnNormConvolution() {} + + void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input, + const Tensor &filter, Tensor *output, Tensor *sum, + Tensor *sum_of_squares) { + auto cudnn_handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); + + CudnnFusionOp *fwd_op = GetForwardOp(ctx); + size_t workspace_size = RoundUp( + static_cast(fwd_op->GetWorkspaceSizeInBytes(cudnn_handle)), + 512); + + // Set variant_param + // input ptr + T *input_ptr = const_cast(input.data()); + T *filter_ptr = const_cast(filter.data()); + fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr); + fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr); + fwd_op->SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size); + + // output ptr + T *output_ptr = output->mutable_data(place); + float *sum_ptr = sum->mutable_data(place); + float *sum_of_squares_ptr = sum_of_squares->mutable_data(place); + fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr); + fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); + fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); + + ctx.cudnn_workspace_handle().RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // fused op execute + fwd_op->Execute(cudnn_handle); + }, + workspace_size); + } + + private: + CudnnFusionOp *GetForwardOp(const platform::CUDADeviceContext &ctx) { + framework::AlgorithmsCache &cache = + *(CudnnFusionOpCache::Instance().GetForward()); + + CudnnFusionOp *fwd_op = cache.GetAlgorithm( + args_.in_dims, args_.filter_dims, args_.strides, args_.paddings, + args_.dilations, 0, static_cast(args_.dtype), [&]() { + CudnnFusionOp *fwd_op = + new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS); + + // Set constant_param + fwd_op->SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER, + CUDNN_PARAM_YDATA_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + fwd_op->SetOpConstParamAttr( + {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + + // conv desc + fwd_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC, + args_.conv_desc.desc()); + // input desc + fwd_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); + // filter desc + fwd_op->SetOpConstParamDesc(CUDNN_PARAM_WDESC, + args_.filter_desc.desc()); + // output desc + fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc()); + // output_stats desc + fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC, + args_.out_stats_desc.desc()); + // batch_norm mode + fwd_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + + // Make cudnn fused ops plan + fwd_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle()); + return fwd_op; + }); + return fwd_op; + } + + private: + NormConvolutionArgs args_; +}; + +template +class CudnnNormConvolutionGrad { + public: + CudnnNormConvolutionGrad(const platform::CUDADeviceContext &ctx, + const std::vector &input_shape, + const std::vector &filter_shape, + const std::vector &output_shape, + const int &padding, const int &stride, + const int &dilation, const int &group) { + args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride, + dilation, group); + dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } + ~CudnnNormConvolutionGrad() {} + + void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input, + const Tensor &filter, const Tensor &output_grad, + Tensor *input_grad, Tensor *filter_grad, + bool use_addto = false) { + auto place = ctx.GetPlace(); + T *input_ptr = const_cast(input.data()); + T *filter_ptr = const_cast(filter.data()); + T *output_grad_ptr = const_cast(output_grad.data()); + + if (filter_grad) { + T *filter_grad_ptr = filter_grad->mutable_data(place); + BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr); + } + if (input_grad) { + T *input_grad_ptr = input_grad->mutable_data(place); + BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto); + } + } + + private: + void BackwardFilter(const platform::CUDADeviceContext &ctx, + T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) { + auto cudnn_handle = ctx.cudnn_handle(); + + CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx); + size_t workspace_size = RoundUp( + static_cast(wgrad_op->GetWorkspaceSizeInBytes(cudnn_handle)), + 512); + + wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr); + wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, output_grad_ptr); + wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DWDATA, filter_grad_ptr); + wgrad_op->SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size); + + ctx.cudnn_workspace_handle().RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + workspace_ptr); + // fused op execute + wgrad_op->Execute(cudnn_handle); + }, + workspace_size); + } + + void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr, + T *filter_ptr, T *input_grad_ptr, bool use_addto = false) { + auto cudnn_handle = ctx.cudnn_handle(); + size_t workspace_size = GetWorkspaceSizeBwdData(ctx); + + // Convolution dgrad followed optionally by batchnorm dgrad + ScalingParamType alpha = 1.0f; + ScalingParamType beta = use_addto ? 1.0f : 0.0f; + ctx.cudnn_workspace_handle().RunFunc( + [&](void *cudnn_workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnConvolutionBackwardData( + cudnn_handle, &alpha, args_.filter_desc.desc(), filter_ptr, + args_.out_desc.desc(), output_grad_ptr, + args_.conv_desc.desc(), dgrad_algo_, cudnn_workspace_ptr, + workspace_size, &beta, args_.in_desc.desc(), input_grad_ptr)); + }, + workspace_size); + } + + CudnnFusionOp *GetBackwardFilterOp(const platform::CUDADeviceContext &ctx) { + framework::AlgorithmsCache &cache = + *(CudnnFusionOpCache::Instance().GetBackward()); + + CudnnFusionOp *wgrad_op = cache.GetAlgorithm( + args_.in_dims, args_.filter_dims, args_.strides, args_.paddings, + args_.dilations, 0, static_cast(args_.dtype), [&]() { + CudnnFusionOp *wgrad_op = + new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD); + + wgrad_op->SetOpConstParamAttr( + {CUDNN_PARAM_DYDATA_PLACEHOLDER, CUDNN_PARAM_XDATA_PLACEHOLDER, + CUDNN_PARAM_DWDATA_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + + // conv desc + wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC, + args_.conv_desc.desc()); + // input desc + wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, + args_.in_desc.desc()); + // filter desc + wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DWDESC, + args_.filter_desc.desc()); + // output desc + wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DYDESC, + args_.out_desc.desc()); + wgrad_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + + // Make cudnn fused ops plan + wgrad_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle()); + return wgrad_op; + }); + return wgrad_op; + } + + size_t GetWorkspaceSizeBwdData(const platform::CUDADeviceContext &ctx) { + size_t workspace_size = 0U; + auto handle = ctx.cudnn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, args_.filter_desc.desc(), args_.out_desc.desc(), + args_.conv_desc.desc(), args_.in_desc.desc(), dgrad_algo_, + &workspace_size)); + return RoundUp(workspace_size, 512); + } + + private: + NormConvolutionArgs args_; + cudnnConvolutionBwdDataAlgo_t dgrad_algo_; +}; + +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc new file mode 100644 index 0000000000000..b930526a4901c --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc @@ -0,0 +1,458 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/float16.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace op = paddle::operators; +using Tensor = paddle::framework::Tensor; + +USE_OP(conv2d); +USE_OP(conv2d_grad); +USE_OP_DEVICE_KERNEL(conv2d, CUDNN); +USE_OP_DEVICE_KERNEL(conv2d_grad, CUDNN); + +template +void InitRandomTensor(const std::vector &dims, + framework::Tensor *cpu_out) { + T *cpu_out_ptr = cpu_out->mutable_data(framework::make_ddim(dims), + platform::CPUPlace()); + + std::default_random_engine random(0); + std::uniform_real_distribution dis(0.0, 1.0); + for (int i = 0; i < cpu_out->numel(); ++i) { + cpu_out_ptr[i] = static_cast(dis(random)); + } +} + +template +void TransposeNchwToNhwc(const framework::Tensor &cpu_in, + framework::Tensor *cpu_out) { + auto in_dims = cpu_in.dims(); + EXPECT_EQ(cpu_in.dims().size(), 4); + + const T *cpu_in_ptr = cpu_in.data(); + T *cpu_out_ptr = cpu_out->mutable_data( + {in_dims[0], in_dims[2], in_dims[3], in_dims[1]}, platform::CPUPlace()); + + int64_t n = in_dims[0]; + int64_t c = in_dims[1]; + int64_t hw = in_dims[2] * in_dims[3]; + for (int i = 0; i < n; ++i) { + for (int j = 0; j < hw; ++j) { + for (int k = 0; k < c; ++k) { + int dst_idx = i * hw * c + j * c + k; + int src_idx = i * c * hw + k * hw + j; + cpu_out_ptr[dst_idx] = cpu_in_ptr[src_idx]; + } + } + } +} + +template +void CheckOutput(const framework::Tensor &cpu_res, + const framework::Tensor &cpu_base, float diff, + bool is_relative_atol = false) { + EXPECT_EQ(cpu_res.dims(), cpu_base.dims()); + + const T *cpu_res_ptr = cpu_res.data(); + const T *cpu_base_ptr = cpu_base.data(); + for (int i = 0; i < cpu_res.numel(); ++i) { + if (is_relative_atol) { + EXPECT_LT(static_cast(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / + cpu_base_ptr[i])), + diff); + } else { + EXPECT_LT(static_cast(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])), + diff); + } + } +} + +// Use Paddle conv2d op results as baseline +void ComputeConv2DForward(const platform::CUDADeviceContext &ctx, + const Tensor &cpu_input, const Tensor &cpu_filter, + Tensor *cpu_output, int stride, int padding) { + framework::Scope scope; + auto *input = scope.Var("Input")->GetMutable(); + auto *filter = scope.Var("Filter")->GetMutable(); + auto *output = scope.Var("Output")->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_input, place, input); + TensorCopySync(cpu_filter, place, filter); + + framework::AttributeMap attrs; + bool use_cudnn = true; + std::string data_format = "NHWC"; + std::vector strides = {stride, stride}; + std::vector paddings = {padding, padding}; + attrs.insert({"strides", strides}); + attrs.insert({"paddings", paddings}); + attrs.insert({"use_cudnn", use_cudnn}); + attrs.insert({"data_format", data_format}); + + auto op = framework::OpRegistry::CreateOp( + "conv2d", {{"Input", {"Input"}}, {"Filter", {"Filter"}}}, + {{"Output", {"Output"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*output, platform::CPUPlace(), cpu_output); +} + +// Use Paddle conv2d_grad op results as baseline +void ComputeConv2DBackward(const platform::CUDADeviceContext &ctx, + const Tensor &cpu_input, const Tensor &cpu_filter, + const Tensor &cpu_output_grad, + framework::Tensor *cpu_input_grad, + framework::Tensor *cpu_filter_grad, int stride, + int padding, int dilation) { + framework::Scope scope; + auto *input = scope.Var("Input")->GetMutable(); + auto *filter = scope.Var("Filter")->GetMutable(); + auto *output_grad = + scope.Var("Output@GRAD")->GetMutable(); + auto *input_grad = + scope.Var("Input@GRAD")->GetMutable(); + auto *filter_grad = + scope.Var("Filter@GRAD")->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_input, place, input); + TensorCopySync(cpu_filter, place, filter); + TensorCopySync(cpu_output_grad, place, output_grad); + + framework::AttributeMap attrs; + bool use_cudnn = true; + std::string data_format = "NHWC"; + std::string padding_algorithm = "EXPLICIT"; + std::vector strides = {stride, stride}; + std::vector paddings = {padding, padding}; + std::vector dilations = {dilation, dilation}; + int groups = 1; + bool exhaustive_search = false; + bool use_addto = false; + attrs.insert({"use_cudnn", use_cudnn}); + attrs.insert({"data_format", data_format}); + attrs.insert({"padding_algorithm", padding_algorithm}); + attrs.insert({"strides", strides}); + attrs.insert({"paddings", paddings}); + attrs.insert({"dilations", dilations}); + attrs.insert({"groups", groups}); + attrs.insert({"exhaustive_search", exhaustive_search}); + attrs.insert({"use_addto", use_addto}); + + auto op = framework::OpRegistry::CreateOp( + "conv2d_grad", {{"Input", {"Input"}}, + {"Filter", {"Filter"}}, + {"Output@GRAD", {"Output@GRAD"}}}, + {{"Input@GRAD", {"Input@GRAD"}}, {"Filter@GRAD", {"Filter@GRAD"}}}, + attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*input_grad, platform::CPUPlace(), cpu_input_grad); + TensorCopySync(*filter_grad, platform::CPUPlace(), cpu_filter_grad); +} + +template +void ComputeSumAndSquareSum(const framework::Tensor &cpu_out, + framework::Tensor *cpu_sum, + framework::Tensor *cpu_sum_of_square) { + auto dims = cpu_out.dims(); + int64_t c = dims[3]; + + const T *cpu_out_ptr = cpu_out.data(); + float *cpu_sum_ptr = + cpu_sum->mutable_data({1, 1, 1, c}, platform::CPUPlace()); + float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data( + {1, 1, 1, c}, platform::CPUPlace()); + + for (int j = 0; j < c; ++j) { + float tmp_sum = 0.0f; + float tmp_sum_of_squares = 0.0f; + for (int i = 0; i < cpu_out.numel() / c; ++i) { + float tmp_out = static_cast(cpu_out_ptr[i * c + j]); + tmp_sum += tmp_out; + tmp_sum_of_squares += tmp_out * tmp_out; + } + cpu_sum_ptr[j] = tmp_sum; + cpu_sum_square_ptr[j] = tmp_sum_of_squares; + } +} + +template +class CudnnNormConvolutionTester { + public: + CudnnNormConvolutionTester(int batch_size, int height, int width, + int input_channels, int output_channels, + int kernel_size, int stride) { + batch_size_ = batch_size; + height_ = height; + width_ = width; + input_channels_ = input_channels; + output_channels_ = output_channels; + kernel_size_ = kernel_size; + stride_ = stride; + padding_ = (kernel_size_ - 1) / 2; + out_height_ = (height_ + 2 * padding_ - kernel_size_) / stride_ + 1; + out_width_ = (width_ + 2 * padding_ - kernel_size_) / stride_ + 1; + SetUp(); + } + + ~CudnnNormConvolutionTester() {} + + void CheckForward(float diff, bool is_relative_atol = false) { + platform::CUDADeviceContext *ctx = + static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(0))); + + framework::Tensor cpu_output_base; + framework::Tensor cpu_sum_base; + framework::Tensor cpu_sum_of_square_base; + BaselineForward(*ctx, &cpu_output_base, &cpu_sum_base, + &cpu_sum_of_square_base); + + framework::Tensor cpu_output; + framework::Tensor cpu_sum; + framework::Tensor cpu_sum_of_square; + FusedForward(*ctx, &cpu_output, &cpu_sum, &cpu_sum_of_square); + + // Check forward correctness between baseline and results of normconv. + CheckOutput(cpu_output, cpu_output_base, diff, is_relative_atol); + CheckOutput(cpu_sum, cpu_sum_base, diff, is_relative_atol); + CheckOutput(cpu_sum_of_square, cpu_sum_of_square_base, diff, + is_relative_atol); + } + + void CheckBackward(float diff, bool is_relative_atol = false) { + platform::CUDADeviceContext *ctx = + static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(0))); + + framework::Tensor cpu_input_grad_base; + framework::Tensor cpu_filter_nchw_grad_base; + framework::Tensor cpu_filter_nhwc_grad_base; + BaselineBackward(*ctx, &cpu_input_grad_base, &cpu_filter_nchw_grad_base); + TransposeNchwToNhwc(cpu_filter_nchw_grad_base, + &cpu_filter_nhwc_grad_base); + + framework::Tensor cpu_input_grad; + framework::Tensor cpu_filter_nhwc_grad; + FusedBackward(*ctx, &cpu_input_grad, &cpu_filter_nhwc_grad); + + // Check backward correctness between baseline and results of normconv. + CheckOutput(cpu_input_grad, cpu_input_grad_base, diff, is_relative_atol); + CheckOutput(cpu_filter_nhwc_grad, cpu_filter_nhwc_grad_base, diff, + is_relative_atol); + } + + private: + void SetUp() { + InitRandomTensor({batch_size_, height_, width_, input_channels_}, + &cpu_input_); + InitRandomTensor( + {output_channels_, input_channels_, kernel_size_, kernel_size_}, + &cpu_filter_nchw_); + // transpoes for filter, NCHW -> NHWC + TransposeNchwToNhwc(cpu_filter_nchw_, &cpu_filter_nhwc_); + InitRandomTensor( + {batch_size_, out_height_, out_width_, output_channels_}, + &cpu_output_grad_); + } + + void BaselineForward(const platform::CUDADeviceContext &ctx, + framework::Tensor *cpu_output_base, + framework::Tensor *cpu_sum_base, + framework::Tensor *cpu_sum_of_square_base) { + ComputeConv2DForward(ctx, cpu_input_, cpu_filter_nchw_, cpu_output_base, + stride_, padding_); + ComputeSumAndSquareSum(*cpu_output_base, cpu_sum_base, + cpu_sum_of_square_base); + } + + void BaselineBackward(const platform::CUDADeviceContext &ctx, + framework::Tensor *cpu_input_grad_base, + framework::Tensor *cpu_filter_grad_base) { + ComputeConv2DBackward(ctx, cpu_input_, cpu_filter_nchw_, cpu_output_grad_, + cpu_input_grad_base, cpu_filter_grad_base, stride_, + padding_, dilation_); + } + + // get forward results of cudnn_norm_conv + void FusedForward(const platform::CUDADeviceContext &ctx, + framework::Tensor *cpu_output, framework::Tensor *cpu_sum, + framework::Tensor *cpu_sum_of_square) { + framework::Tensor input; + framework::Tensor filter_nhwc; + framework::Tensor output; + framework::Tensor sum; + framework::Tensor sum_of_square; + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_input_, place, &input); + TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc); + + output.Resize(framework::make_ddim( + {batch_size_, out_height_, out_width_, output_channels_})); + sum.Resize(framework::make_ddim({1, 1, 1, output_channels_})); + sum_of_square.Resize(framework::make_ddim({1, 1, 1, output_channels_})); + + auto input_shape = framework::vectorize(input.dims()); + auto filter_shape = framework::vectorize(filter_nhwc.dims()); + auto output_shape = framework::vectorize(output.dims()); + op::CudnnNormConvolution conv_op(ctx, input_shape, filter_shape, + output_shape, padding_, stride_, + dilation_, group_); + conv_op.Forward(ctx, input, filter_nhwc, &output, &sum, &sum_of_square); + + TensorCopySync(output, platform::CPUPlace(), cpu_output); + TensorCopySync(sum, platform::CPUPlace(), cpu_sum); + TensorCopySync(sum_of_square, platform::CPUPlace(), cpu_sum_of_square); + } + + void FusedBackward(const platform::CUDADeviceContext &ctx, + framework::Tensor *cpu_input_grad, + framework::Tensor *cpu_filter_grad) { + framework::Tensor input; + framework::Tensor filter_nhwc; + framework::Tensor output_grad; + framework::Tensor input_grad; + framework::Tensor filter_grad; + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_input_, place, &input); + TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc); + TensorCopySync(cpu_output_grad_, place, &output_grad); + + input_grad.Resize(input.dims()); + filter_grad.Resize(filter_nhwc.dims()); + + auto input_shape = framework::vectorize(input.dims()); + auto filter_shape = framework::vectorize(filter_nhwc.dims()); + auto output_shape = framework::vectorize(output_grad.dims()); + op::CudnnNormConvolutionGrad conv_grad_op(ctx, input_shape, filter_shape, + output_shape, padding_, + stride_, dilation_, group_); + conv_grad_op.Backward(ctx, input, filter_nhwc, output_grad, &input_grad, + &filter_grad); + + TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad); + TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad); + } + + private: + int batch_size_; + int height_; + int width_; + int out_height_; + int out_width_; + int input_channels_; + int output_channels_; + int kernel_size_; + int stride_; + int padding_; + const int dilation_ = 1; + const int group_ = 1; + + // Forward input + framework::Tensor cpu_input_; + framework::Tensor cpu_filter_nchw_; + framework::Tensor cpu_filter_nhwc_; + + // Backward input + framework::Tensor cpu_output_grad_; +}; + +// test for fp16, kernel = 1, output_channels = input_channels +TEST(CudnnNormConvFp16, K1S1) { + int batch_size = 4; + int height = 56; + int width = 56; + int input_channels = 32; + int output_channels = 32; + int kernel_size = 1; + int stride = 1; + CudnnNormConvolutionTester test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + test.CheckForward(1e-3, true); + test.CheckBackward(1e-3, true); +} + +// test for fp16, kernel = 3, output_channels = input_channels +TEST(CudnnNormConvFp16, K3S1) { + int batch_size = 4; + int height = 56; + int width = 56; + int input_channels = 32; + int output_channels = 32; + int kernel_size = 3; + int stride = 1; + CudnnNormConvolutionTester test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + test.CheckForward(1e-3, true); + test.CheckBackward(1e-3, true); +} + +// test for fp16, kernel = 1, output_channels = input_channels * 4 +TEST(CudnnNormConvFp16, K1S1O4) { + int batch_size = 4; + int height = 56; + int width = 56; + int input_channels = 32; + int output_channels = 128; + int kernel_size = 1; + int stride = 1; + CudnnNormConvolutionTester test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + test.CheckForward(1e-3, true); + test.CheckBackward(1e-3, true); +} + +// test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4 +TEST(CudnnNormConvFp16, K1S2O4) { + int batch_size = 4; + int height = 8; + int width = 8; + int input_channels = 32; + int output_channels = 128; + int kernel_size = 1; + int stride = 2; + CudnnNormConvolutionTester test( + batch_size, height, width, input_channels, output_channels, kernel_size, + stride); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3)); + } +} diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h new file mode 100644 index 0000000000000..d60a9174e1317 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -0,0 +1,317 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +template +using CudnnDataType = platform::CudnnDataType; +namespace dynload = platform::dynload; +template +using BatchNormParamType = + typename platform::CudnnDataType::BatchNormParamType; + +#if CUDNN_VERSION >= 8000 + +template +struct ScaleBiasAddReluArgs { + ScaleBiasAddReluArgs() { + dtype = platform::CudnnDataType::type; + param_dtype = platform::CudnnDataType>::type; + format = CUDNN_TENSOR_NHWC; + } + + void Set(const std::string &act_type, const std::vector &data_shape, + const std::vector ¶m_shape, + const std::vector &bitmask_shape) { + PADDLE_ENFORCE_EQ( + data_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of data_shape is expected to 4. But recieved " + "data_shape's size is %d, data_shape is [%s].", + data_shape.size(), framework::make_ddim(data_shape))); + PADDLE_ENFORCE_EQ( + param_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of param_shape is expected to 4. But recieved " + "param_shape's size is %d, param_shape is [%s].", + param_shape.size(), framework::make_ddim(param_shape))); + PADDLE_ENFORCE_EQ( + bitmask_shape.size(), 3U, + platform::errors::InvalidArgument( + "The size of bitmask_shape is expected to 3. But recieved " + "bitmask_shape's size is %d, bitmask_shape is [%s].", + bitmask_shape.size(), framework::make_ddim(bitmask_shape))); + + in_desc.set(data_shape, format, dtype); + out_desc.set(data_shape, format, dtype); + equiv_scale_bias_desc.set(param_shape, format, dtype); + scale_bias_mean_var_desc.set(param_shape, format, param_dtype); + bitmask_desc.set(bitmask_shape, format, CUDNN_DATA_INT32); + // set activation desc + cudnnActivationMode_t mode = CUDNN_ACTIVATION_IDENTITY; + if (act_type != "") { + PADDLE_ENFORCE_EQ( + act_type, "relu", + platform::errors::InvalidArgument( + "Only relu activation supported in normalized convolution.")); + mode = CUDNN_ACTIVATION_RELU; + } + double dummy_clip = 0.0; + activation_desc.set(mode, dummy_clip); + } + + cudnnDataType_t dtype; + cudnnDataType_t param_dtype; + cudnnTensorFormat_t format; + + platform::TensorDescriptor in_desc; + platform::TensorDescriptor out_desc; + platform::TensorDescriptor equiv_scale_bias_desc; + platform::TensorDescriptor scale_bias_mean_var_desc; + platform::TensorDescriptor bitmask_desc; + platform::ActivationDescriptor activation_desc; +}; + +template +class CudnnScaleBiasAddRelu { + public: + CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx, + const std::string &act_type, bool fuse_add, + bool has_shortcut, const std::vector &data_shape, + const std::vector ¶m_shape, + const std::vector &bitmask_shape) + : fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK), + bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) { + fuse_add_ = fuse_add; + has_shortcut_ = has_shortcut; + args_.Set(act_type, data_shape, param_shape, bitmask_shape); + } + + ~CudnnScaleBiasAddRelu() {} + + void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x, + const Tensor &x_scale, const Tensor &x_bias, const Tensor *z, + const Tensor *z_scale, const Tensor *z_bias, Tensor *out, + Tensor *bitmask) { + ForwardInit(ctx); + auto handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); + auto workspace_handle = ctx.cudnn_workspace_handle(); + fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); + // Set variant_param + // input ptr + T *x_ptr = const_cast(x.data()); + T *x_scale_ptr = const_cast(x_scale.data()); + T *x_bias_ptr = const_cast(x_bias.data()); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr); + if (has_shortcut_) { + T *z_ptr = const_cast(z->data()); + T *z_scale_ptr = const_cast(z_scale->data()); + T *z_bias_ptr = const_cast(z_bias->data()); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr); + } else { + if (fuse_add_) { + T *z_ptr = const_cast(z->data()); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); + } + } + + fwd_op_.SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); + + // output ptr + T *out_ptr = out->mutable_data(place); + int32_t *bitmask_ptr = bitmask->mutable_data(place); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); + + workspace_handle.RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // workspace ptr + fwd_op_.Execute(handle); + }, + fwd_workspace_byte_); + } + + void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy, + const Tensor &x, const Tensor &scale, const Tensor &bias, + const Tensor &saved_mean, const Tensor &saved_invstd, + const Tensor *bitmask, Tensor *dx, Tensor *dz, Tensor *dscale, + Tensor *dbias, double eps) { + BackwardInit(ctx); + auto handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); + auto workspace_handle = ctx.cudnn_workspace_handle(); + bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle); + // Set variant_param + // input ptr + T *dy_ptr = const_cast(dy.data()); + T *x_ptr = const_cast(x.data()); + float *scale_ptr = const_cast(scale.data()); + float *bias_ptr = const_cast(bias.data()); + float *saved_mean_ptr = const_cast(saved_mean.data()); + float *saved_invstd_ptr = const_cast(saved_invstd.data()); + int32_t *bitmask_ptr = + bitmask ? const_cast(bitmask->data()) : nullptr; + T *dx_ptr = dx->mutable_data(place); + T *dz_ptr = dz ? dz->mutable_data(place) : nullptr; + float *dscale_ptr = dscale ? dscale->mutable_data(place) : nullptr; + float *dbias_ptr = dbias ? dbias->mutable_data(place) : nullptr; + + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, + saved_invstd_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); + + bwd_op_.SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &bwd_workspace_byte_); + + // output ptr + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DXDATA, dx_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DSCALE, dscale_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EPSILON, + &eps); + if (has_shortcut_ || fuse_add_) { + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr); + } + + workspace_handle.RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // workspace ptr + bwd_op_.Execute(handle); + }, + bwd_workspace_byte_); + } + + private: + void ForwardInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + if (has_shortcut_) { + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + } else if (fuse_add_) { + fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER, + CUDNN_PTR_16B_ALIGNED); + } + + // input desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); + if (has_shortcut_ || fuse_add_) { + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc()); + } + + // equiv scale/bias desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + args_.equiv_scale_bias_desc.desc()); + if (has_shortcut_) { + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC, + args_.equiv_scale_bias_desc.desc()); + } + + // output desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc()); + + // bitmask desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC, + args_.bitmask_desc.desc()); + + // activation desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC, + args_.activation_desc.desc()); + + // others + fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + } + + void BackwardInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param + bwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_DYDATA_PLACEHOLDER, + CUDNN_PARAM_DXDATA_PLACEHOLDER, CUDNN_PARAM_BN_SCALE_PLACEHOLDER, + CUDNN_PARAM_BN_BIAS_PLACEHOLDER, CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER, + CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + if (has_shortcut_ || fuse_add_) { + bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER, + CUDNN_PTR_16B_ALIGNED); + } + + // input desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc()); + if (has_shortcut_ || fuse_add_) { + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc()); + } + + // scale/bias/mean/var desc for backward + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC, + args_.scale_bias_mean_var_desc.desc()); + + // output desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DYDESC, args_.out_desc.desc()); + + // bitmask desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC, + args_.bitmask_desc.desc()); + + // activation desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC, + args_.activation_desc.desc()); + + // others + bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + } + + bool fuse_add_ = false; + bool has_shortcut_ = false; + size_t fwd_workspace_byte_; + size_t bwd_workspace_byte_; + ScaleBiasAddReluArgs args_; + CudnnFusionOp fwd_op_; + CudnnFusionOp bwd_op_; +}; +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc new file mode 100644 index 0000000000000..d2ac089d4d1d2 --- /dev/null +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -0,0 +1,411 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// Shape of bitmask +static framework::DDim GetBitmaskDims(std::vector out_shape) { + int c = out_shape.back(); + int64_t nhw = std::accumulate(out_shape.begin(), out_shape.end(), 1, + std::multiplies()) / + c; + int32_t c_int32_elems = ((c + 63) & ~63) / 32; + int32_t nhw_int32_elems = ((nhw + 31) & ~31); + std::vector bitmask_shape = {nhw_int32_elems, c_int32_elems, 1}; + return framework::make_ddim(bitmask_shape); +} + +class ResNetUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const { + // Check input + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("MeanX"), "Input", "MeanX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("VarX"), "Input", "VarX", "ResNetUnitOp"); + + bool fuse_add = ctx->Attrs().Get("fuse_add"); + bool has_shortcut = ctx->Attrs().Get("has_shortcut"); + if (fuse_add || has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitOp"); + } + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("MeanZ"), "Input", "MeanZ", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("VarZ"), "Input", "VarZ", "ResNetUnitOp"); + } + + // Check output + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("BitMask"), "Output", "BitMask", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("ConvX"), "Output", "ConvX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedMeanX"), "Output", "SavedMeanX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdX"), "Output", "SavedInvstdX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningMeanX"), "Output", "RunningMeanX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningVarX"), "Output", "RunningVarX", + "ResNetUnitOp"); + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasOutput("ConvZ"), "Output", "ConvZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedMeanZ"), "Output", "SavedMeanZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdZ"), "Output", "SavedInvstdZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningMeanZ"), "Output", "RunningMeanZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningVarZ"), "Output", "RunningVarZ", + "ResNetUnitOp"); + } + + // make sure Mean/RunningMean and Var/RunningVar share memory + PADDLE_ENFORCE_EQ( + ctx->Inputs("MeanX")[0], ctx->Outputs("RunningMeanX")[0], + platform::errors::InvalidArgument( + "MeanX and RunningMeanX should share the same memory")); + PADDLE_ENFORCE_EQ(ctx->Inputs("VarX")[0], ctx->Outputs("RunningVarX")[0], + platform::errors::InvalidArgument( + "VarX and RunningVarX should share the same memory")); + if (has_shortcut) { + PADDLE_ENFORCE_EQ( + ctx->Inputs("MeanZ")[0], ctx->Outputs("RunningMeanZ")[0], + platform::errors::InvalidArgument( + "MeanZ and RunningMeanZ should share the same memory")); + PADDLE_ENFORCE_EQ( + ctx->Inputs("VarZ")[0], ctx->Outputs("RunningVarZ")[0], + platform::errors::InvalidArgument( + "VarZ and RunningVarZ should share the same memory")); + } + + // Check dims of inputs + const auto x_dims = ctx->GetInputDim("X"); + const auto w_dims = ctx->GetInputDim("FilterX"); + const auto bn_param_dims = ctx->GetInputDim("ScaleX"); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument( + "The dimensions of input " + "must equal to 4." + "But received: the shape of input " + "= [%s], the dimension of input = " + "[%d]", + x_dims, x_dims.size())); + PADDLE_ENFORCE_EQ(w_dims.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of filter " + "must equal to 4." + "But received: the shape of filter " + "= [%s], the dimension of filter = [%d] ", + w_dims, w_dims.size())); + PADDLE_ENFORCE_EQ(bn_param_dims.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of bn param " + "must equal to 4." + "But received: the shape of bn param " + "= [%s], the dimension of bn param = [%d] ", + bn_param_dims, bn_param_dims.size())); + auto data_format = ctx->Attrs().Get("data_format"); + PADDLE_ENFORCE_EQ( + data_format, "NHWC", + platform::errors::InvalidArgument("The data format must equal to NHWC. " + "But received: the data format " + "= [%s]", + data_format)); + // Calculate the dims of outputs + int batch = x_dims[0]; + int output_channel = w_dims[0]; + int filter_size = w_dims[2]; + int stride = ctx->Attrs().Get("stride"); + int padding = ctx->Attrs().Get("padding"); + int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; + int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; + std::vector out_shape = {batch, out_h, out_w, output_channel}; + + auto y_dims = framework::make_ddim(out_shape); + auto bitmask_dims = GetBitmaskDims(out_shape); + // Set dims of outputs + ctx->SetOutputDim("Y", y_dims); + ctx->SetOutputDim("BitMask", bitmask_dims); + ctx->SetOutputDim("ConvX", y_dims); + ctx->SetOutputDim("SavedMeanX", bn_param_dims); + ctx->SetOutputDim("SavedInvstdX", bn_param_dims); + ctx->SetOutputDim("RunningMeanX", bn_param_dims); + ctx->SetOutputDim("RunningVarX", bn_param_dims); + if (has_shortcut) { + ctx->SetOutputDim("ConvZ", y_dims); + ctx->SetOutputDim("SavedMeanZ", bn_param_dims); + ctx->SetOutputDim("SavedInvstdZ", bn_param_dims); + ctx->SetOutputDim("RunningMeanZ", bn_param_dims); + ctx->SetOutputDim("RunningVarZ", bn_param_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + // By default, the type of the scale, bias, mean, + // and var tensors should be float when input tensor's dtype is float16. + auto bn_param_type = framework::proto::VarType::FP32; + + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("ScaleX")->type(), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("BiasX")->type(), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); + } +}; + +class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "The input 1 tensor"); + AddInput("FilterX", "Filter tensor of input 1"); + AddInput("ScaleX", "Scale tensor of input 1 used in batchnorm"); + AddInput("BiasX", "Bias tensor of input 1 used in batchnorm"); + AddInput("MeanX", "Mean tensor of input 1 used in batchnorm"); + AddInput("VarX", "Variance tensor of input 1 used in batchnorm"); + AddInput("Z", "The input 2 tensor").AsDispensable(); + AddInput("FilterZ", "Filter tensor of input 2").AsDispensable(); + AddInput("ScaleZ", "Scale tensor of input 2").AsDispensable(); + AddInput("BiasZ", "Bias tensor of input 2").AsDispensable(); + AddInput("MeanZ", "Mean tensor of input 2").AsDispensable(); + AddInput("VarZ", "Variance tensor of input 2").AsDispensable(); + AddOutput("Y", "The result of the resnet unit"); + AddOutput("BitMask", "The bitmask generated after relu"); + AddOutput("ConvX", "The output of input 1 after conv"); + AddOutput("SavedMeanX", "Mean of input 1 in the current batch"); + AddOutput("SavedInvstdX", "Invstd of input 1 in the current batch"); + AddOutput("RunningMeanX", "Shared memory with MeanX"); + AddOutput("RunningVarX", "Shared memory with VarX"); + AddOutput("ConvZ", "The output of input 2 after conv").AsDispensable(); + AddOutput("SavedMeanZ", "Mean of input 1 in the current batch") + .AsDispensable(); + AddOutput("SavedInvstdZ", "Invstd of input 1 in the current batch") + .AsDispensable(); + AddOutput("RunningMeanZ", "Shared memory with MeanZ").AsDispensable(); + AddOutput("RunningVarZ", "Shared memory with VarZ").AsDispensable(); + AddAttr("stride", "").SetDefault(1); + AddAttr("stride_z", "").SetDefault(1); + AddAttr("padding", "").SetDefault(0); + AddAttr("dilation", "").SetDefault(1); + AddAttr("group", "").SetDefault(1); + AddAttr("momentum", "").SetDefault(0.9); + AddAttr("epsilon", "").SetDefault(1e-5); + AddAttr("data_format", "").SetDefault("NHWC"); + AddAttr("fuse_add", "").SetDefault(false); + AddAttr("has_shortcut", "").SetDefault(false); + AddAttr("use_global_stats", "").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("use_addto", "").SetDefault(false); + AddAttr("act_type", "The activation type to be fused.") + .SetDefault("relu"); + AddComment(R"DOC( +Fusion op of the basic unit of resnet block. + +The implementation is based on the latest fusion op interface in cuDNN v8.0. +For more details: +https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t + +)DOC"); + } +}; + +class ResNetUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const { + // check input + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ConvX"), "Input", "ConvX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedMeanX"), "Input", "SavedMeanX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedInvstdX"), "Input", "SavedInvstdX", + "ResNetUnitGradOp"); + + bool fuse_add = ctx->Attrs().Get("fuse_add"); + bool has_shortcut = ctx->Attrs().Get("has_shortcut"); + if (fuse_add || has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitGradOp"); + } + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ConvZ"), "Input", "ConvZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedMeanZ"), "Input", "SavedMeanZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedInvstdZ"), "Input", "SavedInvstdZ", + "ResNetUnitGradOp"); + } + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("BitMask"), "Input", "BitMask", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", + framework::GradVarName("Y"), "ResNetUnitGradOp"); + + // check output + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterX")), "Output", + framework::GradVarName("FilterX"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleX")), "Output", + framework::GradVarName("ScaleX"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasX")), "Output", + framework::GradVarName("BiasX"), "ResNetUnitGradOp"); + if (fuse_add) { + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), "Output", + framework::GradVarName("Z"), "ResNetUnitGradOp"); + } + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterZ")), + "Output", framework::GradVarName("FilterZ"), + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleZ")), "Output", + framework::GradVarName("ScaleZ"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasZ")), "Output", + framework::GradVarName("BiasZ"), "ResNetUnitGradOp"); + } + const auto x_dims = ctx->GetInputDim("X"); + const auto filter_x_dims = ctx->GetInputDim("FilterX"); + const auto param_dims = ctx->GetInputDim("ScaleX"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->SetOutputDim(framework::GradVarName("FilterX"), filter_x_dims); + ctx->SetOutputDim(framework::GradVarName("ScaleX"), param_dims); + ctx->SetOutputDim(framework::GradVarName("BiasX"), param_dims); + if (fuse_add || has_shortcut) { + const auto z_dims = ctx->GetInputDim("Z"); + ctx->SetOutputDim(framework::GradVarName("Z"), z_dims); + } + if (has_shortcut) { + const auto filter_z_dims = ctx->GetInputDim("FilterZ"); + ctx->SetOutputDim(framework::GradVarName("FilterZ"), filter_z_dims); + ctx->SetOutputDim(framework::GradVarName("ScaleZ"), param_dims); + ctx->SetOutputDim(framework::GradVarName("BiasZ"), param_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar(framework::GradVarName("Y")), + platform::errors::NotFound( + "Can not find Y@GRAD in the execution context.")); + + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); + } +}; + +template +class ResNetUnitGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("resnet_unit_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("FilterX", this->Input("FilterX")); + op->SetInput("ConvX", this->Output("ConvX")); + op->SetInput("ScaleX", this->Input("ScaleX")); + op->SetInput("BiasX", this->Input("BiasX")); + op->SetInput("SavedMeanX", this->Output("SavedMeanX")); + op->SetInput("SavedInvstdX", this->Output("SavedInvstdX")); + op->SetInput("Z", this->Input("Z")); + op->SetInput("FilterZ", this->Input("FilterZ")); + op->SetInput("ConvZ", this->Output("ConvZ")); + op->SetInput("ScaleZ", this->Input("ScaleZ")); + op->SetInput("BiasZ", this->Input("BiasZ")); + op->SetInput("SavedMeanZ", this->Output("SavedMeanZ")); + op->SetInput("SavedInvstdZ", this->Output("SavedInvstdZ")); + op->SetInput("Y", this->Output("Y")); + op->SetInput("BitMask", this->Output("BitMask")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + + op->SetAttrMap(this->Attrs()); + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("FilterX"), + this->InputGrad("FilterX")); + op->SetOutput(framework::GradVarName("ScaleX"), this->InputGrad("ScaleX")); + op->SetOutput(framework::GradVarName("BiasX"), this->InputGrad("BiasX")); + op->SetOutput(framework::GradVarName("Z"), this->InputGrad("Z")); + op->SetOutput(framework::GradVarName("FilterZ"), + this->InputGrad("FilterZ")); + op->SetOutput(framework::GradVarName("ScaleZ"), this->InputGrad("ScaleZ")); + op->SetOutput(framework::GradVarName("BiasZ"), this->InputGrad("BiasZ")); + } +}; + +class ResNetUnitOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(resnet_unit, ops::ResNetUnitOp, ops::ResNetUnitOpMaker, + ops::ResNetUnitOpInferVarType, + ops::ResNetUnitGradOpMaker, + ops::ResNetUnitGradOpMaker); +REGISTER_OPERATOR(resnet_unit_grad, ops::ResNetUnitGradOp); diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cu b/paddle/fluid/operators/fused/resnet_unit_op.cu new file mode 100644 index 0000000000000..b121864f80e4d --- /dev/null +++ b/paddle/fluid/operators/fused/resnet_unit_op.cu @@ -0,0 +1,299 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h" +#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h" +#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ResNetUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("It must use CUDAPlace.")); + PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, CUDNN_DATA_HALF, + platform::errors::Unavailable( + "ResNetUnitOp only supports float16 for now.")); + + // input x + const Tensor *input_x = ctx.Input("X"); + const Tensor *filter_x = ctx.Input("FilterX"); + const Tensor *scale_x = ctx.Input("ScaleX"); + const Tensor *bias_x = ctx.Input("BiasX"); + // norm conv + Tensor *conv_out_x = ctx.Output("ConvX"); + // bn finalize + Tensor *saved_mean_x = ctx.Output("SavedMeanX"); + Tensor *saved_invstd_x = ctx.Output("SavedInvstdX"); + Tensor *running_mean_x = ctx.Output("RunningMeanX"); + Tensor *running_var_x = ctx.Output("RunningVarX"); + // sbar + Tensor *output = ctx.Output("Y"); + Tensor *bitmask = ctx.Output("BitMask"); + // attrs + int padding = ctx.Attr("padding"); + int stride = ctx.Attr("stride"); + int stride_z = ctx.Attr("stride_z"); + int dilation = ctx.Attr("dilation"); + int group = ctx.Attr("group"); + double eps = static_cast(ctx.Attr("epsilon")); + double momentum = static_cast(ctx.Attr("momentum")); + bool has_shortcut = ctx.Attr("has_shortcut"); + bool fuse_add = ctx.Attr("fuse_add"); + bool use_global_stats = ctx.Attr("use_global_stats"); + bool is_test = ctx.Attr("is_test"); + bool is_train = !is_test && !use_global_stats; + std::string act_type = ctx.Attr("act_type"); + + auto input_x_shape = framework::vectorize(input_x->dims()); + auto filter_x_shape = framework::vectorize(filter_x->dims()); + auto param_dims = scale_x->dims(); + auto param_shape = framework::vectorize(scale_x->dims()); + auto output_shape = framework::vectorize(output->dims()); + auto bitmask_shape = framework::vectorize(bitmask->dims()); + int output_channel = filter_x_shape[0]; + int64_t ele_count = + std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()) / + output_channel; + + auto place = ctx.GetPlace(); + auto &dev_ctx = ctx.template device_context(); + + // 1. Conv + Tensor sum_x; + Tensor sum_of_squares_x; + sum_x.Resize(param_dims); + sum_of_squares_x.Resize(param_dims); + CudnnNormConvolution conv_x_op(dev_ctx, input_x_shape, filter_x_shape, + output_shape, padding, stride, dilation, + group); + conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x, + &sum_of_squares_x); + + // 2. BN + Tensor equiv_scale_x; + Tensor equiv_bias_x; + equiv_scale_x.Resize(param_dims); + equiv_bias_x.Resize(param_dims); + CudnnBNStatsFinalize bn_x_op(dev_ctx, param_shape); + bn_x_op.Forward(dev_ctx, sum_x, sum_of_squares_x, *scale_x, *bias_x, + saved_mean_x, saved_invstd_x, running_mean_x, running_var_x, + &equiv_scale_x, &equiv_bias_x, eps, momentum, ele_count, + is_train); + + // 3. scale + bias + add + relu + CudnnScaleBiasAddRelu sbar_op(dev_ctx, act_type, fuse_add, has_shortcut, + output_shape, param_shape, bitmask_shape); + if (has_shortcut) { + // input z + const Tensor *input_z = ctx.Input("Z"); + const Tensor *filter_z = ctx.Input("FilterZ"); + const Tensor *scale_z = ctx.Input("ScaleZ"); + const Tensor *bias_z = ctx.Input("BiasZ"); + // norm conv + Tensor *conv_out_z = ctx.Output("ConvZ"); + // bn finalize + Tensor *saved_mean_z = ctx.Output("SavedMeanZ"); + Tensor *saved_invstd_z = ctx.Output("SavedInvstdZ"); + Tensor *running_mean_z = ctx.Output("RunningMeanZ"); + Tensor *running_var_z = ctx.Output("RunningVarZ"); + + auto input_z_shape = framework::vectorize(input_z->dims()); + auto filter_z_shape = framework::vectorize(filter_z->dims()); + + // 3.1 Conv for second input + Tensor sum_z; + Tensor sum_of_squares_z; + sum_z.Resize(param_dims); + sum_of_squares_z.Resize(param_dims); + CudnnNormConvolution conv_z_op(dev_ctx, input_z_shape, filter_z_shape, + output_shape, padding, stride_z, + dilation, group); + conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z, + &sum_of_squares_z); + + // 3.2 BN for second input + Tensor equiv_scale_z; + Tensor equiv_bias_z; + equiv_scale_z.Resize(param_dims); + equiv_bias_z.Resize(param_dims); + CudnnBNStatsFinalize bn_z_op(dev_ctx, param_shape); + bn_z_op.Forward(dev_ctx, sum_z, sum_of_squares_z, *scale_z, *bias_z, + saved_mean_z, saved_invstd_z, running_mean_z, + running_var_z, &equiv_scale_z, &equiv_bias_z, eps, + momentum, ele_count, is_train); + // 3.3 sbar + sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x, + conv_out_z, &equiv_scale_z, &equiv_bias_z, output, + bitmask); + } else { + const Tensor *input_z = fuse_add ? ctx.Input("Z") : nullptr; + sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x, + input_z, nullptr, nullptr, output, bitmask); + } + } +}; + +template +class ResNetUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("It must use CUDAPlace.")); + PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, CUDNN_DATA_HALF, + platform::errors::Unavailable( + "ResNetUnitOp only supports float16 for now.")); + + const Tensor *y_grad = ctx.Input(framework::GradVarName("Y")); + + const Tensor *x = ctx.Input("X"); + const Tensor *filter_x = ctx.Input("FilterX"); + const Tensor *scale_x = ctx.Input("ScaleX"); + const Tensor *bias_x = ctx.Input("BiasX"); + const Tensor *saved_mean_x = ctx.Input("SavedMeanX"); + const Tensor *saved_invstd_x = ctx.Input("SavedInvstdX"); + + const Tensor *conv_out_x = ctx.Input("ConvX"); + const Tensor *output = ctx.Input("Y"); + const Tensor *bitmask = ctx.Input("BitMask"); + + Tensor *x_grad = ctx.Output(framework::GradVarName("X")); + Tensor *filter_x_grad = + ctx.Output(framework::GradVarName("FilterX")); + Tensor *scale_x_grad = ctx.Output(framework::GradVarName("ScaleX")); + Tensor *bias_x_grad = ctx.Output(framework::GradVarName("BiasX")); + + int padding = ctx.Attr("padding"); + int stride = ctx.Attr("stride"); + int stride_z = ctx.Attr("stride_z"); + int dilation = ctx.Attr("dilation"); + int group = ctx.Attr("group"); + double eps = static_cast(ctx.Attr("epsilon")); + double momentum = static_cast(ctx.Attr("momentum")); + bool has_shortcut = ctx.Attr("has_shortcut"); + bool fuse_add = ctx.Attr("fuse_add"); + bool use_global_stats = ctx.Attr("use_global_stats"); + std::string act_type = ctx.Attr("act_type"); + + auto x_shape = framework::vectorize(x->dims()); + auto filter_x_shape = framework::vectorize(filter_x->dims()); + auto param_shape = framework::vectorize(scale_x->dims()); + auto output_shape = framework::vectorize(output->dims()); + auto bitmask_shape = framework::vectorize(bitmask->dims()); + + auto place = ctx.GetPlace(); + auto &dev_ctx = ctx.template device_context(); + + // 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad + Tensor conv_out_x_grad; + conv_out_x_grad.Resize(conv_out_x->dims()); + CudnnScaleBiasAddRelu sbar_x_op(dev_ctx, act_type, fuse_add, + has_shortcut, output_shape, param_shape, + bitmask_shape); + if (has_shortcut) { + // X Z + // | | + // NormConv NormConv + // | | + // BNStatsFinalize BNStatsFinalize + // \ / + // ScaleBiasAddRelu + // | + // Y + const Tensor *z = ctx.Input("Z"); + const Tensor *filter_z = ctx.Input("FilterZ"); + const Tensor *scale_z = ctx.Input("ScaleZ"); + const Tensor *bias_z = ctx.Input("BiasZ"); + const Tensor *saved_mean_z = ctx.Input("SavedMeanZ"); + const Tensor *saved_invstd_z = ctx.Input("SavedInvstdZ"); + const Tensor *conv_out_z = ctx.Input("ConvZ"); + + Tensor *z_grad = ctx.Output(framework::GradVarName("Z")); + Tensor *filter_z_grad = + ctx.Output(framework::GradVarName("FilterZ")); + Tensor *scale_z_grad = + ctx.Output(framework::GradVarName("ScaleZ")); + Tensor *bias_z_grad = ctx.Output(framework::GradVarName("BiasZ")); + + // 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad and z_grad_temp + Tensor z_grad_temp; + z_grad_temp.Resize(conv_out_z->dims()); + sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x, + *saved_mean_x, *saved_invstd_x, bitmask, + &conv_out_x_grad, &z_grad_temp, scale_x_grad, + bias_x_grad, eps); + + // 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z + Tensor conv_out_z_grad; + conv_out_z_grad.Resize(conv_out_z->dims()); + CudnnScaleBiasAddRelu sbar_z_op( + dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape); + sbar_z_op.Backward(dev_ctx, z_grad_temp, *conv_out_z, *scale_z, *bias_z, + *saved_mean_z, *saved_invstd_z, nullptr, + &conv_out_z_grad, nullptr, scale_z_grad, bias_z_grad, + eps); + + // 1.3 Backward of Conv for z, get z_grad and filter_z_grad + auto z_shape = framework::vectorize(z->dims()); + auto filter_z_shape = framework::vectorize(filter_z->dims()); + CudnnNormConvolutionGrad conv_z_op(dev_ctx, z_shape, filter_z_shape, + output_shape, padding, stride_z, + dilation, group); + conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad, + filter_z_grad); + } else { + // 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad (and z_grad) + Tensor *z_grad = + fuse_add ? ctx.Output(framework::GradVarName("Z")) : nullptr; + sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x, + *saved_mean_x, *saved_invstd_x, bitmask, + &conv_out_x_grad, z_grad, scale_x_grad, bias_x_grad, + eps); + } + + // 2. Backward of Conv for x, get x_grad and filter_x_grad + bool use_addto = ctx.Attr("use_addto"); + CudnnNormConvolutionGrad conv_x_op(dev_ctx, x_shape, filter_x_shape, + output_shape, padding, stride, + dilation, group); + conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad, + filter_x_grad, use_addto); + } +}; + +} // namespace operators +} // namespace paddle + +#if CUDNN_VERSION >= 8000 +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(resnet_unit, ops::ResNetUnitKernel); +REGISTER_OP_CUDA_KERNEL(resnet_unit_grad, + ops::ResNetUnitGradKernel); +#endif diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 0ffdfa9a2d8d5..84a970a9a2606 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -16,66 +16,141 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/pooling.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/gpu_launch_config.h" +#ifdef __HIPCC__ +#define POOLING_BLOCK_SIZE 256 +#else +#define POOLING_BLOCK_SIZE 512 +#endif + namespace paddle { namespace operators { namespace math { +struct FastDivModForPooling { + public: + platform::FastDivMod channel; + platform::FastDivMod width; + platform::FastDivMod height; + + explicit HOSTDEVICE FastDivModForPooling(const int channels, + const int output_width, + const int output_height) { + channel = platform::FastDivMod(channels); + width = platform::FastDivMod(output_width); + height = platform::FastDivMod(output_height); + } +}; + +struct FastDivModForPoolingWithMoreStaff { + public: + platform::FastDivMod channel; + platform::FastDivMod width; + platform::FastDivMod height; + platform::FastDivMod ksize_w; + platform::FastDivMod ksize_h; + platform::FastDivMod stride_w; + platform::FastDivMod stride_h; + + explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff( + const int channels, const int input_width, const int input_height, + const int ksize_width, const int ksize_height, const int stride_width, + const int stride_height) { + channel = platform::FastDivMod(channels); + width = platform::FastDivMod(input_width); + height = platform::FastDivMod(input_height); + ksize_w = platform::FastDivMod(ksize_width); + ksize_h = platform::FastDivMod(ksize_height); + stride_w = platform::FastDivMod(stride_width); + stride_h = platform::FastDivMod(stride_height); + } +}; + +template +__device__ void OffsetPreparationFor4Dimension( + int index, bool channel_last, FastDivModForPooling divmods, + const int pad_width, const int pad_height, const int aux_width, + const int aux_height, int* w_offset, int* h_offset, int* c_offset, + int* stride) { + if (!channel_last) { /* NCHW */ + auto input_width_divmod = divmods.width.Divmod(index); + auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]); + auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]); + *w_offset = input_width_divmod.val[1] + pad_width; + *h_offset = input_height_divmod.val[1] + pad_height; + *c_offset = channel_divmod.val[1]; + *stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) * + aux_height * aux_width; + } else { /* NHWC */ + auto c_divmod = divmods.channel.Divmod(index); + auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]); + auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]); + *c_offset = c_divmod.val[1]; + *w_offset = input_width_divmod.val[1] + pad_width; + *h_offset = input_height_divmod.val[1] + pad_height; + *stride = input_height_divmod.val[0] * aux_height * aux_width * + divmods.channel.divisor; + } +} + +int GetThreadsPerBlock(const platform::CUDADeviceContext& ctx, + int threads_per_block, int64_t numel) { + int sm_count = ctx.GetSMCount(); + if (numel / (sm_count << 1) < threads_per_block) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about twice of SM, to acquire better performance. + threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 1)); + } else if (numel / (sm_count << 2) < threads_per_block) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about 4 times of SM, to acquire better performance. + threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 2)); + } + // Number of threads per block shall be larger than 64. + return std::max(64, threads_per_block); +} + template -__global__ void KernelPool2D(const int nthreads, const T* input_data, - const int channels, const int input_height, - const int input_width, const int output_height, - const int output_width, const int ksize_height, - const int ksize_width, const int stride_height, - const int stride_width, const int padding_height, - const int padding_width, PoolProcess pool_process, - bool exclusive, bool adaptive, T* output_data, - bool channel_last = false) { +__global__ void KernelPool2D( + const int nthreads, const T* input_data, const int channels, + const int input_height, const int input_width, const int output_height, + const int output_width, const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, const int padding_height, + const int padding_width, FastDivModForPooling divmods, + PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data, + bool channel_last = false) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int pw, ph, c, batch_idx; - if (!channel_last) { /*NCHW*/ - pw = index % output_width; - ph = (index / output_width) % output_height; - c = (index / output_width / output_height) % channels; - batch_idx = index / output_width / output_height / channels; - } else { /*NHWC*/ - c = index % channels; - pw = (index / channels) % output_width; - ph = (index / channels / output_width) % output_height; - batch_idx = index / channels / output_width / output_height; - } + int hstart, hend, wstart, wend; + int w_offset, h_offset, c_offset, input_offset; + OffsetPreparationFor4Dimension( + index, channel_last, divmods, 0, 0, input_width, input_height, + &w_offset, &h_offset, &c_offset, &input_offset); + input_data += input_offset; - int hstart, hend; - int wstart, wend; if (adaptive) { - hstart = AdaptStartIndex(ph, input_height, output_height); - hend = AdaptEndIndex(ph, input_height, output_height); - - wstart = AdaptStartIndex(pw, input_width, output_width); - wend = AdaptEndIndex(pw, input_width, output_width); + hstart = AdaptStartIndex(h_offset, input_height, output_height); + hend = AdaptEndIndex(h_offset, input_height, output_height); + wstart = AdaptStartIndex(w_offset, input_width, output_width); + wend = AdaptEndIndex(w_offset, input_width, output_width); } else { - hstart = ph * stride_height - padding_height; + hstart = h_offset * stride_height - padding_height; hend = min(hstart + ksize_height, input_height); hstart = max(hstart, 0); - - wstart = pw * stride_width - padding_width; + wstart = w_offset * stride_width - padding_width; wend = min(wstart + ksize_width, input_width); wstart = max(wstart, 0); } - if (!channel_last) { - input_data += (batch_idx * channels + c) * input_height * input_width; - } else { - input_data += batch_idx * input_height * input_width * channels; - } T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - auto input_idx = channel_last ? (h * input_width + w) * channels + c - : h * input_width + w; + auto input_idx = channel_last + ? (h * input_width + w) * channels + c_offset + : h * input_width + w; pool_process.compute(input_data[input_idx], &ele); } } @@ -85,91 +160,109 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, output_data[index] = ele; } } -template + +template __global__ void KernelPool2DGrad( - const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, const int channels, const int input_height, - const int input_width, const int output_height, const int output_width, - const int ksize_height, const int ksize_width, const int stride_height, - const int stride_width, const int padding_height, const int padding_width, - PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad, - bool channel_last = false) { + const int nthreads, const T* __restrict__ input_data, + const T* __restrict__ output_data, const const T* __restrict__ output_grad, + const int output_width, const int output_height, const int input_width, + const int input_height, const int ksize_width, const int ksize_height, + const int stride_width, const int stride_height, const int padding_width, + const int padding_height, FastDivModForPoolingWithMoreStaff divmods, + PoolProcess pool_process, bool exclusive, bool adaptive, + T* __restrict__ input_grad, bool channel_last = false) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int w_offset, h_offset, offsetC, batch_idx; - if (!channel_last) { /* NCHW */ - w_offset = index % input_width + padding_width; - h_offset = (index / input_width) % input_height + padding_height; - offsetC = (index / input_width / input_height) % channels; - batch_idx = index / input_width / input_height / channels; - } else { /* NHWC */ - offsetC = index % channels; - w_offset = (index / channels) % input_width + padding_width; - h_offset = - (index / channels / input_width) % input_height + padding_height; - batch_idx = index / channels / input_width / input_height; + T input = static_cast(0); + T input_grad_data = static_cast(0); + int phstart, phend, pwstart, pwend; + int w_offset, h_offset, c_offset, output_offset; + OffsetPreparationFor4Dimension<>(index, channel_last, divmods, + padding_width, padding_height, + output_width, output_height, &w_offset, + &h_offset, &c_offset, &output_offset); + if (pool_process.use_x) { + input = input_data[index]; + output_data += output_offset; } + output_grad += output_offset; - int phstart, phend; - int pwstart, pwend; if (adaptive) { - phstart = AdaptStartIndex(h_offset, output_height, input_height); - phend = AdaptEndIndex(h_offset, output_height, input_height); + auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height); + auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width); + phstart = divmods.height.Div(h_offset * output_height); + pwstart = divmods.width.Div(w_offset * output_width); + phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0]; + pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0]; - pwstart = AdaptStartIndex(w_offset, output_width, input_width); - pwend = AdaptEndIndex(w_offset, output_width, input_width); - } else { - phstart = (h_offset < ksize_height) - ? 0 - : (h_offset - ksize_height) / stride_height + 1; - pwstart = (w_offset < ksize_width) - ? 0 - : (w_offset - ksize_width) / stride_width + 1; - phend = min(h_offset / stride_height + 1, output_height); - pwend = min(w_offset / stride_width + 1, output_width); - } - T gradient = static_cast(0.0); - T input = input_data[index]; - - int output_stride; - if (!channel_last) { - output_stride = - (batch_idx * channels + offsetC) * output_height * output_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width); + auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height); + auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1 + : ksize_w_divmod.val[0]; + auto tmp_height = ksize_h_divmod.val[1] > 0 + ? ksize_h_divmod.val[0] + 1 + : ksize_h_divmod.val[0]; + int pool_size = tmp_height * tmp_width; + int tmp_idx = ph * output_width + pw; + int output_sub_idx = + channel_last ? tmp_idx * divmods.channel.divisor + c_offset + : tmp_idx; + T ouput_value = pool_process.use_x ? output_data[output_sub_idx] + : static_cast(0); + pool_process.compute(input, ouput_value, output_grad[output_sub_idx], + static_cast(1.0 / pool_size), + &input_grad_data); + } + } } else { - output_stride = batch_idx * output_height * output_width * channels; - } - - output_data += output_stride; - output_grad += output_stride; - - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - int pool_size; - if (adaptive) { - pool_size = static_cast(ceil(static_cast(input_height) / - ksize_height)) * - static_cast( - ceil(static_cast(input_width) / ksize_width)); - } else { - int hstart = ph * stride_height - padding_height; - int wstart = pw * stride_width - padding_width; - int hend = min(hstart + ksize_height, input_height); - int wend = min(wstart + ksize_width, input_width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - pool_size = exclusive ? (hend - hstart) * (wend - wstart) - : ksize_height * ksize_width; + auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height); + auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width); + phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1; + pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1; + phend = min(divmods.stride_h.Div(h_offset) + 1, output_height); + pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width); + + if (exclusive) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int pool_size = (hend - hstart) * (wend - wstart); + int tmp_idx = ph * output_width + pw; + int output_sub_idx = + channel_last ? tmp_idx * divmods.channel.divisor + c_offset + : tmp_idx; + T ouput_value = pool_process.use_x ? output_data[output_sub_idx] + : static_cast(0); + pool_process.compute( + input, ouput_value, output_grad[output_sub_idx], + static_cast(1.0 / pool_size), &input_grad_data); + } + } + } else { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + int pool_size = ksize_height * ksize_width; + int tmp_idx = ph * output_width + pw; + int output_sub_idx = + channel_last ? tmp_idx * divmods.channel.divisor + c_offset + : tmp_idx; + T ouput_value = pool_process.use_x ? output_data[output_sub_idx] + : static_cast(0); + pool_process.compute( + input, ouput_value, output_grad[output_sub_idx], + static_cast(1.0 / pool_size), &input_grad_data); + } } - - int output_sub_idx = channel_last - ? (ph * output_width + pw) * channels + offsetC - : ph * output_width + pw; - pool_process.compute(input, output_data[output_sub_idx], - output_grad[output_sub_idx], - static_cast(1.0 / pool_size), &gradient); } } - input_grad[index] = gradient; + input_grad[index] = input_grad_data; } } @@ -180,45 +273,32 @@ __global__ void KernelMaxPool2DGrad( const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, - T* input_grad, bool channel_last = false) { + T* input_grad, FastDivModForPooling divmods, bool channel_last = false) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int pw, ph, c, batch_idx; - if (!channel_last) { /* NCHW */ - pw = index % output_width; - ph = (index / output_width) % output_height; - c = (index / output_width / output_height) % channels; - batch_idx = index / output_width / output_height / channels; - } else { /* NHWC */ - c = index % channels; - pw = (index / channels) % output_width; - ph = (index / channels / output_width) % output_height; - batch_idx = index / channels / output_width / output_height; - } - int hstart = ph * stride_height - padding_height; + int w_offset, h_offset, c_offset, input_offset; + OffsetPreparationFor4Dimension( + index, channel_last, divmods, 0, 0, input_width, input_height, + &w_offset, &h_offset, &c_offset, &input_offset); + input_data += input_offset; + input_grad += input_offset; + + int hstart = h_offset * stride_height - padding_height; int hend = min(hstart + ksize_height, input_height); hstart = max(hstart, 0); - int wstart = pw * stride_width - padding_width; + int wstart = w_offset * stride_width - padding_width; int wend = min(wstart + ksize_width, input_width); wstart = max(wstart, 0); - int input_stride; - if (!channel_last) { - input_stride = (batch_idx * channels + c) * input_height * input_width; - } else { - input_stride = batch_idx * input_height * input_width * channels; - } - input_data += input_stride; - input_grad += input_stride; - T ele = output_data[index]; int maxIndex = -1; bool stop = false; for (int h = hstart; h < hend && !stop; ++h) { for (int w = wstart; w < wend && !stop; ++w) { - int input_data_idx = channel_last ? (h * input_width + w) * channels + c - : h * input_width + w; + int input_data_idx = channel_last + ? (h * input_width + w) * channels + c_offset + : h * input_width + w; if (ele == input_data[input_data_idx]) { maxIndex = input_data_idx; stop = true; @@ -264,10 +344,13 @@ void Pool2dDirectCUDAFunctor::operator()( dim3 threads(thread_num, 1); dim3 grid(blocks, 1); + auto pool_divmods = + FastDivModForPooling(input_channels, output_width, output_height); KernelPool2D<<>>( nthreads, input, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, - padding_height, padding_width, pool_compute, exclusive, adaptive, output); + padding_height, padding_width, pool_divmods, pool_compute, exclusive, + adaptive, output); } /* @@ -311,11 +394,14 @@ class Pool2dFunctor { int blocks = (nthreads + thread_num - 1) / thread_num; dim3 threads(thread_num, 1); dim3 grid(blocks, 1); + + auto pool_divmods = + FastDivModForPooling(input_channels, output_width, output_height); KernelPool2D<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, - stride_width, padding_height, padding_width, pool_process, exclusive, - adaptive, output_data); + stride_width, padding_height, padding_width, pool_divmods, pool_process, + exclusive, adaptive, output_data); } void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ksize, @@ -357,11 +443,14 @@ class Pool2dFunctor { int blocks = (nthreads + thread_num - 1) / thread_num; dim3 threads(thread_num, 1); dim3 grid(blocks, 1); + + auto pool_divmods = + FastDivModForPooling(input_channels, output_width, output_height); KernelPool2D<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, - stride_width, padding_height, padding_width, pool_process, exclusive, - adaptive, output_data, channel_last); + stride_width, padding_height, padding_width, pool_divmods, pool_process, + exclusive, adaptive, output_data, channel_last); } }; /* @@ -402,15 +491,18 @@ class Pool2dGradFunctor { T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelPool2DGrad<<>>( - nthreads, input_data, output_data, output_grad_data, input_channels, - input_height, input_width, output_height, output_width, ksize_height, - ksize_width, stride_height, stride_width, padding_height, padding_width, - pool_process, exclusive, adaptive, input_grad_data); + int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads); + int grids = (nthreads + blocks - 1) / blocks; + + auto pool_divmods = FastDivModForPoolingWithMoreStaff( + input_channels, input_width, input_height, ksize_width, ksize_height, + stride_width, stride_height); + + KernelPool2DGrad<<>>( + nthreads, input_data, output_data, output_grad_data, output_width, + output_height, input_width, input_height, ksize_width, ksize_height, + stride_width, stride_height, padding_width, padding_height, + pool_divmods, pool_process, exclusive, adaptive, input_grad_data); } void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, @@ -424,7 +516,6 @@ class Pool2dGradFunctor { bool channel_last = (data_format == "NHWC"); const int batch_size = input.dims()[0]; - const int input_channels = channel_last ? input.dims()[3] : input.dims()[1]; const int input_height = channel_last ? input.dims()[1] : input.dims()[2]; const int input_width = channel_last ? input.dims()[2] : input.dims()[3]; @@ -447,19 +538,22 @@ class Pool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelPool2DGrad<<>>( - nthreads, input_data, output_data, output_grad_data, input_channels, - input_height, input_width, output_height, output_width, ksize_height, - ksize_width, stride_height, stride_width, padding_height, padding_width, - pool_process, exclusive, adaptive, input_grad_data, channel_last); + int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads); + int grids = (nthreads + blocks - 1) / blocks; + + auto pool_divmods = FastDivModForPoolingWithMoreStaff( + input_channels, input_width, input_height, ksize_width, ksize_height, + stride_width, stride_height); + + KernelPool2DGrad<<>>( + nthreads, input_data, output_data, output_grad_data, output_width, + output_height, input_width, input_height, ksize_width, ksize_height, + stride_width, stride_height, padding_width, padding_height, + pool_divmods, pool_process, exclusive, adaptive, input_grad_data, + channel_last); } }; @@ -505,11 +599,13 @@ class MaxPool2dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); + auto pool_divmods = + FastDivModForPooling(input_channels, output_width, output_height); KernelMaxPool2DGrad<<>>( nthreads, input_data, output_data, output_grad_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, - input_grad_data); + input_grad_data, pool_divmods); } void operator()( const platform::CUDADeviceContext& context, @@ -550,11 +646,14 @@ class MaxPool2dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); + auto pool_divmods = + FastDivModForPooling(input_channels, output_width, output_height); + KernelMaxPool2DGrad<<>>( nthreads, input_data, output_data, output_grad_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, - input_grad_data, channel_last); + input_grad_data, pool_divmods, channel_last); } }; @@ -689,35 +788,40 @@ __global__ void KernelPool3D( } } -template +template __global__ void KernelPool3DGrad( - const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, const int channels, const int input_depth, - const int input_height, const int input_width, const int output_depth, - const int output_height, const int output_width, const int ksize_depth, - const int ksize_height, const int ksize_width, const int stride_depth, - const int stride_height, const int stride_width, const int padding_depth, - const int padding_height, const int padding_width, PoolProcess pool_process, - bool exclusive, bool adaptive, T* input_grad, bool channel_last = false) { + const int nthreads, const T* __restrict__ input_data, + const T* __restrict__ output_data, const T* __restrict__ output_grad, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width, PoolProcess pool_process, bool exclusive, + bool adaptive, T* input_grad, bool channel_last = false) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int w_offset, h_offset, d_offset, offsetC, batch_idx; + int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride; + T input = static_cast(0); if (!channel_last) { /* "NCDHW" */ w_offset = index % input_width + padding_width; h_offset = (index / input_width) % input_height + padding_height; d_offset = (index / input_width / input_height) % input_depth + padding_depth; - offsetC = (index / input_width / input_height / input_depth) % channels; + c_offset = (index / input_width / input_height / input_depth) % channels; batch_idx = index / input_width / input_height / input_depth / channels; - + output_stride = (batch_idx * channels + c_offset) * output_depth * + output_height * output_width; } else { /* "NDHWC" */ - offsetC = index % channels; + c_offset = index % channels; w_offset = (index / channels) % input_width + padding_width; h_offset = (index / channels / input_width) % input_height + padding_height; d_offset = (index / channels / input_width / input_height) % input_depth + padding_depth; batch_idx = index / channels / input_width / input_height / input_depth; + output_stride = + batch_idx * output_depth * output_height * output_width * channels; } int pdstart, pdend; @@ -746,20 +850,12 @@ __global__ void KernelPool3DGrad( phend = min((h_offset) / stride_height + 1, output_height); pwend = min((w_offset) / stride_width + 1, output_width); } - - T gradient = static_cast(0.0); - T input = input_data[index]; - - int output_stride; - if (!channel_last) { - output_stride = (batch_idx * channels + offsetC) * output_depth * - output_height * output_width; - } else { - output_stride = - batch_idx * output_depth * output_height * output_width * channels; + if (pool_process.use_x) { + input = input_data[index]; + output_data += output_stride; } - output_data += output_stride; output_grad += output_stride; + T input_grad_data = static_cast(0.0); for (int pd = pdstart; pd < pdend; ++pd) { for (int ph = phstart; ph < phend; ++ph) { @@ -792,16 +888,17 @@ __global__ void KernelPool3DGrad( int output_sub_idx = channel_last ? ((pd * output_height + ph) * output_width + pw) * channels + - offsetC + c_offset : (pd * output_height + ph) * output_width + pw; - - pool_process.compute(input, output_data[output_sub_idx], - output_grad[output_sub_idx], - static_cast(1.0 / pool_size), &gradient); + T ouput_value = pool_process.use_x ? output_data[output_sub_idx] + : static_cast(0); + pool_process.compute(input, ouput_value, output_grad[output_sub_idx], + static_cast(1.0 / pool_size), + &input_grad_data); } } } - input_grad[index] = gradient; + input_grad[index] = input_grad_data; } } @@ -1088,7 +1185,7 @@ class Pool3dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelPool3DGrad<<>>( + KernelPool3DGrad<<>>( nthreads, input_data, output_data, output_grad_data, input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, @@ -1142,7 +1239,7 @@ class Pool3dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelPool3DGrad<<>>( + KernelPool3DGrad<<>>( nthreads, input_data, output_data, output_grad_data, input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, @@ -1315,33 +1412,33 @@ __global__ void KernelMaxPool2dWithIdx( const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, - const int padding_width, bool adaptive, T1* output_data, T2* mask_data) { + const int padding_width, bool adaptive, T1* output_data, T2* mask_data, + FastDivModForPooling divmods) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int pw = index % output_width; - int ph = (index / output_width) % output_height; - int c = (index / output_width / output_height) % channels; - int batch_idx = index / output_width / output_height / channels; + int hstart, hend, wstart, wend; + int w_offset, h_offset, c_offset, input_offset; + OffsetPreparationFor4Dimension( + index, false, divmods, 0, 0, input_width, input_height, &w_offset, + &h_offset, &c_offset, &input_offset); + input_data += input_offset; - int hstart, hend; - int wstart, wend; if (adaptive) { - hstart = AdaptStartIndex(ph, input_height, output_height); - hend = AdaptEndIndex(ph, input_height, output_height); + hstart = AdaptStartIndex(h_offset, input_height, output_height); + hend = AdaptEndIndex(h_offset, input_height, output_height); - wstart = AdaptStartIndex(pw, input_width, output_width); - wend = AdaptEndIndex(pw, input_width, output_width); + wstart = AdaptStartIndex(w_offset, input_width, output_width); + wend = AdaptEndIndex(w_offset, input_width, output_width); } else { - hstart = ph * stride_height - padding_height; + hstart = h_offset * stride_height - padding_height; hend = min(hstart + ksize_height, input_height); hstart = max(hstart, 0); - wstart = pw * stride_width - padding_width; + wstart = w_offset * stride_width - padding_width; wend = min(wstart + ksize_width, input_width); wstart = max(wstart, 0); } - input_data += (batch_idx * channels + c) * input_height * input_width; T1 ele = -FLT_MAX; int max_index = -1; for (int h = hstart; h < hend; ++h) { @@ -1365,16 +1462,17 @@ __global__ void KernelMaxPool2DWithIdxGrad( const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, bool adaptive, - T1* input_grad) { + T1* input_grad, FastDivModForPooling divmods) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int w_offset = index % input_width; - int h_offset = (index / input_width) % input_height; - int offsetC = (index / input_width / input_height) % channels; - int batch_idx = index / input_width / input_height / channels; + int phstart, phend, pwstart, pwend; + int w_offset, h_offset, c_offset, output_offset; + OffsetPreparationFor4Dimension( + index, false, divmods, 0, 0, output_width, output_height, &w_offset, + &h_offset, &c_offset, &output_offset); + mask_data += output_offset; + output_grad += output_offset; - int phstart, phend; - int pwstart, pwend; if (adaptive) { phstart = h_offset * output_height / input_height; phend = @@ -1396,20 +1494,15 @@ __global__ void KernelMaxPool2DWithIdxGrad( pwend = min((w_offset + padding_width) / stride_width + 1, output_width); } - T1 gradient = 0; + T1 input_grad_data = 0; int input_current_featuremap_idx = h_offset * input_width + w_offset; - int output_idx = - (batch_idx * channels + offsetC) * output_height * output_width; - - mask_data += output_idx; - output_grad += output_idx; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) - gradient += output_grad[ph * output_width + pw]; + input_grad_data += output_grad[ph * output_width + pw]; } } - input_grad[index] = gradient; + input_grad[index] = input_grad_data; } } @@ -1453,11 +1546,14 @@ class MaxPool2dWithIndexFunctor { int blocks = (nthreads + thread_num - 1) / thread_num; dim3 threads(thread_num, 1); dim3 grid(blocks, 1); + + auto pool_divmods = + FastDivModForPooling(input_channels, output_width, output_height); KernelMaxPool2dWithIdx<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, adaptive, output_data, - mask_data); + mask_data, pool_divmods); } }; @@ -1497,11 +1593,13 @@ class MaxPool2dWithIndexGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); + auto pool_divmods = + FastDivModForPooling(input_channels, input_width, input_height); KernelMaxPool2DWithIdxGrad<<>>( nthreads, output_grad_data, mask_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, adaptive, - input_grad_data); + input_grad_data, pool_divmods); } }; @@ -1590,7 +1688,8 @@ __global__ void KernelMaxPool3DWithIdxGrad( int w_offset = index % input_width; int h_offset = (index / input_width) % input_height; int d_offset = (index / input_width / input_height) % input_depth; - int offsetC = (index / input_width / input_height / input_depth) % channels; + int c_offset = + (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; int pdstart, pdend; @@ -1625,10 +1724,10 @@ __global__ void KernelMaxPool3DWithIdxGrad( pwend = min((w_offset + padding_width) / stride_width + 1, output_width); } - T1 gradient = 0; + T1 input_grad_data = 0; int input_current_feature_map_idx = (d_offset * input_height + h_offset) * input_width + w_offset; - int output_idx = (batch_idx * channels + offsetC) * output_depth * + int output_idx = (batch_idx * channels + c_offset) * output_depth * output_height * output_width; mask += output_idx; output_grad += output_idx; @@ -1638,12 +1737,12 @@ __global__ void KernelMaxPool3DWithIdxGrad( for (int pw = pwstart; pw < pwend; ++pw) { if (mask[(pd * output_height + ph) * output_width + pw] == input_current_feature_map_idx) - gradient += + input_grad_data += output_grad[(pd * output_height + ph) * output_width + pw]; } } } - input_grad[index] = gradient; + input_grad[index] = input_grad_data; } } diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 2a94d113b6f65..4743f0dc9faf1 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -68,8 +68,9 @@ class AvgPool { template class MaxPoolGrad { public: - DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, - T* dx) { + static constexpr bool use_x = true; + HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, + T* dx) { *dx += dy * static_cast(x == y); } }; @@ -77,8 +78,9 @@ class MaxPoolGrad { template class AvgPoolGrad { public: - DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, - T* dx) { + static constexpr bool use_x = false; + HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, + T* dx) { *dx += (scale * dy); } }; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 8f30dd5b2e68a..b0a85b7d5b92f 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -13,46 +13,158 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/optimizers/lars_momentum_op.h" -#include "paddle/fluid/operators/optimizers/momentum_op.h" namespace paddle { namespace operators { +class LarsMomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut", + "LarsMomentum"); + PADDLE_ENFORCE_EQ( + ctx->GetInputsVarType("Param").front(), + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The input var's type should be LoDTensor, but the received is %s", + ctx->GetInputsVarType("Param").front())); + + auto lr_dims = ctx->GetInputsDim("LearningRate"); + auto grad_dim = ctx->GetInputsDim("Grad"); + auto param_dim = ctx->GetInputsDim("Param"); + auto velocity_dim = ctx->GetInputsDim("Velocity"); + auto lars_weight_decays = + ctx->Attrs().Get>("lars_weight_decay"); + auto multi_precision = ctx->Attrs().Get("multi_precision"); + + PADDLE_ENFORCE_EQ( + param_dim.size(), grad_dim.size(), + platform::errors::InvalidArgument( + "Input(Param) and Input(Grad) of LarsMomentumOp should have " + "same quantity. But number of Param is [%d] and Grad is [%d].", + param_dim.size(), grad_dim.size())); + PADDLE_ENFORCE_EQ( + param_dim.size(), velocity_dim.size(), + platform::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp should " + "have same quantity. But number of Param is [%d] and Velocity " + "is [%d].", + param_dim.size(), velocity_dim.size())); + PADDLE_ENFORCE_EQ( + lars_weight_decays.size(), grad_dim.size(), + platform::errors::InvalidArgument( + "Attr(Lars_weight_decay) and " + "Input(Grad) of LarsMomentumOp should have same quantity. " + "But number of Lars_weight_decay is [%d] and Grad is [%d].", + lars_weight_decays.size(), grad_dim.size())); + + if (multi_precision) { + OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam", + "LarsMomentumMultiPrecision"); + OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output", + "MasterParamOut", "LarsMomentumMultiPrecision"); + } + for (size_t i = 0; i < lr_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1, + platform::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + framework::product(lr_dims[i]))); + } + + for (size_t i = 0; i < param_dim.size(); ++i) { + PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i], + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx->Inputs("Grad")[i].front(), + ctx->GetInputsVarType("Grad")[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], grad_dim[i], + platform::errors::InvalidArgument( + "Input(Param) and Input(Grad) input of LarsMomentumOp shall " + "have same dimension. But Param`s dim is [%s] and Grad's dim " + "is [%s].", + param_dim[i], grad_dim[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], velocity_dim[i], + platform::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp shall have " + "same dimension. But Param dim [%s] differs with Velocity dim " + "[%s].", + param_dim[i], velocity_dim[i])); + } + ctx->SetOutputsDim("ParamOut", param_dim); + ctx->SetOutputsDim("VelocityOut", param_dim); + if (ctx->HasOutputs("MasterParamOut")) { + ctx->SetOutputsDim("MasterParamOut", param_dim); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Param", "(LoDTensor, default LoDTensor) " - "Input parameter that has to be updated"); + "Input parameter that has to be updated") + .AsDuplicable(); AddInput("Grad", "(LoDTensor, default LoDTensor) " - "Input gradient of the parameter"); + "Input gradient of the parameter") + .AsDuplicable(); AddInput("Velocity", "(LoDTensor, default LoDTensor) " "Input velocity (corresponding to the parameter) " - "that has to be updated"); + "that has to be updated") + .AsDuplicable(); AddInput("LearningRate", "(LoDTensor, default LoDTensor) " - "Input learning rate"); - AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); - + "Input learning rate") + .AsDuplicable(); + AddInput("MasterParam", "FP32 master weight for AMP.") + .AsDuplicable() + .AsDispensable(); AddOutput("ParamOut", "(LoDTensor) This output is updated parameter. " - "It shared memory with Input(Param)."); + "It shared memory with Input(Param).") + .AsDuplicable(); AddOutput("VelocityOut", "(LoDTensor) This output is updated velocity. " - "It shared memory with Input(Velocity)."); + "It shared memory with Input(Velocity).") + .AsDuplicable(); AddOutput("MasterParamOut", "The updated FP32 master weight for AMP. " "It shared memory with Input(MasterParam).") + .AsDuplicable() .AsDispensable(); - AddAttr("mu", "(float) Momentum coefficient"); AddAttr("lars_coeff", "(float, default 0.001) LARS coefficient.") .SetDefault(0.001); - AddAttr("lars_weight_decay", - "(float, default 0.0005) LARS weight decay") - .SetDefault(0.0005); + AddAttr>( + "lars_weight_decay", + "(std::vector, default 0.0005) LARS weight decay params") + .SetDefault({0.0005}); AddAttr("epsilon", "(float, default 0.0) epsilon to avoid Division by Zero.") .SetDefault(0.0); @@ -68,10 +180,8 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Lars Momentum Optimizer. - This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each weight using a local learning rate: - $$ local\_lr = \eta * \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ @@ -79,10 +189,8 @@ velocity = mu * velocity + local\_lr * (grad + \beta * param) \\ param = param - velocity. \\ $$ - Note that we use lars_weight_decay here to decay weights, you may need not to use L2 regularizers in case of using LARS. - )DOC"); } }; @@ -96,7 +204,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR( - lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, + lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::LarsMomentumOpVarTypeInference); diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 42477232e7ca1..2c27a2135c14b 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -14,7 +14,21 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h" +#include "paddle/fluid/platform/fast_divmod.h" + +#if CUDA_VERSION >= 11000 +#include +#endif + +#ifdef __HIPCC__ +#define LARS_BLOCK_SIZE 256 +#else +#define LARS_BLOCK_SIZE 512 +#endif + +#define LARS_MAX_MERGED_OPS 60 namespace paddle { namespace operators { @@ -22,124 +36,472 @@ namespace operators { template using MultiPrecisionType = typename details::MPTypeTrait::Type; -template -__global__ void MomentumLarsKernel( - const T* p, const T* g, const MT* v, - const MultiPrecisionType* learning_rate, const MT mu, const int64_t num, - const MT lars_coeff, const MT lars_weight_decay, - const MultiPrecisionType* p_norm, const MultiPrecisionType* g_norm, - T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out, - const MultiPrecisionType rescale_grad) { - const MT lr = static_cast(learning_rate[0]); - MT local_lr = lr; - const MT p_n = static_cast(p_norm[0]); - const MT g_n = static_cast(g_norm[0]); +__device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); } +__device__ __forceinline__ double Sqrt(double x) { return sqrt(x); } +__device__ __forceinline__ float Fma(float x, float y, float z) { + return fmaf(x, y, z); +} +__device__ __forceinline__ double Fma(double x, double y, double z) { + return fma(x, y, z); +} + +template +class LarsThreadConfig { + public: + int grid_for_norm; + int grid_for_lars; +#if CUDA_VERSION >= 11000 - if (lars_weight_decay > static_cast(0) && p_n > static_cast(0) && - g_n > static_cast(0)) { - local_lr = - lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon); + private: + int grid_stride; + + public: + explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) { + int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; + grid_for_lars = + std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); + grid_stride = LARS_BLOCK_SIZE * grid_for_lars; } - CUDA_KERNEL_LOOP(i, num) { - MT grad = static_cast(g[i]) * static_cast(rescale_grad); - MT param = master_p ? master_p[i] : static_cast(p[i]); - MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param); - MT p_new = param - v_new; + int GetRepeatTimes(int64_t numel) { + return (numel + grid_stride - 1) / grid_stride - 1; + } +#else + int repeat_times; + explicit LarsThreadConfig(const int64_t numel) { + int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; + grid_for_norm = std::min(grid, LARS_BLOCK_SIZE); + const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE; + repeat_times = (numel + grid_stride - 1) / grid_stride - 1; + // Determine to read 4 fp16 or float data once, but 2 double data once. + grid_for_lars = + std::is_same::value + ? (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1) + : (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2); + } +#endif +}; + +template +__device__ inline void VectorizeLarsUpdate( + const T* __restrict__ grad, const MT* param, const MT* velocity, + T* param_out, MT* velocity_out, const MT mu, MT local_lr, + const MT lars_weight_decay, const MT rescale_grad, const int tid, + const int grid_stride, const int numel, MT* master_param_out = nullptr) { + using VecType = paddle::platform::AlignedVector; + using VecMType = paddle::platform::AlignedVector; + int main = numel >> (VecSize >> 1); + int tail_offset = main * VecSize; - v_out[i] = v_new; - p_out[i] = static_cast(p_new); - if (master_p_out) master_p_out[i] = p_new; + const VecType* grad_vec = reinterpret_cast(grad); + const VecMType* param_vec = reinterpret_cast(param); + const VecMType* velocity_vec = reinterpret_cast(velocity); + VecType* param_out_vec = reinterpret_cast(param_out); + VecMType* velocity_out_vec = reinterpret_cast(velocity_out); + + VecMType* master_param_out_vec; + if (IsAmp) { + master_param_out_vec = reinterpret_cast(master_param_out); + } + + for (int i = tid; i < main; i += grid_stride) { + VecType param_out_tmp; + VecMType velocity_tmp, param_tmp; + VecType grad_data = grad_vec[i]; + VecMType param_data = param_vec[i]; + VecMType velocity_data = velocity_vec[i]; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + MT grad_val = static_cast(grad_data[j]) * rescale_grad; + velocity_tmp[j] = + Fma(velocity_data[j], mu, + local_lr * Fma(lars_weight_decay, param_data[j], grad_val)); + param_tmp[j] = param_data[j] - velocity_tmp[j]; + param_out_tmp[j] = static_cast(param_tmp[j]); + } + param_out_vec[i] = param_out_tmp; + velocity_out_vec[i] = velocity_tmp; + if (IsAmp) { + master_param_out_vec[i] = param_tmp; + } + } + + for (int i = tid + tail_offset; i < numel; i += grid_stride) { + MT grad_val = static_cast(grad[i]) * rescale_grad; + MT param_val = param[i]; + MT velocity_tmp = Fma(velocity[i], mu, local_lr * Fma(lars_weight_decay, + param_val, grad_val)); + MT param_tmp = param_val - velocity_tmp; + param_out[i] = static_cast(param_tmp); + velocity_out[i] = velocity_tmp; + if (IsAmp) { + master_param_out[i] = param_tmp; + } } } -template -class LarsMomentumOpCUDAKernel : public framework::OpKernel { - using MPDType = MultiPrecisionType; +#if CUDA_VERSION >= 11000 +/* Once CUDA_VERSION is beyond 11, cooperative_groups can be involved in without + --rdc=true compile flag, then L2_norm kernel can be set with __device__ and + cooperative_groups::grid_group also can be involved. Otherwise, adding this + flag may affect much, L2_norm kernel shall be set with __global__.*/ +// TODO(limingshu): declaration of cooperative_groups wapper is invalid in host. +template +__forceinline__ __device__ void L2NormKernel( + const cooperative_groups::grid_group* cg, +#else +template +__global__ void L2NormKernel( +#endif + const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, const int64_t numel, const int repeat_times, + const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr, + MT* __restrict__ g_n = nullptr) { + __shared__ MT s_buffer[2]; + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int grid_stride = LARS_BLOCK_SIZE * gridDim.x; - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const bool multi_precision = ctx.Attr("multi_precision"); - if (multi_precision) { - InnerCompute(ctx, multi_precision); + MT p_tmp = static_cast(0); + MT g_tmp = static_cast(0); + while (tid < numel) { + MT tmp0 = static_cast(p_data[tid]); + MT tmp1 = static_cast(g_data[tid]); + p_tmp += (tmp0 * tmp0); + g_tmp += (tmp1 * tmp1); + tid += grid_stride; + } + p_tmp = math::blockReduceSum(p_tmp, FINAL_MASK); + g_tmp = math::blockReduceSum(g_tmp, FINAL_MASK); + + if (threadIdx.x == 0) { + p_buffer[blockIdx.x] = p_tmp; + g_buffer[blockIdx.x] = g_tmp; + } +#if CUDA_VERSION >= 11000 + cg->sync(); // Grid sync for writring partial result to gloabl memory + MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; + MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; + MT tmp0 = math::blockReduceSum(p_part_sum, FINAL_MASK); + MT tmp1 = math::blockReduceSum(g_part_sum, FINAL_MASK); + if (threadIdx.x == 0) { + s_buffer[0] = tmp0; + s_buffer[1] = tmp1; + } + __syncthreads(); + *p_n = Sqrt(s_buffer[0]); + *g_n = rescale_grad * Sqrt(s_buffer[1]); +#endif +} + +template +__forceinline__ __device__ void MomentumUpdate( + const T* param, const T* __restrict__ grad, const MT* velocity, + T* param_out, MT* velocity_out, const MT* master_param, + MT* master_param_out, const MT* __restrict__ learning_rate, const MT mu, + const MT lars_weight_decay, const MT lars_coeff, const MT epsilon, + const MT rescale_grad, const MT param_norm, const MT grad_norm, + const int tid, const int grid_stride, const int64_t numel, + const bool is_amp) { + const MT lr = learning_rate[0]; + MT local_lr = lr; + if (lars_weight_decay > static_cast(0)) { + local_lr = lr * lars_coeff * param_norm / + (fma(lars_weight_decay, param_norm, grad_norm) + epsilon); + } + if (is_amp) { + VectorizeLarsUpdate( + grad, master_param, velocity, param_out, velocity_out, mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel, + master_param_out); + } else { + if (std::is_same::value || + std::is_same::value) { + /* TODO(limingshu): pointer cast may damage memory accessing for fp16 */ + VectorizeLarsUpdate( + grad, reinterpret_cast(param), velocity, param_out, + velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); } else { - InnerCompute(ctx, multi_precision); + VectorizeLarsUpdate( + grad, reinterpret_cast(param), velocity, param_out, + velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); } } +} - private: - template - void InnerCompute(const framework::ExecutionContext& ctx, - const bool multi_precision) const { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto learning_rate = ctx.Input("LearningRate"); - - const framework::Tensor* master_param = nullptr; - framework::Tensor* master_param_out = nullptr; - if (multi_precision) { - bool has_master = - ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); - PADDLE_ENFORCE_EQ(has_master, true, - platform::errors::InvalidArgument( - "The Input(MasterParam) and Output(MasterParamOut) " - "should not be null when " - "the attr `multi_precision` is true")); - master_param = ctx.Input("MasterParam"); - master_param_out = ctx.Output("MasterParamOut"); - } +#if CUDA_VERSION >= 11000 +template +struct LarsParamWarpper { + int64_t numel_arr[LARS_MAX_MERGED_OPS]; + int repeat_arr[LARS_MAX_MERGED_OPS]; + const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS]; + const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS]; + T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS]; + MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS]; + MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS]; + MT weight_decay_arr[LARS_MAX_MERGED_OPS]; +}; - const MT* master_p = multi_precision ? master_param->data() : nullptr; - MT* master_p_out = multi_precision - ? master_param_out->mutable_data(ctx.GetPlace()) - : nullptr; +template +__global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, + MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, + const int op_num, const MT mu, + const MT lars_coeff, const MT epsilon, + const MT rescale_grad, + const bool is_amp) { + int grid_stride = gridDim.x * LARS_BLOCK_SIZE; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); + for (int i = 0; i < op_num; ++i) { + int numel = lars_warpper.numel_arr[i]; + MT param_norm = static_cast(0); + MT grad_norm = static_cast(0); + L2NormKernel(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], + p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i], + rescale_grad, 0, ¶m_norm, &grad_norm); + MomentumUpdate( + lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], + lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i], + lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i], + lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu, + lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad, + param_norm, grad_norm, tid, grid_stride, numel, is_amp); + } +} +#endif - T* p_out = param_out->mutable_data(ctx.GetPlace()); - MT* v_out = velocity_out->mutable_data(ctx.GetPlace()); +template +__global__ void MomentumLarsKernel( + const T* param, const T* __restrict__ grad, const MT* velocity, + T* param_out, MT* velocity_out, const MT* master_param, + MT* master_param_out, const MT* __restrict__ learning_rate, + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, + const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, + const MT rescale_grad, const int repeat_times, const int thresh, + const int64_t numel, const bool is_amp) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int grid_stride = gridDim.x * LARS_BLOCK_SIZE; +#if CUDA_VERSION >= 11000 + const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); + MT param_norm = static_cast(0); + MT grad_norm = static_cast(0); + L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times, + rescale_grad, gridDim.x, ¶m_norm, &grad_norm); +#else + const MT rescale_grad_pow = rescale_grad * rescale_grad; + MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; + MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; + __syncthreads(); + MT param_norm = Sqrt(math::blockReduceSum(param_part_norm, FINAL_MASK)); + MT grad_norm = Sqrt(rescale_grad_pow * + math::blockReduceSum(grad_part_norm, FINAL_MASK)); +#endif + MomentumUpdate(param, grad, velocity, param_out, velocity_out, + master_param, master_param_out, learning_rate, mu, + lars_weight_decay, lars_coeff, epsilon, rescale_grad, + param_norm, grad_norm, tid, grid_stride, numel, is_amp); +} + +template +inline void SeparatedLarsMomentumOpCUDAKernel( + const platform::CUDADeviceContext& cuda_ctx, const T* param_data, + T* param_out_data, const MT* velocity_data, MT* velocity_out_data, + const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu, + const MT lars_coeff, const MT weight_decay, const MT epsilon, + const MT rescale_grad, const int64_t numel, const MT* master_param_data, + MT* master_out_data, const bool is_amp) { + LarsThreadConfig lars_thread_config(numel); + L2NormKernel<<>>( + param_data, grad_data, p_buffer, g_buffer, numel, + lars_thread_config.repeat_times, rescale_grad); + + MomentumLarsKernel<<>>( + param_data, grad_data, velocity_data, param_out_data, velocity_out_data, + master_param_data, master_out_data, lr, p_buffer, g_buffer, mu, + lars_coeff, weight_decay, epsilon, rescale_grad, 0, + lars_thread_config.grid_for_norm, numel, is_amp); +} + +template +class LarsMomentumOpCUDAKernel : public framework::OpKernel { + using MT = MultiPrecisionType; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int num_blocks_per_sm = 0; + bool multi_precision = ctx.Attr("multi_precision"); + auto& cuda_ctx = ctx.template device_context(); + int sm_num = cuda_ctx.GetSMCount(); + framework::Tensor tmp_buffer_t = + ctx.AllocateTmpTensor( + {LARS_BLOCK_SIZE << 1}, cuda_ctx); + auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); + auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; MT mu = static_cast(ctx.Attr("mu")); MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); - MT lars_weight_decay = - static_cast(ctx.Attr("lars_weight_decay")); MT epsilon = static_cast(ctx.Attr("epsilon")); - MPDType rescale_grad = - static_cast(ctx.Attr("rescale_grad")); - - auto* p = param->data(); - auto* g = grad->data(); - auto* v = velocity->data(); - auto* lr = learning_rate->data(); - - int block = 512; - int grid = (param->numel() + block - 1) / block; - - auto eigen_p = framework::EigenVector::Flatten(*param); - auto eigen_g = framework::EigenVector::Flatten(*grad); - // calculate norms using eigein and launch the kernel. - framework::Tensor p_norm_t, g_norm_t; - p_norm_t.Resize({1}); - g_norm_t.Resize({1}); - auto* p_norm_data = p_norm_t.mutable_data(ctx.GetPlace()); - auto* g_norm_data = g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); - - auto* place = ctx.template device_context().eigen_device(); - - // eigen unsupport fp16 l2-norm - ep_norm.device(*place) = - eigen_p.template cast().square().sum().sqrt(); - eg_norm.device(*place) = - (eigen_g.template cast() * rescale_grad).square().sum().sqrt(); - - MomentumLarsKernel< - T, MT><<>>( - p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay, - p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out, - rescale_grad); + MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); + + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); + auto grad = ctx.MultiInput("Grad"); + auto param = ctx.MultiInput("Param"); + auto velocity = ctx.MultiInput("Velocity"); + auto param_out = ctx.MultiOutput("ParamOut"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto learning_rate = ctx.MultiInput("LearningRate"); + auto master_param = ctx.MultiInput("MasterParam"); + auto master_param_out = + ctx.MultiOutput("MasterParamOut"); + + int op_num = grad.size(); +#if CUDA_VERSION >= 11000 + if (op_num > 1) { + LarsParamWarpper lars_warpper; + PADDLE_ENFORCE_LT( + op_num, LARS_MAX_MERGED_OPS, + platform::errors::InvalidArgument( + "The maximum number of merged-ops supported is (%d), but" + "lars op required for trainning this model is (%d)\n", + LARS_MAX_MERGED_OPS, op_num)); + + /* Implementation of lars optimizer consists of following two steps: + 1. Figure out the L2 norm statistic result of grad data and param data. + 2. Update param and velocity with usage of L2 norm statistic result. + Step1 and step2 can be merged with api provided by nvida + cudaLaunchCooperativeKernel: + - The thread quantity shall less than pyhsical SM limited threads + - Launche as thread-block can synchronizlly execute. */ + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, MergedMomentumLarsKernel, LARS_BLOCK_SIZE, + sizeof(MT) << 1); + + size_t total_numel = 0; + for (int i = 0; i < op_num; ++i) { + size_t temp_numel = param[i]->numel(); + total_numel += temp_numel; + lars_warpper.numel_arr[i] = temp_numel; + lars_warpper.g_arr[i] = grad[i]->data(); + lars_warpper.lr_arr[i] = learning_rate[i]->data(); + lars_warpper.p_out_arr[i] = + param_out[i]->mutable_data(ctx.GetPlace()); + lars_warpper.v_out_arr[i] = + velocity_out[i]->mutable_data(ctx.GetPlace()); + lars_warpper.weight_decay_arr[i] = static_cast(weight_decay_arr[i]); + PADDLE_ENFORCE_EQ( + param[i]->data(), lars_warpper.p_out_arr[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity[i]->data(), lars_warpper.v_out_arr[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + } + int64_t avg_numel = total_numel / op_num; + LarsThreadConfig lars_thread_config(avg_numel, sm_num, + num_blocks_per_sm); + for (int i = 0; i < op_num; ++i) { + lars_warpper.repeat_arr[i] = + lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); + } + if (multi_precision) { + for (int i = 0; i < op_num; ++i) { + lars_warpper.master_p_out_arr[i] = + master_param_out[i]->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_EQ(master_param[i]->data(), + lars_warpper.master_p_out_arr[i], + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + } + } + void* cuda_param[] = {reinterpret_cast(&lars_warpper), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&op_num), + reinterpret_cast(&mu), + reinterpret_cast(&lars_coeff), + reinterpret_cast(&epsilon), + reinterpret_cast(&rescale_grad), + reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads, and thead of each block synchronizedly cooperate. + cudaLaunchCooperativeKernel( + reinterpret_cast(MergedMomentumLarsKernel), + lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, + cuda_ctx.stream()); + } else { + auto* param_data = param[0]->data(); + auto* grad_data = grad[0]->data(); + auto* velocity_data = velocity[0]->data(); + auto* lr = learning_rate[0]->data(); + auto* param_out_data = param_out[0]->mutable_data(ctx.GetPlace()); + auto* velocity_out_data = + velocity_out[0]->mutable_data(ctx.GetPlace()); + const MT* master_param_data = + multi_precision ? master_param[0]->data() : nullptr; + MT* master_param_out_data = + multi_precision + ? master_param_out[0]->mutable_data(ctx.GetPlace()) + : nullptr; + int64_t numel = param[0]->numel(); + MT lars_weight_decay = weight_decay_arr[0]; + + // Figure out how many blocks can be active in each sm. + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, MomentumLarsKernel, LARS_BLOCK_SIZE, + sizeof(MT) << 1); + LarsThreadConfig lars_thread_config(numel, sm_num, + num_blocks_per_sm); + int repeat_times = lars_thread_config.GetRepeatTimes(numel); + int thresh = 0; + void* cuda_param[] = { + reinterpret_cast(¶m_data), + reinterpret_cast(&grad_data), + reinterpret_cast(&velocity_data), + reinterpret_cast(¶m_out_data), + reinterpret_cast(&velocity_out_data), + reinterpret_cast(&master_param_data), + reinterpret_cast(&master_param_out_data), + reinterpret_cast(&lr), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&mu), + reinterpret_cast(&lars_coeff), + reinterpret_cast(&lars_weight_decay), + reinterpret_cast(&epsilon), + reinterpret_cast(&rescale_grad), + reinterpret_cast(&repeat_times), + reinterpret_cast(&thresh), // Just a placeholder + reinterpret_cast(&numel), + reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads. + cudaLaunchCooperativeKernel( + reinterpret_cast(MomentumLarsKernel), + lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, + cuda_ctx.stream()); + } +#else + for (int i = 0; i < op_num; ++i) { + const MT* master_param_data = + multi_precision ? master_param[i]->data() : nullptr; + MT* master_param_out_data = + multi_precision + ? master_param_out[i]->mutable_data(ctx.GetPlace()) + : nullptr; + SeparatedLarsMomentumOpCUDAKernel( + cuda_ctx, param[i]->data(), + param_out[i]->mutable_data(ctx.GetPlace()), + velocity[i]->data(), + velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), + learning_rate[i]->data(), p_buffer, g_buffer, mu, lars_coeff, + weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), + master_param_data, master_param_out_data, multi_precision); + } +#endif } }; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h old mode 100755 new mode 100644 index 55775bc08fb5e..ff836fdca1cc4 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,54 +23,48 @@ template class LarsMomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto learning_rate = ctx.Input("LearningRate"); - auto* grad_var = ctx.InputVar("Grad"); - // only support dense for now. - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - auto grad = ctx.Input("Grad"); - - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); - + auto param_out = ctx.MultiOutput("ParamOut"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto param = ctx.MultiInput("Param"); + auto velocity = ctx.MultiInput("Velocity"); + auto learning_rate = ctx.MultiInput("LearningRate"); + auto grad = ctx.MultiInput("Grad"); + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); - T lars_weight_decay = ctx.Attr("lars_weight_decay"); T epsilon = ctx.Attr("epsilon"); - auto p_out = framework::EigenVector::Flatten(*param_out); - auto v_out = framework::EigenVector::Flatten(*velocity_out); + int op_num = param.size(); + for (int i = 0; i < op_num; ++i) { + auto* lr = learning_rate[i]->data(); + T lars_weight_decay = weight_decay_arr[i]; + param_out[i]->mutable_data(ctx.GetPlace()); + velocity_out[i]->mutable_data(ctx.GetPlace()); - auto p = framework::EigenVector::Flatten(*param); - auto v = framework::EigenVector::Flatten(*velocity); - auto g = framework::EigenVector::Flatten(*grad); - auto* lr = learning_rate->data(); + auto p_out = framework::EigenVector::Flatten(*(param_out[i])); + auto v_out = framework::EigenVector::Flatten(*(velocity_out[i])); + auto p = framework::EigenVector::Flatten(*(param[i])); + auto v = framework::EigenVector::Flatten(*(velocity[i])); + auto g = framework::EigenVector::Flatten(*(grad[i])); - framework::Tensor p_norm_t, g_norm_t; - p_norm_t.Resize({1}); - g_norm_t.Resize({1}); - p_norm_t.mutable_data(ctx.GetPlace()); - g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + p_norm_t.mutable_data(ctx.GetPlace()); + g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + ep_norm = p.square().sum().sqrt(); + eg_norm = g.square().sum().sqrt(); - ep_norm = p.square().sum().sqrt(); - eg_norm = g.square().sum().sqrt(); - T local_lr = lr[0]; - if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { - local_lr = lr[0] * lars_coeff * ep_norm(0) / - (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + T local_lr = lr[0]; + if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + } + v_out = v * mu + local_lr * (g + lars_weight_decay * p); + p_out = p - v_out; } - v_out = v * mu + local_lr * (g + lars_weight_decay * p); - p_out = p - v_out; } }; diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc new file mode 100644 index 0000000000000..6c63376b5eb42 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" + +namespace paddle { +namespace operators { + +class MergedMomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto param_dtype = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(param_dtype, ctx.GetPlace()); + } +}; + +class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter that has to be updated") + .AsDuplicable(); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter") + .AsDuplicable(); + AddInput("Velocity", + "(Tensor, default Tensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated") + .AsDuplicable(); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "Input learning rate"); + AddInput("MasterParam", "FP32 master weight for AMP.") + .AsDispensable() + .AsDuplicable(); + AddOutput("ParamOut", + "(Tensor) This output is updated parameter. " + "It shared memory with Input(Param).") + .AsDuplicable(); + AddOutput("VelocityOut", + "(Tensor) This output is updated velocity. " + "It shared memory with Input(Velocity).") + .AsDuplicable(); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable() + .AsDuplicable(); + AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + AddAttr( + "rescale_grad", + "(float, default 1.0) Multiply the gradient with `rescale_grad`" + "before updating. Often choose to be `1.0/batch_size`.") + .SetDefault(1.0f); + AddComment(R"DOC(Merged Momentum Optimizer.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp, + ops::MergedMomentumOpMaker); + +REGISTER_OP_CPU_KERNEL( + merged_momentum, ops::MergedMomentumOpKernel, + ops::MergedMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cu b/paddle/fluid/operators/optimizers/merged_momentum_op.cu new file mode 100644 index 0000000000000..7e4bbd9807938 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + merged_momentum, + ops::MergedMomentumOpKernel, + ops::MergedMomentumOpKernel, + ops::MergedMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.h b/paddle/fluid/operators/optimizers/merged_momentum_op.h new file mode 100644 index 0000000000000..4dfaa4de3ad44 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.h @@ -0,0 +1,197 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { + +template +struct MergedMomentumMasterParams { + MT *PADDLE_RESTRICT master_params[kParamNum]; + + HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; } + HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; } +}; + +template +struct MergedMomentumMasterParams { + HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; } + HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {} +}; + +template +struct MergedMomentumKernelParam + : public MergedMomentumMasterParams { + static constexpr auto N = kParamNum; + size_t sizes[N]; + T *PADDLE_RESTRICT params[N]; + const T *PADDLE_RESTRICT grads[N]; + MT *PADDLE_RESTRICT velocitys[N]; + const MT *PADDLE_RESTRICT lr; + MT mu; + MT rescale_grad; + uint32_t param_num; + + HOSTDEVICE void operator()(size_t i) const { + const auto lr_val = *lr; + for (uint32_t idx = 0; idx < param_num; ++idx) { + auto size = sizes[idx]; + if (i >= size) continue; + + auto param_p = params[idx]; + auto grad_p = grads[idx]; + auto velocity_p = velocitys[idx]; + auto master_param_p = this->MasterParam(idx); + + const MT param = + master_param_p ? master_param_p[i] : static_cast(param_p[i]); + const MT grad = static_cast(grad_p[i]) * rescale_grad; + const MT velocity = velocity_p[i]; + const MT velocity_out = velocity * mu + grad; + const MT param_out = param - lr_val * velocity_out; + velocity_p[i] = velocity_out; + param_p[i] = static_cast(param_out); + if (master_param_p) { + master_param_p[i] = param_out; + } + } + } +}; + +template +class MergedMomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto params = ctx.MultiInput("Param"); + auto params_out = ctx.MultiOutput("ParamOut"); + size_t n = params.size(); + PADDLE_ENFORCE_EQ( + n, params_out.size(), + platform::errors::InvalidArgument( + "Output(ParamOut) number must be equal to Input(Param) number.")); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ( + params[i], params_out[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + } + + auto grads = ctx.MultiInput("Grad"); + PADDLE_ENFORCE_EQ( + n, grads.size(), + platform::errors::InvalidArgument( + "Input(Grad) number must be equal to Input(Param) number.")); + + auto velocitys = ctx.MultiInput("Velocity"); + PADDLE_ENFORCE_EQ(n, velocitys.size(), + platform::errors::InvalidArgument( + "Input(Velocity) number and Input(Param) number.")); + + auto velocitys_out = ctx.MultiOutput("VelocityOut"); + PADDLE_ENFORCE_EQ( + n, velocitys_out.size(), + platform::errors::InvalidArgument("Output(VelocityOut) number must be " + "equal to Input(Param) number.")); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + } + + auto master_params = ctx.MultiInput("MasterParam"); + auto master_params_out = + ctx.MultiOutput("MasterParamOut"); + auto multi_precision = ctx.Attr("multi_precision"); + if (multi_precision) { + PADDLE_ENFORCE_EQ( + n, master_params.size(), + platform::errors::InvalidArgument("Input(MasterParam) number must be " + "equal to Input(Param) number.")); + PADDLE_ENFORCE_EQ(n, master_params_out.size(), + platform::errors::InvalidArgument( + "Output(MasterParamOut) number must be equal to " + "Input(MasterParam) number.")); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i], + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + PADDLE_ENFORCE_NOT_NULL(master_params[i], + platform::errors::InvalidArgument( + "Input(MasterParam) must be provided when " + "multi_precision=True.")); + } + } else { + master_params.clear(); + master_params_out.clear(); + } + + auto lr = ctx.Input("LearningRate"); + auto mu = ctx.Attr("mu"); + auto rescale_grad = ctx.Attr("rescale_grad"); + using MPType = typename operators::details::MPTypeTrait::Type; + + auto &dev_ctx = ctx.template device_context(); + +#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ + MergedMomentumKernelParam kernel_params; \ + constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ + size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ + kernel_params.mu = static_cast(mu); \ + kernel_params.rescale_grad = static_cast(rescale_grad); \ + kernel_params.lr = lr->data(); \ + for (size_t i = 0; i < kernel_num; ++i) { \ + size_t start = i * kMaxMergedNum; \ + size_t end = std::min((i + 1) * kMaxMergedNum, n); \ + kernel_params.param_num = static_cast(end - start); \ + size_t max_size = 0; \ + for (size_t j = 0; j < kernel_params.param_num; ++j) { \ + auto size = static_cast(params_out[j + start]->numel()); \ + max_size = std::max(max_size, size); \ + kernel_params.sizes[j] = size; \ + kernel_params.params[j] = params_out[j + start]->data(); \ + kernel_params.grads[j] = grads[j + start]->data(); \ + kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ + kernel_params.SetMasterParam( \ + j, kMultiPrecision ? master_params_out[j + start]->data() \ + : nullptr); \ + } \ + platform::ForRange for_range(dev_ctx, max_size); \ + for_range(kernel_params); \ + VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ + << kernel_params.param_num; \ + } + + if (multi_precision) { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); + } else { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); + } + +#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index f461dec66c0e7..2d713308fd938 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor { } }; -template +template class DenseMomentumFunctor; // NOTE(dzh) for performance. // avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two // functor. -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; const T* grad_; @@ -193,7 +194,6 @@ class DenseMomentumFunctor { T* param_out_; MT* velocity_out_; MT* master_param_out_; - const RegularizationType regularization_flag_; const MT regularization_coeff_; public: @@ -201,7 +201,6 @@ class DenseMomentumFunctor { const MultiPrecisionType* learning_rate, const MT* master_param, const MT mu, const MT rescale_grad, const int64_t num, - const RegularizationType regularization_flag, const MT regularization_coeff, T* param_out, MT* velocity_out, MT* master_param_out) : param_(param), @@ -215,7 +214,6 @@ class DenseMomentumFunctor { param_out_(param_out), velocity_out_(velocity_out), master_param_out_(master_param_out), - regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register @@ -225,9 +223,9 @@ class DenseMomentumFunctor { const MT lr = static_cast(lr_[0]); const MT velocity = velocity_[i]; - grad = regularization_flag_ == RegularizationType::kL2DECAY - ? grad + regularization_coeff_ * param - : grad; + if (kRegType == RegularizationType::kL2DECAY) { + grad += regularization_coeff_ * param; + } MT velocity_out = velocity * mu_ + grad; MT param_out = param - (grad + velocity_out * mu_) * lr; @@ -240,8 +238,8 @@ class DenseMomentumFunctor { } }; -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; const T* grad_; @@ -254,7 +252,6 @@ class DenseMomentumFunctor { T* param_out_; MT* velocity_out_; MT* master_param_out_; - const RegularizationType regularization_flag_; const MT regularization_coeff_; public: @@ -262,7 +259,6 @@ class DenseMomentumFunctor { const MultiPrecisionType* learning_rate, const MT* master_param, const MT mu, const MT rescale_grad, const int64_t num, - const RegularizationType regularization_flag, const MT regularization_coeff, T* param_out, MT* velocity_out, MT* master_param_out) : param_(param), @@ -276,7 +272,6 @@ class DenseMomentumFunctor { param_out_(param_out), velocity_out_(velocity_out), master_param_out_(master_param_out), - regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register @@ -286,9 +281,9 @@ class DenseMomentumFunctor { const MT lr = static_cast(lr_[0]); const MT velocity = velocity_[i]; - grad = regularization_flag_ == RegularizationType::kL2DECAY - ? grad + regularization_coeff_ * param - : grad; + if (kRegType == RegularizationType::kL2DECAY) { + grad += regularization_coeff_ * param; + } MT velocity_out = velocity * mu_ + grad; MT param_out = param - lr * velocity_out; @@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel { platform::ForRange for_range( static_cast(ctx.device_context()), param->numel()); - if (use_nesterov) { - DenseMomentumFunctor functor( - param->data(), grad->data(), velocity->data(), - learning_rate->data(), master_in_data, mu, rescale_grad, - param->numel(), regularization_flag, regularization_coeff, - param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); - for_range(functor); +#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ + DenseMomentumFunctor functor( \ + param->data(), grad->data(), velocity->data(), \ + learning_rate->data(), master_in_data, mu, rescale_grad, \ + param->numel(), regularization_coeff, \ + param_out->mutable_data(ctx.GetPlace()), \ + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); \ + for_range(functor); + if (use_nesterov) { + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, + RegularizationType::kL2DECAY); + } else { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, + RegularizationType::kNONE); + } } else { - DenseMomentumFunctor functor( - param->data(), grad->data(), velocity->data(), - learning_rate->data(), master_in_data, mu, rescale_grad, - param->numel(), regularization_flag, regularization_coeff, - param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); - for_range(functor); + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, + RegularizationType::kL2DECAY); + } else { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, + RegularizationType::kNONE); + } } } diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc new file mode 100644 index 0000000000000..4d919c94f616b --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +class Pow2DecayWithLinearWarmupOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + auto dim = framework::make_ddim({1}); + ctx->SetOutputDim("LearningRateOut", dim); + ctx->SetOutputDim("StepOut", dim); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "LearningRate"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class Pow2DecayWithLinearWarmupOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("LearningRate", "(Tensor) The input learning rate Tensor."); + AddInput("Step", "(Tensor) The input global step Tensor."); + AddOutput("LearningRateOut", + "(Tensor) The output learning rate Tensor. Same with " + "Input(LearningRate)."); + AddOutput( + "StepOut", + "(Tensor) The output learning rate Tensor. Same with Input(Step)."); + AddAttr("warmup_steps", "(int64_t) The warmup steps."); + AddAttr( + "total_steps", + "(int64_t) The total steps for changing the learning rate."); + AddAttr("base_lr", + "(float) The final learning rate value after warmup."); + AddAttr("end_lr", + "(float) The final learning rate value after total_steps."); + AddComment(R"DOC( +The Pow2DecayWithLinearWarmup learning rate scheduler. + +When step_num < warmup_steps, lr = base_lr * step_num / warmup_steps + +When warmup_steps <= step_num <= total_steps, + factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps) + lr = (base_lr - end_lr) * factor * factor + end_lr + +When step_num > total_steps, lr = end_lr + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupOp, + ops::Pow2DecayWithLinearWarmupOpMaker); +REGISTER_OP_CPU_KERNEL( + pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupOpKernel, + ops::Pow2DecayWithLinearWarmupOpKernel); diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu new file mode 100644 index 0000000000000..6695778dbac06 --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupOpKernel, + ops::Pow2DecayWithLinearWarmupOpKernel); diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h new file mode 100644 index 0000000000000..74cf762745077 --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h @@ -0,0 +1,115 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { + +template +struct Pow2DecayWithLinearWarmupFunctor { + template + using RestrictPtr = U *PADDLE_RESTRICT; + + public: + HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr lr, + RestrictPtr step, + size_t warmup_steps, + size_t total_steps, AttrT base_lr, + AttrT end_lr) + : lr_(lr), + step_(step), + warmup_steps_(warmup_steps), + total_steps_(total_steps), + base_lr_(base_lr), + end_lr_(end_lr) {} + + HOSTDEVICE void operator()(size_t) const { + size_t step = static_cast(*step_) + 1; + *step_ = static_cast(step); + if (step <= warmup_steps_) { + auto new_lr = static_cast(step) / warmup_steps_ * base_lr_; + *lr_ = static_cast(new_lr); + } else if (step < total_steps_) { + auto factor = 1 - + static_cast(step - warmup_steps_) / + (total_steps_ - warmup_steps_); + auto new_lr = + static_cast(base_lr_ - end_lr_) * (factor * factor) + end_lr_; + *lr_ = static_cast(new_lr); + } else { + *lr_ = static_cast(end_lr_); + } + } + + private: + RestrictPtr lr_; + RestrictPtr step_; + size_t warmup_steps_; + size_t total_steps_; + AttrT base_lr_; + AttrT end_lr_; +}; + +template +class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const { + const auto *lr = ctx.Input("LearningRate"); + const auto *step = ctx.Input("Step"); + auto *lr_out = ctx.Output("LearningRateOut"); + auto *step_out = ctx.Output("StepOut"); + PADDLE_ENFORCE_EQ( + lr, lr_out, platform::errors::InvalidArgument("Input(LearningRate) and " + "Output(LearningRateOut) " + "must be the same.")); + PADDLE_ENFORCE_NOT_NULL(lr, + platform::errors::InvalidArgument( + "Input(LearingRate) should not be nullptr.")); + PADDLE_ENFORCE_EQ(step, step_out, + platform::errors::InvalidArgument( + "Input(Step) and Output(StepOut) must be the same.")); + PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument( + "Input(Step) should not be nullptr.")); + PADDLE_ENFORCE_EQ( + step->IsInitialized(), true, + platform::errors::InvalidArgument("Input(Step) must be initialized.")); + + auto warmup_steps = static_cast(ctx.Attr("warmup_steps")); + auto total_steps = static_cast(ctx.Attr("total_steps")); + PADDLE_ENFORCE_LE(warmup_steps, total_steps, + platform::errors::InvalidArgument( + "warmup_steps must not be larger than total_steps.")); + auto base_lr = ctx.Attr("base_lr"); + auto end_lr = ctx.Attr("end_lr"); + + auto *lr_data = lr_out->data(); + auto *step_data = step_out->data(); + auto &dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, 1); + using AttrT = double; + Pow2DecayWithLinearWarmupFunctor functor( + lr_data, step_data, warmup_steps, total_steps, + static_cast(base_lr), static_cast(end_lr)); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 2540170ed54fb..21213f9e6ff21 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -59,9 +59,14 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS}) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) IF(WITH_GPU) + nv_library(cuda_graph SRCS cuda_graph.cc DEPS enforce allocator_facade) nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda) nv_library(cuda_profiler SRCS cuda_profiler.cc DEPS enforce) + nv_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade cuda_graph) +ELSE() + cc_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade) ENDIF() + IF(WITH_ROCM) hip_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda) ENDIF() diff --git a/paddle/fluid/platform/cuda_graph.cc b/paddle/fluid/platform/cuda_graph.cc new file mode 100644 index 0000000000000..693a592799027 --- /dev/null +++ b/paddle/fluid/platform/cuda_graph.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/cuda_graph.h" + +namespace paddle { +namespace platform { + +std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; + +void CUDAGraph::Reset() { + if (is_reset_) return; +#if CUDA_VERSION >= 10010 + if (graph_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph_)); + graph_ = nullptr; + } + if (exec_graph_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph_)); + exec_graph_ = nullptr; + } +#endif + // callback should be called in reverse order because the latter added + // callback may rely on the former added callback. + for (auto iter = callbacks_.rbegin(); iter != callbacks_.rend(); ++iter) { + (*iter)(); + } + callbacks_.clear(); + is_reset_ = true; +} + +void CUDAGraph::Replay() { +#if CUDA_VERSION >= 10010 + PADDLE_ENFORCE_EQ(is_reset_, false, + errors::PermissionDenied( + "Cannot replay the CUDA Graph after reset is called.")); + PADDLE_ENFORCE_NOT_NULL(exec_graph_, + errors::PermissionDenied( + "CUDA Graph must be captured before replaying.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph_, stream_)); +#endif +} + +void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, + cudaStreamCaptureMode mode) { + ThrowErrorIfNotSupportCUDAGraph(); + PADDLE_ENFORCE_EQ( + IsCapturing(), false, + errors::PermissionDenied("CUDA Graph can only captured one by one.")); + PADDLE_ENFORCE_NOT_NULL( + stream, errors::PermissionDenied( + "CUDA Graph cannot be captured in default CUDA stream 0.")); + capturing_graph_.reset(new CUDAGraph()); + capturing_graph_->place_ = place; + capturing_graph_->stream_ = stream; + + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamBeginCapture(capturing_graph_->stream_, mode)); + cudaStreamCaptureStatus status; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo( + capturing_graph_->stream_, &status, &(capturing_graph_->id_))); + PADDLE_ENFORCE_EQ(IsValidCapturing(), true, + platform::errors::PermissionDenied( + "CUDA Graph should not be invalidated.")); + VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_; +} + +std::unique_ptr CUDAGraph::EndCapture() { + ThrowErrorIfNotSupportCUDAGraph(); +#if CUDA_VERSION >= 10010 + PADDLE_ENFORCE_EQ(IsCapturing(), true, + errors::PermissionDenied("No CUDA Graph is capturing.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamEndCapture( + capturing_graph_->stream_, &(capturing_graph_->graph_))); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaGraphInstantiate(&(capturing_graph_->exec_graph_), + capturing_graph_->graph_, nullptr, nullptr, 0)); + VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_; + return std::move(capturing_graph_); +#endif +} + +bool CUDAGraph::IsValidCapturing() { + if (!IsCapturing()) return false; + cudaStreamCaptureStatus status; + CUDAGraphID id; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id)); + return status == cudaStreamCaptureStatusActive; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph.h b/paddle/fluid/platform/cuda_graph.h new file mode 100644 index 0000000000000..55ec463556b45 --- /dev/null +++ b/paddle/fluid/platform/cuda_graph.h @@ -0,0 +1,142 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "cuda.h" // NOLINT +#include "cuda_runtime.h" // NOLINT +#include "paddle/fluid/platform/type_defs.h" + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace platform { + +#if CUDA_VERSION >= 10010 +static void ThrowErrorIfNotSupportCUDAGraph() {} +#else +enum cudaStreamCaptureMode { + cudaStreamCaptureModeGlobal = 0, + cudaStreamCaptureModeThreadLocal = 1, + cudaStreamCaptureModeRelaxed = 2 +}; +static void ThrowErrorIfNotSupportCUDAGraph() { + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported when CUDA version >= 10.1")); +} +#endif + +// NOTE: Currently, we do not support to capture CUDA graph in parallel +// NOTE: Do not use this class directly because it should be used with +// the memory pool. +class CUDAGraph { + DISABLE_COPY_AND_ASSIGN(CUDAGraph); + + // Since the constructor would throw error is CUDA_VERSION < 10010. + // The non-static method of CUDAGraph need not check CUDA_VERSION + // again. + CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); } + + public: + ~CUDAGraph() { Reset(); } + + CUDAGraphID ID() const { return id_; } + + void Replay(); + + void Reset(); + + void AddResetCallback(std::function callback) { + std::lock_guard guard(mtx_); + callbacks_.push_back(std::move(callback)); + } + + static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream, + cudaStreamCaptureMode mode); + static std::unique_ptr EndCapture(); + static void AddResetCallbackDuringCapturing(std::function callback) { + capturing_graph_->AddResetCallback(std::move(callback)); + } + + // No need to add CUDA_VERSION macro because capturing_graph_ would + // always be nullptr (constructor throws error) + static bool IsCapturing() { return capturing_graph_ != nullptr; } + + static CUDAGraphID CapturingID() { return capturing_graph_->id_; } + + static platform::CUDAPlace CapturingPlace() { + return capturing_graph_->place_; + } + + // This API can be used to debug which GPU operation is not + // supported during capturing CUDA Graph. + static bool IsValidCapturing(); + + private: +#if CUDA_VERSION >= 10010 + cudaGraph_t graph_{nullptr}; + cudaGraphExec_t exec_graph_{nullptr}; +#endif + cudaStream_t stream_{nullptr}; + platform::CUDAPlace place_; + CUDAGraphID id_{0}; + std::vector> callbacks_; + bool is_reset_{false}; + std::mutex mtx_; + + static std::unique_ptr capturing_graph_; +}; + +#if CUDA_VERSION >= 10010 +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard( + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); + // After cudaThreadExchangeStreamCaptureMode is called, + // the variable "mode" would be set to the old capturing mode. + old_mode_ = mode; + } + } + + ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaThreadExchangeStreamCaptureMode(&old_mode_)); + } + } + + private: + cudaStreamCaptureMode old_mode_; +}; +#else +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard( + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} +}; +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc new file mode 100644 index 0000000000000..4804d3f6ed301 --- /dev/null +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace platform { + +#ifdef PADDLE_WITH_CUDA +void BeginCUDAGraphCapture(platform::CUDAPlace place, + cudaStreamCaptureMode mode) { + auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + + auto stream = dev_ctx->stream(); + CUDAGraph::BeginCapture(place, stream, mode); + auto id = CUDAGraph::CapturingID(); + memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( + id); + AddResetCallbackIfCapturingCUDAGraph([id] { + memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( + id); + }); +} + +std::unique_ptr EndCUDAGraphCapture() { + auto place = CUDAGraph::CapturingPlace(); + auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + return CUDAGraph::EndCapture(); +} +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h new file mode 100644 index 0000000000000..f9f0248e5153b --- /dev/null +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_graph.h" +#endif + +namespace paddle { +namespace platform { + +// NOTE: These APIs are not thread-safe. +#ifdef PADDLE_WITH_CUDA +void BeginCUDAGraphCapture(platform::CUDAPlace place, + cudaStreamCaptureMode mode); +std::unique_ptr EndCUDAGraphCapture(); +#endif + +inline bool IsCUDAGraphCapturing() { +#ifdef PADDLE_WITH_CUDA + return CUDAGraph::IsCapturing(); +#else + return false; +#endif +} + +inline platform::CUDAPlace CUDAGraphCapturingPlace() { +#ifdef PADDLE_WITH_CUDA + return CUDAGraph::CapturingPlace(); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported on NVIDIA GPU device.")); +#endif +} + +// Add reset callback if CUDA Graph is capturing. +// Otherwise, invoke callback directly. +template +inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(IsCUDAGraphCapturing())) { + return CUDAGraph::AddResetCallbackDuringCapturing( + std::forward(callback)); + } +#endif + callback(); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cudnn_desc.h b/paddle/fluid/platform/cudnn_desc.h index 486b3346c3760..318c85ee484be 100644 --- a/paddle/fluid/platform/cudnn_desc.h +++ b/paddle/fluid/platform/cudnn_desc.h @@ -44,6 +44,9 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) { inline std::vector TransformDimOrder(const std::vector& dims) { std::vector transformed_dims(dims.begin(), dims.end()); + if (dims.size() < 4) { + return transformed_dims; + } int H, W, D, C; if (dims.size() == 4) { H = dims[1]; @@ -155,8 +158,8 @@ class TensorDescriptor { dims_with_group.data(), strides.data())); } - void set(const Tensor& tensor, const cudnnTensorFormat_t format) { - auto dims = framework::vectorize(tensor.dims()); + void set(const std::vector& dims, const cudnnTensorFormat_t format, + const cudnnDataType_t dtype) { std::vector transformed_dims; if (format == CUDNN_TENSOR_NHWC) { transformed_dims = TransformDimOrder(dims); @@ -164,8 +167,14 @@ class TensorDescriptor { transformed_dims = dims; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptorEx( - desc_.get(), format, ToCudnnDataType(tensor.type()), - transformed_dims.size(), transformed_dims.data())); + desc_.get(), format, dtype, transformed_dims.size(), + transformed_dims.data())); + } + + void set(const Tensor& tensor, const cudnnTensorFormat_t format) { + auto dims = framework::vectorize(tensor.dims()); + auto dtype = ToCudnnDataType(tensor.type()); + set(dims, format, dtype); } private: @@ -191,9 +200,8 @@ class FilterDescriptor { T* desc() { return desc_.get(); } T* desc() const { return desc_.get(); } - void set(const Tensor& tensor, const cudnnTensorFormat_t format, - const int groups = 1) { - auto dims = framework::vectorize(tensor.dims()); + void set(const std::vector& dims, const cudnnTensorFormat_t format, + const cudnnDataType_t dtype, const int groups = 1) { std::vector transformed_dims; if (format == CUDNN_TENSOR_NHWC) { transformed_dims = TransformDimOrder(dims); @@ -204,8 +212,15 @@ class FilterDescriptor { transformed_dims[1] = transformed_dims[1] / groups; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetFilterNdDescriptor( - desc_.get(), ToCudnnDataType(tensor.type()), format, - transformed_dims.size(), transformed_dims.data())); + desc_.get(), dtype, format, transformed_dims.size(), + transformed_dims.data())); + } + + void set(const Tensor& tensor, const cudnnTensorFormat_t format, + const int groups = 1) { + auto dims = framework::vectorize(tensor.dims()); + auto dtype = ToCudnnDataType(tensor.type()); + set(dims, format, dtype, groups); } private: diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 4828a97e4df4d..3420c38fe9639 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -180,7 +180,18 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif #if CUDNN_VERSION >= 8000 -#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) __macro(cudnnSetRNNDescriptor_v8); +#define CUDNN_DNN_ROUTINE_EACH_R8(__macro) \ + __macro(cudnnSetRNNDescriptor_v8); \ + __macro(cudnnCreateFusedOpsPlan); \ + __macro(cudnnCreateFusedOpsConstParamPack); \ + __macro(cudnnCreateFusedOpsVariantParamPack); \ + __macro(cudnnDestroyFusedOpsPlan); \ + __macro(cudnnDestroyFusedOpsConstParamPack); \ + __macro(cudnnDestroyFusedOpsVariantParamPack); \ + __macro(cudnnFusedOpsExecute); \ + __macro(cudnnSetFusedOpsConstParamPackAttribute); \ + __macro(cudnnSetFusedOpsVariantParamPackAttribute); \ + __macro(cudnnMakeFusedOpsPlan); CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index c4ac5aa3046a9..c624ba94b74a3 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -22,6 +22,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/dynload/miopen.h" #else +#include "paddle/fluid/platform/cuda_graph.h" #include "paddle/fluid/platform/dynload/cudnn.h" #endif #include "paddle/fluid/memory/malloc.h" @@ -557,6 +558,7 @@ class RecordedCudaMallocHelper { #ifdef PADDLE_WITH_HIP auto result = hipMalloc(ptr, size); #else + CUDAGraphCaptureModeGuard capture_mode_guard; auto result = cudaMalloc(ptr, size); #endif if (result == gpuSuccess) { diff --git a/paddle/fluid/platform/macros.h b/paddle/fluid/platform/macros.h index fb5cf9fb31915..bf089ac117d41 100644 --- a/paddle/fluid/platform/macros.h +++ b/paddle/fluid/platform/macros.h @@ -30,3 +30,9 @@ limitations under the License. */ #define FLT_MAX __FLT_MAX__ #endif // __FLT_MAX__ #endif // PADDLE_WITH_MUSL + +#if defined(__NVCC__) || defined(__HIPCC__) +#define PADDLE_RESTRICT __restrict__ +#else +#define PADDLE_RESTRICT +#endif diff --git a/paddle/fluid/platform/type_defs.h b/paddle/fluid/platform/type_defs.h index f46bd1a0bdfa4..88a2d16472fa7 100644 --- a/paddle/fluid/platform/type_defs.h +++ b/paddle/fluid/platform/type_defs.h @@ -36,4 +36,5 @@ using gpuEvent_t = cudaEvent_t; using gpuDeviceProp = cudaDeviceProp; #endif +using CUDAGraphID = unsigned long long; // NOLINT } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 22778013f2390..875e6af9652a2 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -7,7 +7,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator - cost_model) + cost_model cuda_graph_with_memory_pool) if (WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d4f4323238e9f..2fb7395be34e9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -125,6 +125,8 @@ limitations under the License. */ #include "paddle/fluid/platform/xpu/xpu_info.h" #endif +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" + #ifdef PADDLE_WITH_CRYPTO #include "paddle/fluid/pybind/crypto.h" #endif @@ -485,6 +487,17 @@ static int GetNCCLVersion() { } #endif +template +static void TensorCopyFrom(framework::Tensor *dst, const framework::Tensor &src, + const PlaceType &place, int64_t batch_size) { + if (batch_size < 0) { + framework::TensorCopy(src, place, dst); + } else { + auto sliced = src.Slice(0, batch_size); + framework::TensorCopy(sliced, place, dst); + } +} + #ifdef PADDLE_WITH_AVX PYBIND11_MODULE(core_avx, m) { #else @@ -520,6 +533,19 @@ PYBIND11_MODULE(core_noavx, m) { m.def("nccl_version", &GetNCCLVersion); #endif + m.def("is_cuda_graph_capturing", &platform::IsCUDAGraphCapturing); +#ifdef PADDLE_WITH_CUDA + py::class_(m, "CUDAGraph") + .def_static("begin_capture", + [](platform::CUDAPlace place, int mode) { + platform::BeginCUDAGraphCapture( + place, static_cast(mode)); + }) + .def_static("end_capture", &platform::EndCUDAGraphCapture) + .def("replay", &platform::CUDAGraph::Replay) + .def("reset", &platform::CUDAGraph::Reset); +#endif + m.def("wait_device", [](const platform::Place &place) { platform::DeviceContextPool::Instance().Get(place)->Wait(); }); @@ -721,6 +747,18 @@ PYBIND11_MODULE(core_noavx, m) { paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, @@ -2301,7 +2339,14 @@ All parameter, weight, gradient are variables in Paddle. m.def("op_support_gpu", OpSupportGPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) m.def("get_cuda_device_count", platform::GetCUDADeviceCount); - m.def("cuda_empty_cache", platform::EmptyCache); + m.def("cuda_empty_cache", [] { + for (int dev_id : platform::GetSelectedDevices()) { + auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace( + platform::CUDAPlace(dev_id)); + dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + } + platform::EmptyCache(); + }); m.def("get_device_properties", [](int id) -> const gpuDeviceProp & { return platform::GetDeviceProperties(id); @@ -3213,6 +3258,13 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, bool fix_op_run_order) { self.fix_op_run_order_ = fix_op_run_order; }) + .def_property("allow_cuda_graph_capture", + [](const BuildStrategy &self) { + return self.allow_cuda_graph_capture_; + }, + [](BuildStrategy &self, bool allow_cuda_graph_capture) { + self.allow_cuda_graph_capture_ = allow_cuda_graph_capture; + }) .def("_copy", [](const BuildStrategy &self) { auto new_bs = self; diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py new file mode 100644 index 0000000000000..612f4d2c8cebd --- /dev/null +++ b/python/paddle/device/cuda/graphs.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace + +if is_compiled_with_cuda() and not is_compiled_with_rocm(): + from paddle.fluid.core import CUDAGraph as CoreCUDAGraph + + class CUDAGraph: + def __init__(self, place=None, mode="thread_local"): + ALL_MODES = ["global", "thread_local", "relaxed"] + self._graph = None + if place is None: + place = CUDAPlace(0) + self._place = place + assert mode in ALL_MODES + self._mode = ALL_MODES.index(mode) + + def capture_begin(self): + CoreCUDAGraph.begin_capture(self._place, self._mode) + + def capture_end(self): + self._graph = CoreCUDAGraph.end_capture() + + def replay(self): + self._graph.replay() + + def reset(self): + self._graph.reset() +else: + + class CUDAGraph: + def __init__(self, place=None, mode="thread_local"): + raise NotImplementedError() + + def capture_begin(self): + raise NotImplementedError() + + def capture_end(self): + raise NotImplementedError() + + def replay(self): + raise NotImplementedError() + + def reset(self): + raise NotImplementedError() diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index ded963b3a39c7..6965c997e7943 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1932,3 +1932,38 @@ def build_program(main_program, startup_program): attrs=attrs) return batch_norm_out + + +def pow2_decay_with_linear_warmup(warmup_steps, + total_steps, + base_lr, + end_lr, + dtype='float32', + name=None): + if paddle.fluid.in_dygraph_mode(): + raise NotImplementedError( + "pow2_decay_with_linear_warmup does not support dygraph mode yet.") + + helper = LayerHelper("pow2_decay_with_linear_warmup", **locals()) + lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1]) + helper.set_variable_initializer( + lr, Constant(value=float(base_lr) / warmup_steps)) + + step = helper.create_global_variable( + persistable=True, dtype='int64', shape=[1]) + helper.set_variable_initializer(step, Constant(value=0)) + assert warmup_steps <= total_steps, "warmup_steps cannot be larger than total_steps" + + helper.append_op( + type="pow2_decay_with_linear_warmup", + inputs={"LearningRate": lr, + "Step": step}, + outputs={"LearningRateOut": lr, + "StepOut": step}, + attrs={ + "warmup_steps": warmup_steps, + "total_steps": total_steps, + "base_lr": base_lr, + "end_lr": end_lr, + }) + return lr diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index e55ee8621867f..36546c1de1204 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -127,11 +127,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): num_cast_ops = 0 for in_name in op.input_names: - if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - ]: - if in_name not in {'X', 'Z'}: - continue + if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(op, + in_name): + continue for in_var_name in op.input(in_name): in_var = block._find_var_recursive(in_var_name) if in_var.type not in _valid_types or in_var.dtype == dest_dtype: @@ -184,9 +182,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): op._set_attr('in_dtype', dest_dtype) if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: - if op.type in [ - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - ] and out_name != 'Y': + if _keep_fp32_output(op, out_name): continue for out_var_name in op.output(out_name): out_var = block.var(out_var_name) @@ -401,9 +397,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): keep_fp32_ops.add(op) continue # processed below for in_name in op.input_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and in_name not in {'X', 'Z'}: + if _keep_fp32_input(op, in_name): continue for in_var_name in op.input(in_name): in_var = None @@ -431,9 +425,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): format(op.type, in_var_name, in_var.dtype)) for out_name in op.output_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and out_name != 'Y': + if _keep_fp32_output(op, out_name): continue for out_var_name in op.output(out_name): out_var = None diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 99009a4743c20..4d424fc09f14a 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1041,9 +1041,15 @@ def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, lr_value = lr_sheduler() lr_var = program._program.global_block().vars[lr_sheduler._var_name] lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype) - exe.feed_and_split_tensor_into_local_scopes({ - lr_sheduler._var_name: lr_tensor - }) + if core.is_cuda_graph_capturing(): + warnings.warn( + "Caution!!! When capturing CUDA Graph, the learning rate scheduler would not " + "take any effect! Please set the learning rate manually before each batch!" + ) + else: + exe.feed_and_split_tensor_into_local_scopes({ + lr_sheduler._var_name: lr_tensor + }) fetch_var_names = list(map(_to_name_str, fetch_list)) tensors = exe.run(fetch_var_names, return_merged)._move_to_list() diff --git a/python/paddle/fluid/memory_analysis.py b/python/paddle/fluid/memory_analysis.py new file mode 100644 index 0000000000000..0bcfeed351615 --- /dev/null +++ b/python/paddle/fluid/memory_analysis.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import core +import numpy as np + + +def get_var_and_memory_size(block, var_name, batch_size=None): + var = block._find_var_recursive(var_name) + assert var is not None, "Variable {} cannot be found".format(var_name) + assert var.type == core.VarDesc.VarType.LOD_TENSOR, "Variable {} is not Tensor".format( + var_name) + shape = list(var.shape) + if not shape: + return var, 0 + + has_none = False + for i, s in enumerate(shape): + if s is None or s < 0: + assert not has_none + shape[i] = batch_size + has_none = True + assert all( + [s >= 0 for s in shape]), "shape {} is not deterministic".format(shape) + mem_size = int(np.prod(shape)) * core.size_of_dtype(var.dtype) + return var, mem_size + + +def pre_allocate_memory(size, place): + t = core.LoDTensor() + t._set_dims([size]) + t._mutable_data(place, core.VarDesc.VarType.INT8) + del t + + +# NOTE: does not consider inplace yet. +def get_max_memory_info(program, batch_size=None): + assert program.num_blocks == 1, "only support to analysis program with only one block" + cur_tmp_mem = 0 + max_tmp_mem = 0 + max_persistable_mem = 0 + visited_vars = set() + alived_vars = [] + + block = program.global_block() + gc_vars = core._get_eager_deletion_vars(program.desc, [])[0] + for i, op in enumerate(block.ops): + var_names = op.input_arg_names + op.output_arg_names + for var_name in var_names: + if var_name in visited_vars: + continue + visited_vars.add(var_name) + var, mem_size = get_var_and_memory_size(block, var_name, batch_size) + if var.persistable: + max_persistable_mem += mem_size + else: + cur_tmp_mem += mem_size + max_tmp_mem = max(max_tmp_mem, cur_tmp_mem) + + cur_gc_vars = gc_vars[i] + for var_name in var_names: + if var_name not in cur_gc_vars: + continue + _, mem_size = get_var_and_memory_size(block, var_name, batch_size) + cur_tmp_mem -= mem_size + return max_tmp_mem, max_persistable_mem diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index f809f1bda079d..719bc609785e3 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2064,8 +2064,9 @@ def _append_optimize_op(self, block, param_and_grad): attrs = { "mu": self._momentum, "lars_coeff": self._lars_coeff, - "lars_weight_decay": _lars_weight_decay, + "lars_weight_decay": [_lars_weight_decay], "multi_precision": find_master, + "epsilon": self._epsilon, "rescale_grad": self._rescale_grad } @@ -2084,7 +2085,7 @@ def _append_optimize_op(self, block, param_and_grad): # create the momentum optimize op momentum_op = block.append_op( - type=self.type, + type=self.type if _lars_weight_decay != 0.0 else 'momentum', inputs=inputs, outputs=outputs, attrs=attrs, diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py new file mode 100644 index 0000000000000..7d1317473531e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +from paddle.device.cuda.graphs import CUDAGraph +import unittest +import numpy as np +from paddle.fluid.dygraph.base import switch_to_static_graph +from simple_nets import simple_fc_net_with_inputs + + +class TestCUDAGraph(unittest.TestCase): + def setUp(self): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm( + ): + fluid.set_flags({ + 'FLAGS_allocator_strategy': 'auto_growth', + 'FLAGS_sync_nccl_allreduce': False, + 'FLAGS_cudnn_deterministic': True + }) + + def random_tensor(self, shape): + return paddle.to_tensor( + np.random.randint( + low=0, high=10, size=shape).astype("float32")) + + @switch_to_static_graph + def test_cuda_graph_static_graph(self): + if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + return + + seed = 100 + loss_cuda_graph = self.cuda_graph_static_graph_main( + seed, use_cuda_graph=True) + loss_no_cuda_graph = self.cuda_graph_static_graph_main( + seed, use_cuda_graph=False) + self.assertEqual(loss_cuda_graph, loss_no_cuda_graph) + + def cuda_graph_static_graph_main(self, seed, use_cuda_graph): + batch_size = 1 + class_num = 10 + image_shape = [batch_size, 784] + label_shape = [batch_size, 1] + + paddle.seed(seed) + np.random.seed(seed) + startup = paddle.static.Program() + main = paddle.static.Program() + with paddle.static.program_guard(main, startup): + image = paddle.static.data( + name="image", shape=image_shape, dtype='float32') + label = paddle.static.data( + name="label", shape=label_shape, dtype='int64') + image.persistable = True + label.persistable = True + loss = simple_fc_net_with_inputs(image, label, class_num) + loss.persistable = True + lr = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04]) + optimizer = paddle.optimizer.SGD(learning_rate=lr) + optimizer.minimize(loss) + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup) + build_strategy = paddle.static.BuildStrategy() + build_strategy.allow_cuda_graph_capture = True + build_strategy.fix_op_run_order = True + build_strategy.fuse_all_optimizer_ops = True + compiled_program = paddle.static.CompiledProgram( + main).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + places=place) + image_t = scope.var(image.name).get_tensor() + label_t = scope.var(label.name).get_tensor() + loss_t = scope.var(loss.name).get_tensor() + lr_var = main.global_block().var(lr._var_name) + self.assertTrue(lr_var.persistable) + lr_t = scope.var(lr_var.name).get_tensor() + cuda_graph = None + for batch_id in range(20): + image_t.set( + np.random.rand(*image_shape).astype('float32'), place) + label_t.set(np.random.randint( + low=0, high=class_num, size=label_shape, dtype='int64'), + place) + + if batch_id == 1 and use_cuda_graph: + cuda_graph = CUDAGraph(place, mode="global") + cuda_graph.capture_begin() + exe.run(compiled_program) + cuda_graph.capture_end() + + if cuda_graph: + lr_t.set(np.array([lr()], dtype='float32'), place) + cuda_graph.replay() + else: + exe.run(compiled_program) + lr.step() + if cuda_graph: + cuda_graph.reset() + return np.array(loss_t) + + def test_cuda_graph_dynamic_graph(self): + if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + return + + shape = [2, 3] + x = self.random_tensor(shape) + z = self.random_tensor(shape) + + g = CUDAGraph() + g.capture_begin() + y = x + 10 + z.add_(x) + g.capture_end() + + for _ in range(10): + z_np_init = z.numpy() + x_new = self.random_tensor(shape) + x.copy_(x_new, False) + g.replay() + x_np = x_new.numpy() + y_np = y.numpy() + z_np = z.numpy() + self.assertTrue((y_np - x_np == 10).all()) + self.assertTrue((z_np - z_np_init == x_np).all()) + + g.reset() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py index e4cc3682d1a24..bee6acf732460 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -103,7 +103,7 @@ def test_lars_exclude_fn(self): 'op_role_var')[0] or ".b" in op.attr('op_role_var')[0]) ] for op in ops_without_wd: - self.assertEqual(op.attr('lars_weight_decay'), 0) + self.assertEqual(op.attr('lars_weight_decay')[0], 0) def test_lars_apply_with_amp(self): role = role_maker.PaddleCloudRoleMaker(is_collective=True) diff --git a/python/paddle/fluid/tests/unittests/test_memory_analysis.py b/python/paddle/fluid/tests/unittests/test_memory_analysis.py new file mode 100644 index 0000000000000..9388e07dbf891 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_memory_analysis.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle.fluid.memory_analysis import pre_allocate_memory, get_max_memory_info +from simple_nets import simple_fc_net + + +class TestMemoryAnalysis(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def test_get_memory_info(self): + loss = simple_fc_net() + optimizer = paddle.optimizer.Adam(learning_rate=1e-3) + optimizer.minimize(loss) + main_prog = paddle.static.default_main_program() + max_tmp_mem_1, max_persitable_mem_1 = get_max_memory_info( + main_prog, batch_size=32) + self.assertGreater(max_tmp_mem_1, 0) + self.assertGreater(max_persitable_mem_1, 0) + max_tmp_mem_2, max_persitable_mem_2 = get_max_memory_info( + main_prog, batch_size=64) + self.assertEqual(max_persitable_mem_1, max_persitable_mem_2) + self.assertLess(max_tmp_mem_1, max_tmp_mem_2) + + +class TestPreAllocateMemory(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def test_pre_allocate(self): + size = 32 * 1024 * 1024 + pre_allocate_memory(size, paddle.CPUPlace()) + if paddle.is_compiled_with_cuda(): + pre_allocate_memory(size, paddle.CUDAPlace(0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py new file mode 100644 index 0000000000000..96e458795a3c0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.fluid.layer_helper import LayerHelper +from collections import OrderedDict + + +def run_momentum_op(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + mu=0.9, + rescale_grad=0.01, + use_merged=False): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'merged_momentum' if use_merged else 'momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + } + + param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable( + shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable( + persistable=True, shape=v.shape, dtype=v.dtype) + for v in velocitys + ] + lr_var = helper.create_variable( + persistable=True, + shape=learning_rate.shape, + dtype=learning_rate.dtype) + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update({lr_var.name: learning_rate}) + + if multi_precision: + master_param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) + for p in master_params + ] + feed_dict.update( + OrderedDict([(mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, + master_params)])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, (p, g, + v) in enumerate(zip(param_vars, grad_vars, velocity_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + +class TestMergedMomentum(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float16 if multi_precision and isinstance( + place, paddle.CUDAPlace) else np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, mp_dtype) + learning_rate = self.gen_rand_data([[1]], mp_dtype)[0] + if multi_precision: + master_params = [p.astype(mp_dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rate + + def check_with_place(self, place, multi_precision): + params, grads, velocitys, master_params, learning_rate = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_merged): + # FIXME(zengjinle): CPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01 + return run_momentum_op( + params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + rescale_grad=rescale_grad, + use_merged=use_merged) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + if isinstance(place, paddle.CUDAPlace): + self.assertTrue(np.array_equal(out1, out2)) + else: + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + + def get_places(self): + places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + return places + + def test_main(self): + for multi_precision in [False, True]: + for place in self.get_places(): + self.check_with_place(place, multi_precision) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index b42de853c00d5..34e057a5a8a61 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -138,50 +138,70 @@ def test_check_output(self): "core is not compiled with CUDA") class TestLarsMomentumOpWithMP(OpTest): def setUp(self): + self.config() self.op_type = "lars_momentum" - - master_param = np.random.random((123, 321)).astype("float32") - param = master_param.astype("float16") - grad = np.random.random((123, 321)).astype("float16") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 lars_coeff = 0.001 lars_weight_decay = 0.0005 rescale_grad = 1.0 + params = [] + grads = [] + velocitys = [] + learning_rates = [] + master_params = [] + param_outs = [] + velocity_outs = [] + master_param_outs = [] + for i in range(self.params_num): + master_param = np.random.random((123, 321)).astype("float32") + param = master_param.astype("float16") + grad = np.random.random((123, 321)).astype("float16") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + + fp32_grad = grad.astype("float32") + pnorm = np.sqrt(np.square(master_param).sum()) + gnorm = np.sqrt(np.square(fp32_grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * pnorm) + fp32_grad = fp32_grad * rescale_grad + velocity_out = mu * velocity + local_lr * ( + fp32_grad + lars_weight_decay * master_param) + p_new = master_param - velocity_out + param_out = p_new.astype("float16") + master_param_out = p_new + + params.append(("SubParam_" + str(i), param)) + grads.append(("SubGrad_" + str(i), grad)) + velocitys.append(("SubVelocity_" + str(i), velocity)) + learning_rates.append(("SubLearning_rate_" + str(i), learning_rate)) + velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out)) + param_outs.append(("SubParam_out_" + str(i), param_out)) + master_params.append(("SubMasterParam_" + str(i), master_param)) + master_param_outs.append( + ("SubMasterParamOut_" + str(i), master_param_out)) + self.inputs = { - 'Param': param, - 'Grad': grad, - 'Velocity': velocity, - 'LearningRate': learning_rate, - 'MasterParam': master_param, + 'Param': params, + 'Grad': grads, + 'Velocity': velocitys, + 'LearningRate': learning_rates, + 'MasterParam': master_params, } self.attrs = { 'mu': mu, 'lars_coeff': lars_coeff, - 'lars_weight_decay': lars_weight_decay, + 'lars_weight_decay': [lars_weight_decay], 'multi_precision': True, 'rescale_grad': rescale_grad } - fp32_grad = grad.astype("float32") - pnorm = np.sqrt(np.square(master_param).sum()) - gnorm = np.sqrt(np.square(fp32_grad).sum()) - local_lr = learning_rate * lars_coeff * pnorm / ( - gnorm + lars_weight_decay * pnorm) - fp32_grad = fp32_grad * rescale_grad - velocity_out = mu * velocity + local_lr * (fp32_grad + lars_weight_decay - * master_param) - p_new = master_param - velocity_out - param_out = p_new.astype("float16") - master_param_out = p_new - self.outputs = { - 'ParamOut': param_out, - 'VelocityOut': velocity_out, - 'MasterParamOut': master_param_out + 'ParamOut': param_outs, + 'VelocityOut': velocity_outs, + 'MasterParamOut': master_param_outs } def test_check_output(self): @@ -191,46 +211,65 @@ def test_check_output(self): if core.is_float16_supported(place): self.check_output_with_place(place) + def config(self): + self.params_num = 1 + class TestLarsMomentumOp(OpTest): def setUp(self): + self.config() self.op_type = "lars_momentum" - - param = np.random.random((123, 321)).astype("float32") - grad = np.random.random((123, 321)).astype("float32") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 lars_coeff = 0.001 lars_weight_decay = 0.0005 + params = [] + grads = [] + velocitys = [] + param_outs = [] + velocity_outs = [] + learning_rates = [] + for i in range(self.params_num): + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + pnorm = np.sqrt(np.square(param).sum()) + gnorm = np.sqrt(np.square(grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * param) + velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay + * param) + param_out = param - velocity_out + + params.append(("SubParam_" + str(i), param)) + grads.append(("SubGrad_" + str(i), grad)) + velocitys.append(("SubVelocity_" + str(i), velocity)) + learning_rates.append(("SubLearning_rate_" + str(i), learning_rate)) + velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out)) + param_outs.append(("SubParam_out_" + str(i), param_out)) + self.inputs = { - 'Param': param, - 'Grad': grad, - 'Velocity': velocity, - 'LearningRate': learning_rate + 'Param': params, + 'Grad': grads, + 'Velocity': velocitys, + 'LearningRate': learning_rates } self.attrs = { 'mu': mu, 'lars_coeff': lars_coeff, - 'lars_weight_decay': lars_weight_decay + 'lars_weight_decay': [lars_weight_decay] } - - pnorm = np.sqrt(np.square(param).sum()) - gnorm = np.sqrt(np.square(grad).sum()) - local_lr = learning_rate * lars_coeff * pnorm / ( - gnorm + lars_weight_decay * param) - velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay * - param) - param_out = param - velocity_out - - self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + self.outputs = {'ParamOut': param_outs, 'VelocityOut': velocity_outs} def test_check_output(self): paddle.enable_static() self.check_output() + def config(self): + self.params_num = 1 + class TestSparseMomentumOp(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py b/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py new file mode 100644 index 0000000000000..056db5b8590ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.contrib.layers.nn import pow2_decay_with_linear_warmup +from paddle.optimizer.lr import LinearWarmup +from paddle.optimizer.lr import PolynomialDecay +import unittest + + +def gen_pow2_warmup_op_lr(warmup_steps, total_steps, base_lr, end_lr, place): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, base_lr, + end_lr) + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + while True: + lr_np = exe.run(main, fetch_list=[lr])[0] + yield lr_np[0] + + +class Pow2Warmup(LinearWarmup): + def __init__(self, warmup_steps, total_steps, base_lr, end_lr): + assert total_steps > warmup_steps + lr_sch = PolynomialDecay( + learning_rate=base_lr, + decay_steps=total_steps - warmup_steps, + end_lr=end_lr, + power=2) + + super(Pow2Warmup, self).__init__( + learning_rate=lr_sch, + warmup_steps=warmup_steps, + start_lr=0.0, + end_lr=base_lr) + + +def gen_pow2_warmup_py_lr(warmup_steps, total_steps, base_lr, end_lr, place): + lr_sch = Pow2Warmup(warmup_steps, total_steps, base_lr, end_lr) + lr_sch.step() + while True: + yield lr_sch() + lr_sch.step() + + +class TestPow2WarmupLRScheduler(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.params = { + 'warmup_steps': 30, + 'total_steps': 100, + 'base_lr': 0.02, + 'end_lr': 0.001, + } + self.step_num = 1000 + + def check_with_place(self, place): + kwargs = dict(self.params) + kwargs['place'] = place + lr_sch_op = gen_pow2_warmup_op_lr(**kwargs) + lr_sch_py = gen_pow2_warmup_py_lr(**kwargs) + for i, (lr_op, lr_py) in enumerate(zip(lr_sch_op, lr_sch_py)): + self.assertLess(abs(lr_op - lr_py), 1e-6) + if i > self.step_num: + break + + def test_main(self): + self.check_with_place(paddle.CPUPlace()) + if paddle.is_compiled_with_cuda(): + self.check_with_place(paddle.CUDAPlace(0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_copy_from.py b/python/paddle/fluid/tests/unittests/test_tensor_copy_from.py new file mode 100644 index 0000000000000..6a91c2182d1c5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_copy_from.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np +from paddle.fluid.core import LoDTensor as Tensor + + +class TestTensorCopyFrom(unittest.TestCase): + def test_main(self): + place = paddle.CPUPlace() + np_value = np.random.random(size=[10, 30]).astype('float32') + t_src = Tensor() + t_src.set(np_value, place) + self.assertTrue(np.array_equal(np_value, t_src)) + + t_dst1 = Tensor() + t_dst1._copy_from(t_src, place) + self.assertTrue(np.array_equal(np_value, t_dst1)) + + t_dst2 = Tensor() + t_dst2._copy_from(t_src, place, 5) + self.assertTrue(np.array_equal(np.array(np_value[0:5]), t_dst2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py index 694cde4f28624..9a6710d095097 100644 --- a/python/paddle/incubate/operators/__init__.py +++ b/python/paddle/incubate/operators/__init__.py @@ -14,3 +14,4 @@ from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401 from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401 +from .resnet_unit import ResNetUnit #noqa: F401 diff --git a/python/paddle/incubate/operators/resnet_unit.py b/python/paddle/incubate/operators/resnet_unit.py new file mode 100644 index 0000000000000..cba1d4863cbd4 --- /dev/null +++ b/python/paddle/incubate/operators/resnet_unit.py @@ -0,0 +1,269 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import collections +import itertools +import six +import math +import sys +import warnings +from functools import partial, reduce + +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle import framework +from paddle.device import get_device, get_cudnn_version +from paddle.nn import initializer as I +from paddle.nn import Layer, LayerList +from paddle.fluid.layers import utils +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as +from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.param_attr import ParamAttr +from paddle import _C_ops +__all__ = ['resnet_unit', 'ResNetUnit'] + + +def resnet_unit(x, filter_x, scale_x, bias_x, mean_x, var_x, z, filter_z, + scale_z, bias_z, mean_z, var_z, stride, stride_z, padding, + dilation, groups, momentum, eps, data_format, fuse_add, + has_shortcut, use_global_stats, is_test, act): + + helper = LayerHelper('resnet_unit', **locals()) + bn_param_dtype = fluid.core.VarDesc.VarType.FP32 + bit_mask_dtype = fluid.core.VarDesc.VarType.INT32 + out = helper.create_variable_for_type_inference(x.dtype) + bit_mask = helper.create_variable_for_type_inference( + dtype=bit_mask_dtype, stop_gradient=True) + # intermediate_out for x + conv_x = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + saved_mean_x = helper.create_variable_for_type_inference( + dtype=bn_param_dtype, stop_gradient=True) + saved_invstd_x = helper.create_variable_for_type_inference( + dtype=bn_param_dtype, stop_gradient=True) + running_mean_x = mean_x + running_var_x = var_x + # intermediate_out for z + conv_z = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + saved_mean_z = helper.create_variable_for_type_inference( + dtype=bn_param_dtype, stop_gradient=True) + saved_invstd_z = helper.create_variable_for_type_inference( + dtype=bn_param_dtype, stop_gradient=True) + running_mean_z = helper.create_variable_for_type_inference( + dtype=bn_param_dtype, stop_gradient=True) if mean_z is None else mean_z + running_var_z = helper.create_variable_for_type_inference( + dtype=bn_param_dtype, stop_gradient=True) if var_z is None else var_z + + inputs = { + 'X': x, + 'FilterX': filter_x, + 'ScaleX': scale_x, + 'BiasX': bias_x, + 'MeanX': mean_x, + 'VarX': var_x, + 'Z': z, + 'FilterZ': filter_z, + 'ScaleZ': scale_z, + 'BiasZ': bias_z, + 'MeanZ': mean_z, + 'VarZ': var_z + } + + attrs = { + 'stride': stride, + 'stride_z': stride_z, + 'padding': padding, + 'dilation': dilation, + 'group': groups, + 'momentum': momentum, + 'epsilon': eps, + 'data_format': data_format, + 'fuse_add': fuse_add, + 'has_shortcut': has_shortcut, + 'use_global_stats': use_global_stats, + 'is_test': is_test, + 'act_type': act + } + + outputs = { + 'Y': out, + 'BitMask': bit_mask, + 'ConvX': conv_x, + 'SavedMeanX': saved_mean_x, + 'SavedInvstdX': saved_invstd_x, + 'RunningMeanX': running_mean_x, + 'RunningVarX': running_var_x, + 'ConvZ': conv_z, + 'SavedMeanZ': saved_mean_z, + 'SavedInvstdZ': saved_invstd_z, + 'RunningMeanZ': running_mean_z, + 'RunningVarZ': running_var_z, + } + + helper.append_op( + type='resnet_unit', inputs=inputs, outputs=outputs, attrs=attrs) + + return out + + +class ResNetUnit(Layer): + r""" + ******Temporary version******. + ResNetUnit is designed for optimize the performence by using cudnnv8 API. + """ + + def __init__(self, + num_channels_x, + num_filters, + filter_size, + stride=1, + momentum=0.9, + eps=1e-5, + data_format='NHWC', + act='relu', + fuse_add=False, + has_shortcut=False, + use_global_stats=False, + is_test=False, + filter_x_attr=None, + scale_x_attr=None, + bias_x_attr=None, + moving_mean_x_name=None, + moving_var_x_name=None, + num_channels_z=1, + stride_z=1, + filter_z_attr=None, + scale_z_attr=None, + bias_z_attr=None, + moving_mean_z_name=None, + moving_var_z_name=None): + super(ResNetUnit, self).__init__() + self._stride = stride + self._stride_z = stride_z + self._dilation = 1 + self._kernel_size = utils.convert_to_list(filter_size, 2, 'kernel_size') + self._padding = (filter_size - 1) // 2 + self._groups = 1 + self._momentum = momentum + self._eps = eps + self._data_format = data_format + self._act = act + self._fuse_add = fuse_add + self._has_shortcut = has_shortcut + self._use_global_stats = use_global_stats + self._is_test = is_test + + # check format + valid_format = {'NHWC'} + if data_format not in valid_format: + raise ValueError( + "conv_format must be one of {}, but got conv_format='{}'". + format(valid_format, data_format)) + + def _get_default_param_initializer(channels): + filter_elem_num = np.prod(self._kernel_size) * channels + std = (2.0 / filter_elem_num)**0.5 + return I.Normal(0.0, std) + + # initial filter + bn_param_dtype = fluid.core.VarDesc.VarType.FP32 + bn_param_shape = [1, 1, 1, num_filters] + filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x] + filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z] + + self.filter_x = self.create_parameter( + shape=filter_x_shape, + attr=filter_x_attr, + default_initializer=_get_default_param_initializer(num_channels_x)) + self.scale_x = self.create_parameter( + shape=bn_param_shape, + attr=scale_x_attr, + dtype=bn_param_dtype, + default_initializer=I.Constant(1.0)) + self.bias_x = self.create_parameter( + shape=bn_param_shape, + attr=bias_x_attr, + dtype=bn_param_dtype, + is_bias=True) + self.mean_x = self.create_parameter( + attr=ParamAttr( + name=moving_mean_x_name, + initializer=I.Constant(0.0), + trainable=False), + shape=bn_param_shape, + dtype=bn_param_dtype) + self.mean_x.stop_gradient = True + self.var_x = self.create_parameter( + attr=ParamAttr( + name=moving_var_x_name, + initializer=I.Constant(1.0), + trainable=False), + shape=bn_param_shape, + dtype=bn_param_dtype) + self.var_x.stop_gradient = True + if has_shortcut: + self.filter_z = self.create_parameter( + shape=filter_z_shape, + attr=filter_z_attr, + default_initializer=_get_default_param_initializer( + num_channels_z)) + self.scale_z = self.create_parameter( + shape=bn_param_shape, + attr=scale_z_attr, + dtype=bn_param_dtype, + default_initializer=I.Constant(1.0)) + self.bias_z = self.create_parameter( + shape=bn_param_shape, + attr=bias_z_attr, + dtype=bn_param_dtype, + is_bias=True) + self.mean_z = self.create_parameter( + attr=ParamAttr( + name=moving_mean_z_name, + initializer=I.Constant(0.0), + trainable=False), + shape=bn_param_shape, + dtype=bn_param_dtype) + self.mean_z.stop_gradient = True + self.var_z = self.create_parameter( + attr=ParamAttr( + name=moving_var_z_name, + initializer=I.Constant(1.0), + trainable=False), + shape=bn_param_shape, + dtype=bn_param_dtype) + self.var_z.stop_gradient = True + else: + self.filter_z = None + self.scale_z = None + self.bias_z = None + self.mean_z = None + self.var_z = None + + def forward(self, x, z=None): + if self._fuse_add and z is None: + raise ValueError("z can not be None") + + out = resnet_unit( + x, self.filter_x, self.scale_x, self.bias_x, self.mean_x, + self.var_x, z, self.filter_z, self.scale_z, self.bias_z, + self.mean_z, self.var_z, self._stride, self._stride_z, + self._padding, self._dilation, self._groups, self._momentum, + self._eps, self._data_format, self._fuse_add, self._has_shortcut, + self._use_global_stats, self._is_test, self._act) + return out