From f4b746dc34127d6dc9e52d74192a7289a1e9bd28 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Thu, 23 Sep 2021 06:32:46 +0000 Subject: [PATCH] Add bn_add_relu test --- paddle/fluid/operators/fused/CMakeLists.txt | 7 +- .../operators/fused/cudnn_bn_add_relu_test.cu | 353 ++++++++++++++++++ .../fused/cudnn_bn_stats_finalize.cu.h | 142 +++++++ .../fused/cudnn_scale_bias_add_relu.cu.h | 198 ++++++++++ 4 files changed, 699 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cu create mode 100644 paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h create mode 100644 paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 599be6912b760..cc3b66c302e75 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -16,7 +16,8 @@ register_operators(EXCLUDES fusion_gru_op fusion_lstm_op fused_bn_add_activation_op - fused_transformer_op) + fused_transformer_op + resnet_unit_op) # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) @@ -78,7 +79,11 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) endif() + # resnet_unit needs cudnn 8.0 above if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) + op_library(resnet_unit_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n") cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) + nv_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cu DEPS batch_norm_op tensor op_registry device_context generator memory) endif() endif() diff --git a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cu b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cu new file mode 100644 index 0000000000000..590204b43b59b --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cu @@ -0,0 +1,353 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h" +#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/float16.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace op = paddle::operators; +using Tensor = paddle::framework::Tensor; + +USE_OP(batch_norm); + +// get paddle batchnorm op results as baseline +void GetBatchNormOp(const Tensor &x, const Tensor &scale, const Tensor &bias, + Tensor *mean, Tensor *var, Tensor *y, Tensor *saved_mean, + Tensor *saved_var, Tensor *reserve_space, + const framework::DDim &data_dim, + const framework::DDim ¶m_dim, + const platform::CUDADeviceContext &ctx) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + auto var_scale = scope.Var("Scale"); + auto tensor_scale = var_scale->GetMutable(); + auto var_bias = scope.Var("Bias"); + auto tensor_bias = var_bias->GetMutable(); + auto var_mean = scope.Var("Mean"); + auto tensor_mean = var_mean->GetMutable(); + auto var_var = scope.Var("Variance"); + auto tensor_var = var_var->GetMutable(); + auto var_y = scope.Var("Y"); + auto tensor_y = var_y->GetMutable(); + auto var_saved_mean = scope.Var("SavedMean"); + auto tensor_saved_mean = var_saved_mean->GetMutable(); + auto var_saved_var = scope.Var("SavedVariance"); + auto tensor_saved_var = var_saved_var->GetMutable(); + auto var_reserve = scope.Var("ReserveSpace"); + auto tensor_reserve = var_reserve->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(x, place, tensor_x); + TensorCopySync(scale, place, tensor_scale); + TensorCopySync(bias, place, tensor_bias); + TensorCopySync(*mean, place, tensor_mean); + TensorCopySync(*var, place, tensor_var); + + framework::AttributeMap attrs; + std::string data_layout = "NHWC"; + attrs.insert({"data_layout", data_layout}); + + auto op = framework::OpRegistry::CreateOp( + "batch_norm", {{"X", {"X"}}, + {"Scale", {"Scale"}}, + {"Bias", {"Bias"}}, + {"Mean", {"Mean"}}, + {"Variance", {"Variance"}}}, + {{"Y", {"Y"}}, + {"MeanOut", {"Mean"}}, + {"VarianceOut", {"Variance"}}, + {"SavedMean", {"SavedMean"}}, + {"SavedVariance", {"SavedVariance"}}, + {"ReserveSpace", {"ReserveSpace"}}}, + attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*tensor_y, place, y); + TensorCopySync(*tensor_mean, place, mean); + TensorCopySync(*tensor_var, place, var); + TensorCopySync(*tensor_saved_mean, place, saved_mean); + TensorCopySync(*tensor_saved_var, place, saved_var); + TensorCopySync(*tensor_reserve, place, reserve_space); + ctx.Wait(); +} + +template +class TestCudnnBNAddReluForward { + public: + TestCudnnBNAddReluForward() { + batch_size_ = 2; + height_ = 8; + width_ = 8; + channels_ = 32; + ele_count_ = batch_size_ * height_ * width_; + ctx_ = new platform::CUDADeviceContext(place_); + } + + TestCudnnBNAddReluForward(int batch_size, int height, int width, + int channels) { + batch_size_ = batch_size; + height_ = height; + width_ = width; + channels_ = channels; + ele_count_ = batch_size_ * height_ * width_; + ctx_ = new platform::CUDADeviceContext(place_); + } + + ~TestCudnnBNAddReluForward() { delete ctx_; } + + void SetUp() { + data_size_ = batch_size_ * height_ * width_ * channels_; + param_size_ = channels_; + + x_vec_.resize(data_size_); + sum_vec_.resize(param_size_); + sum_of_squares_vec_.resize(param_size_); + scale_vec_.resize(param_size_); + bias_vec_.resize(param_size_); + mean_vec_.resize(param_size_); + var_vec_.resize(param_size_); + y_vec_.resize(data_size_); + saved_mean_vec_.resize(param_size_); + saved_var_vec_.resize(param_size_); + equiv_scale_vec_.resize(param_size_); + equiv_bias_vec_.resize(param_size_); + base_y_vec_.resize(data_size_); + base_mean_vec_.resize(param_size_); + base_var_vec_.resize(param_size_); + base_saved_mean_vec_.resize(param_size_); + base_saved_var_vec_.resize(param_size_); + + // initial data + std::default_random_engine random(0); + std::uniform_real_distribution dis(0.0, 1.0); + for (int c = 0; c < channels_; ++c) { + float sum = 0; + float sum_of_squares = 0; + for (int n = 0; n < batch_size_; ++n) { + for (int h = 0; h < height_; ++h) { + for (int w = 0; w < width_; ++w) { + float temp = dis(random); + float ttemp = static_cast(static_cast(temp)); + int idx = n * height_ * width_ * channels_ + + h * width_ * channels_ + w * channels_ + c; + sum += ttemp; + sum_of_squares += ttemp * ttemp; + x_vec_[idx] = static_cast(temp); + } + } + } + sum_vec_[c] = sum; + sum_of_squares_vec_[c] = sum_of_squares; + } + for (int i = 0; i < param_size_; ++i) { + scale_vec_[i] = 1.0; + bias_vec_[i] = 0.0; + mean_vec_[i] = 0.0; + var_vec_[i] = 1.0; + saved_mean_vec_[i] = 0.0; + saved_var_vec_[i] = 0.0; + base_mean_vec_[i] = 0.0; + base_var_vec_[i] = 1.0; + base_saved_mean_vec_[i] = 0.0; + base_saved_var_vec_[i] = 0.0; + } + for (int i = 0; i < data_size_; ++i) { + y_vec_[i] = static_cast(0.0f); + base_y_vec_[i] = static_cast(0.0f); + } + + // input + framework::TensorFromVector(x_vec_, *ctx_, &x_); + x_.Resize({batch_size_, height_, width_, channels_}); + framework::TensorFromVector(sum_vec_, *ctx_, &sum_); + sum_.Resize({1, 1, 1, channels_}); + framework::TensorFromVector(sum_of_squares_vec_, *ctx_, + &sum_of_squares_); + sum_of_squares_.Resize({1, 1, 1, channels_}); + framework::TensorFromVector(scale_vec_, *ctx_, &scale_); + scale_.Resize({1, 1, 1, channels_}); + framework::TensorFromVector(bias_vec_, *ctx_, &bias_); + bias_.Resize({1, 1, 1, channels_}); + framework::TensorFromVector(mean_vec_, *ctx_, &mean_); + mean_.Resize({1, 1, 1, channels_}); + framework::TensorFromVector(var_vec_, *ctx_, &var_); + var_.Resize({1, 1, 1, channels_}); + // baseline + framework::TensorFromVector(base_mean_vec_, *ctx_, &base_mean_); + base_mean_.Resize({1, 1, 1, channels_}); + framework::TensorFromVector(base_var_vec_, *ctx_, &base_var_); + base_var_.Resize({1, 1, 1, channels_}); + // output + y_.Resize({batch_size_, height_, width_, channels_}); + equiv_scale_.Resize({1, 1, 1, channels_}); + equiv_bias_.Resize({1, 1, 1, channels_}); + saved_mean_.Resize({1, 1, 1, channels_}); + saved_var_.Resize({1, 1, 1, channels_}); + // baseline + base_y_.Resize({batch_size_, height_, width_, channels_}); + base_saved_mean_.Resize({1, 1, 1, channels_}); + base_saved_var_.Resize({1, 1, 1, channels_}); + // bitmask + int C = channels_; + int64_t NHW = ele_count_; + int32_t C_int32Elems = ((C + 63) & ~63) / 32; + int32_t NHW_int32Elems = (NHW + 31) & ~31; + bitmask_.Resize({NHW_int32Elems, C_int32Elems, 1}); + + ctx_->Wait(); + } + + void BaselineForward() { + GetBatchNormOp(x_, scale_, bias_, &base_mean_, &base_var_, &base_y_, + &base_saved_mean_, &base_saved_var_, &reserve_space_, + x_.dims(), framework::make_ddim({channels_}), *ctx_); + + ctx_->Wait(); + } + + // get forward results of cudnn_bn_stats_finalize + cudnn_scale_bias_add_relu + void FusedForward() { + auto data_shape = framework::vectorize(x_.dims()); + auto param_shape = framework::vectorize(scale_.dims()); + auto bitmask_shape = framework::vectorize(bitmask_.dims()); + T *x_ptr = x_.data(); + float *sum_ptr = sum_.data(); + float *sum_of_squares_ptr = sum_of_squares_.data(); + float *scale_ptr = scale_.data(); + float *bias_ptr = bias_.data(); + float *mean_ptr = mean_.data(); + float *var_ptr = var_.data(); + float *saved_mean_ptr = saved_mean_.mutable_data(place_); + float *saved_var_ptr = saved_var_.mutable_data(place_); + T *equiv_scale_ptr = equiv_scale_.mutable_data(place_); + T *equiv_bias_ptr = equiv_bias_.mutable_data(place_); + T *y_ptr = y_.mutable_data(place_); + int32_t *bitmask_ptr = bitmask_.mutable_data(place_); + + // 1. BN Stats Finalize + std::shared_ptr> bn_op( + new op::CudnnBNStatsFinalizeOp()); + bn_op->Init(*ctx_, param_shape); + bn_op->Forward(*ctx_, sum_ptr, sum_of_squares_ptr, scale_ptr, bias_ptr, + saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr, + equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_, + true); + // 2. Scale Bias + Relu (not fused add) + std::string act_type = ""; + std::shared_ptr> sbar_op( + new op::CudnnScaleBiasAddReluOp(false, false)); + sbar_op->Init(*ctx_, act_type, data_shape, bitmask_shape, data_shape, + param_shape); + sbar_op->Forward(*ctx_, x_ptr, equiv_scale_ptr, equiv_bias_ptr, y_ptr, + bitmask_ptr); + + ctx_->Wait(); + } + + void Run() { + SetUp(); + BaselineForward(); + FusedForward(); + } + + // check forward correctness between baseline and results of fused op. + void CheckOut(const float diff, bool is_relative_atol = false) { + TensorToVector(y_, *ctx_, &y_vec_); + TensorToVector(mean_, *ctx_, &mean_vec_); + TensorToVector(var_, *ctx_, &var_vec_); + TensorToVector(saved_mean_, *ctx_, &saved_mean_vec_); + TensorToVector(saved_var_, *ctx_, &saved_var_vec_); + TensorToVector(base_y_, *ctx_, &base_y_vec_); + TensorToVector(base_mean_, *ctx_, &base_mean_vec_); + TensorToVector(base_var_, *ctx_, &base_var_vec_); + TensorToVector(base_saved_mean_, *ctx_, &base_saved_mean_vec_); + TensorToVector(base_saved_var_, *ctx_, &base_saved_var_vec_); + ctx_->Wait(); + + for (int i = 0; i < data_size_; ++i) { + if (is_relative_atol) { + EXPECT_LT(std::abs((y_vec_[i] - base_y_vec_[i]) / base_y_vec_[i]), + static_cast(diff)); + } else { + EXPECT_LT(std::abs(y_vec_[i] - base_y_vec_[i]), static_cast(diff)); + } + } + + for (int i = 0; i < param_size_; ++i) { + if (is_relative_atol) { + EXPECT_LT( + std::abs((mean_vec_[i] - base_mean_vec_[i]) / base_mean_vec_[i]), + diff); + EXPECT_LT(std::abs((var_vec_[i] - base_var_vec_[i]) / base_var_vec_[i]), + diff); + EXPECT_LT(std::abs((saved_mean_vec_[i] - base_saved_mean_vec_[i]) / + base_saved_mean_vec_[i]), + diff); + EXPECT_LT(std::abs((saved_var_vec_[i] - base_saved_var_vec_[i]) / + base_saved_var_vec_[i]), + diff); + } else { + EXPECT_LT(std::abs(mean_vec_[i] - base_mean_vec_[i]), diff); + EXPECT_LT(std::abs(var_vec_[i] - base_var_vec_[i]), diff); + EXPECT_LT(std::abs(saved_mean_vec_[i] - base_saved_mean_vec_[i]), diff); + EXPECT_LT(std::abs(saved_var_vec_[i] - base_saved_var_vec_[i]), diff); + } + } + } + + private: + int batch_size_, height_, width_, channels_; + int data_size_, param_size_; + + framework::Tensor x_, scale_, bias_, mean_, var_, sum_, sum_of_squares_; + framework::Tensor y_, saved_mean_, saved_var_, equiv_scale_, equiv_bias_, + bitmask_; + std::vector x_vec_, y_vec_, equiv_scale_vec_, equiv_bias_vec_; + std::vector sum_vec_, sum_of_squares_vec_, scale_vec_, bias_vec_; + std::vector mean_vec_, var_vec_, saved_mean_vec_, saved_var_vec_; + // baseline + framework::Tensor base_y_, base_mean_, base_var_, base_saved_mean_, + base_saved_var_, reserve_space_; + std::vector base_y_vec_; + std::vector base_mean_vec_, base_var_vec_, base_saved_mean_vec_, + base_saved_var_vec_; + + double eps_ = 1e-5; + float momentum_ = 0.9; + int ele_count_; + platform::CUDAPlace place_; + platform::CUDADeviceContext *ctx_; +}; + +TEST(CudnnBNAddReluForward, GPUCudnnBNAddReluForwardFp16) { + int batch_size = 4; + int height = 8; + int width = 8; + int channels = 64; + TestCudnnBNAddReluForward test(batch_size, height, + width, channels); + test.Run(); + test.CheckOut(2e-3); +} diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h new file mode 100644 index 0000000000000..a1004fb6670db --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +namespace dynload = platform::dynload; + +#if CUDNN_VERSION >= 8000 +template +class CudnnBNStatsFinalizeOp { + public: + CudnnBNStatsFinalizeOp() + : train_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING), + inference_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE) { + dtype_ = platform::CudnnDataType::type; + dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT : dtype_; + } + + ~CudnnBNStatsFinalizeOp() {} + + void Init(const platform::CUDADeviceContext &ctx, + const std::vector ¶m_shape) { + InitDescriptors(ctx, param_shape); + + // Set constant_param for train op + train_op_.SetOpConstParamAttr( + {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER, + CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER, + CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_ELEM_ALIGNED); + // Set input and output desc for train op + train_op_.SetOpConstParamDesc( + {CUDNN_PARAM_YSTATS_DESC, CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC}, + in_desc_.desc()); + train_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + out_desc_.desc()); + + // Set constant_param for inference op + inference_op_.SetOpConstParamAttr( + {CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER, + CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_ELEM_ALIGNED); + // Set input and output desc for inference op + inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC, + in_desc_.desc()); + inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + out_desc_.desc()); + + // Get workspace + auto handle = ctx.cudnn_handle(); + for (auto op : {&train_op_, &inference_op_}) { + op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, CUDNN_BATCHNORM_SPATIAL); + // Check workspace size, also creates plan. + size_t workspace_size_bytes = op->GetWorkspaceSizeInBytes(handle); + PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U, + platform::errors::InvalidArgument( + "Unexpected non-zero workspace size for " + "CudnnBNStatsFinalize op.")); + op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + static_cast(nullptr)); + op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, &workspace_size_bytes); + } + } + + void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr, + float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr, + float *saved_mean_ptr, float *saved_invstd_ptr, + float *running_mean_ptr, float *running_var_ptr, + T *equiv_scale_ptr, T *equiv_bias_ptr, double eps, + float momentum, int64_t ele_count, bool is_train) { + auto &op = is_train ? train_op_ : inference_op_; + + // Set variant_param for both inference_op_ and train_op_ + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_VAR, running_var_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, equiv_scale_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, equiv_bias_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EPSILON, &eps); + + // Set extra variant_param only for train_op_: + if (is_train) { + op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, saved_invstd_ptr); + double avg_factor = 1.0 - momentum; + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT, + &ele_count); + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR, + &avg_factor); + } + // fused op execute + auto handle = ctx.cudnn_handle(); + op.Execute(handle); + } + + // TBD + void Backward(const platform::CUDADeviceContext &ctx) {} + + private: + void InitDescriptors(const platform::CUDADeviceContext &ctx, + const std::vector ¶m_shape) { + cudnnTensorFormat_t format = CUDNN_TENSOR_NHWC; + in_desc_.set(param_shape, format, dtype_param_); + out_desc_.set(param_shape, format, dtype_); + } + + cudnnDataType_t dtype_; + cudnnDataType_t dtype_param_; + platform::TensorDescriptor in_desc_; + platform::TensorDescriptor out_desc_; + + CudnnFusionOp train_op_; + CudnnFusionOp inference_op_; +}; +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h new file mode 100644 index 0000000000000..3e106836f8d0a --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -0,0 +1,198 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" +#include "paddle/fluid/operators/fused/resnet_unit_op.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +namespace dynload = platform::dynload; + +#if CUDNN_VERSION >= 8000 +template +class CudnnScaleBiasAddReluOp { + public: + CudnnScaleBiasAddReluOp(bool fused_add, bool has_shortcut) + : fused_add_(fused_add), + has_shortcut_(has_shortcut), + fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK) {} + + ~CudnnScaleBiasAddReluOp() {} + + void Init(const platform::CUDADeviceContext &ctx, const std::string &act_type, + const std::vector &out_shape, + const std::vector &bitmask_shape, + const std::vector &x_shape, + const std::vector ¶m_shape, + std::vector z_shape = {}) { + dtype_ = platform::CudnnDataType::type; + format_ = CUDNN_TENSOR_NHWC; + InitDescriptors(ctx, act_type, out_shape, bitmask_shape, x_shape, + param_shape, z_shape); + GetWorkspaceSize(ctx); + } + + void Forward(const platform::CUDADeviceContext &ctx, T *x_ptr, T *x_scale_ptr, + T *x_bias_ptr, T *out_ptr, int32_t *bitmask_ptr, + T *z_ptr = nullptr, T *z_scale_ptr = nullptr, + T *z_bias_ptr = nullptr) { + auto handle = ctx.cudnn_handle(); + auto workspace_handle = ctx.cudnn_workspace_handle(); + // Set variant_param + // input ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr); + if (has_shortcut_) { + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr); + } else { + if (fused_add_) { + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); + } + } + + fwd_op_.SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); + + // output ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); + + workspace_handle.RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // workspace ptr + fwd_op_.Execute(handle); + }, + fwd_workspace_byte_); + } + + void Backward(const platform::CUDADeviceContext &ctx) {} + + private: + void InitDescriptors(const platform::CUDADeviceContext &ctx, + const std::string &act_type, + const std::vector &out_shape, + const std::vector &bitmask_shape, + const std::vector &x_shape, + const std::vector ¶m_shape, + std::vector z_shape = {}) { + // Set constant_param + if (has_shortcut_) { + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER, + CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + } else { + if (fused_add_) { + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER, + CUDNN_PARAM_ZDATA_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + } else { + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + } + } + + // set input desc + in_x_desc_.set(x_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, in_x_desc_.desc()); + if (has_shortcut_ || fused_add_) { + in_z_desc_.set(z_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, in_z_desc_.desc()); + } + + // set scale/bias desc + equiv_x_scale_bias_desc_.set(param_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + equiv_x_scale_bias_desc_.desc()); + if (has_shortcut_) { + equiv_z_scale_bias_desc_.set(param_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC, + equiv_z_scale_bias_desc_.desc()); + } + + // set output desc + out_desc_.set(out_shape, format_, dtype_); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, out_desc_.desc()); + + // set bitmask desc + bitmask_desc_.set(bitmask_shape, format_, CUDNN_DATA_INT32); + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC, + bitmask_desc_.desc()); + + // set activation desc + cudnnActivationMode_t mode = CUDNN_ACTIVATION_IDENTITY; + if (act_type != "") { + PADDLE_ENFORCE_EQ( + act_type, "relu", + platform::errors::InvalidArgument( + "Only relu activation supported in normalized convolution.")); + mode = CUDNN_ACTIVATION_RELU; + } + double dummy_clip = 0.0; + activation_desc_.set(mode, dummy_clip); + if (mode != CUDNN_ACTIVATION_IDENTITY) { + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC, + activation_desc_.desc()); + } + + // others + fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + } + + void GetWorkspaceSize(const platform::CUDADeviceContext &ctx) { + // Make op plan and get workspace size + auto handle = ctx.cudnn_handle(); + fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); + } + + bool fused_add_ = false; + bool has_shortcut_ = false; + size_t fwd_workspace_byte_; + + cudnnDataType_t dtype_; + cudnnTensorFormat_t format_; + + platform::TensorDescriptor in_x_desc_; + platform::TensorDescriptor in_z_desc_; + platform::TensorDescriptor out_desc_; + platform::TensorDescriptor bitmask_desc_; + platform::TensorDescriptor equiv_x_scale_bias_desc_; + platform::TensorDescriptor equiv_z_scale_bias_desc_; + platform::ActivationDescriptor activation_desc_; + + CudnnFusionOp fwd_op_; +}; +#endif +} // namespace operators +} // namespace paddle