From 5ae0019361c8f2b5f5b157ef3976e8e2507f8b1d Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Tue, 10 Oct 2023 14:30:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.102=E3=80=91=20mo?= =?UTF-8?q?ve=20fused=5Fembedding=5Feltwise=5Flayernorm/fusion=5Ftranspose?= =?UTF-8?q?=5Fflatten=5Fconcat/fused=5Ffc=5Felementwise=5Flayernorm=20to?= =?UTF-8?q?=20phi=20=20(#57865)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * transplant fused_embedding_elt_wise_layer_norm_kernel * fix the error * fix some bug * move the transpose to phi but new IR have a bug in output==nullptr. embedding_eltwise_op also have the bug in new IR. because the wrong memory accesss * remove some useless code * move fused_fc_elementwise_layernorm to phi, but have a bug in making * fix the bug in build the fused_fc_elementwise_layernorm_kernel and pass the test with new IR * try to fix the bug --- paddle/fluid/operators/fused/CMakeLists.txt | 10 - .../fused_embedding_eltwise_layernorm_op.cc | 176 ---------- .../fused_embedding_eltwise_layernorm_op.cu | 162 --------- .../fused_fc_elementwise_layernorm_op.cc | 294 ---------------- .../fusion_transpose_flatten_concat_op.cc | 132 ------- .../fusion_transpose_flatten_concat_op.cu.cc | 128 ------- .../fusion_transpose_flatten_concat_op.h | 51 --- .../operators/math/bert_encoder_functor.cu | 126 ------- .../operators/math/bert_encoder_functor.h | 29 -- paddle/phi/api/yaml/fused_ops.yaml | 28 ++ paddle/phi/api/yaml/op_compat.yaml | 39 +++ paddle/phi/infermeta/fusion.cc | 325 ++++++++++++++++++ paddle/phi/infermeta/fusion.h | 30 ++ paddle/phi/kernels/funcs/common_shape.h | 25 ++ .../funcs/emb_eltwise_layer_norm_functor.cu | 210 +++++++++++ .../funcs/emb_eltwise_layer_norm_functor.h | 51 +++ ...used_embedding_eltwise_layernorm_kernel.cu | 156 +++++++++ .../fused_fc_elementwise_layernorm_kernel.cu} | 246 +++++++------ .../fusion_transpose_flatten_concat_kernel.cu | 127 +++++++ ...test_fusion_transpose_flatten_concat_op.py | 2 +- 20 files changed, 1125 insertions(+), 1222 deletions(-) delete mode 100644 paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc delete mode 100644 paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu delete mode 100644 paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc delete mode 100644 paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc delete mode 100644 paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc delete mode 100644 paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h create mode 100644 paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.cu create mode 100644 paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.h create mode 100644 paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu rename paddle/{fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu => phi/kernels/fusion/gpu/fused_fc_elementwise_layernorm_kernel.cu} (71%) create mode 100644 paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 89ea5def6fa6b..42c41effb80ed 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -7,14 +7,11 @@ register_operators( EXCLUDES fused_bn_activation_op conv_fusion_op - fusion_transpose_flatten_concat_op fusion_conv_inception_op - fused_fc_elementwise_layernorm_op self_dp_attention_op skip_layernorm_op yolo_box_head_op yolo_box_post_op - fused_embedding_eltwise_layernorm_op fusion_group_op fusion_gru_op fusion_lstm_op @@ -61,22 +58,15 @@ if(WITH_GPU OR WITH_ROCM) if(NOT ${CUDNN_VERSION} VERSION_LESS 7100) op_library(conv_fusion_op) endif() - # fusion_transpose_flatten_concat_op # HIP not support cudnnTransformTensor - if(NOT WITH_ROCM) - op_library(fusion_transpose_flatten_concat_op) - endif() # fusion_conv_inception_op needs cudnn 7 above # HIP not support cudnnConvolutionBiasActivationForward if((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100)) op_library(fusion_conv_inception_op) endif() - # fused_fc_elementwise_layernorm_op - op_library(fused_fc_elementwise_layernorm_op) op_library(skip_layernorm_op) op_library(yolo_box_head_op) op_library(yolo_box_post_op) - op_library(fused_embedding_eltwise_layernorm_op DEPS bert_encoder_functor) op_library(fused_gate_attention_op) # fusion_group if(NOT APPLE AND NOT WIN32) diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc deleted file mode 100644 index 6f2c61a5cf470..0000000000000 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/errors.h" - -namespace paddle { -namespace operators { - -class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* context) const override { - PADDLE_ENFORCE_EQ( - context->Inputs("Ids").size(), - context->Inputs("Embs").size(), - platform::errors::InvalidArgument( - "Two inputs of EmbeddingEltWiseLayerNormOp shoube be " - "the same size, but received the size of input Ids = %d," - " the size of input Embs = %d", - context->Inputs("Ids").size(), - context->Inputs("Embs").size())); - PADDLE_ENFORCE_GE(context->Inputs("Embs").size(), - 2UL, - platform::errors::InvalidArgument( - "Input Embs of EmbeddingEltWiseLayerNormOp should " - "have at least 2 tensors")); - PADDLE_ENFORCE_GE(context->Inputs("Ids").size(), - 2UL, - platform::errors::InvalidArgument( - "Input Ids of EmbeddingEltWiseLayerNormOp should " - "have at least 2 tensors")); - - PADDLE_ENFORCE_EQ( - context->HasInput("Bias"), - true, - platform::errors::InvalidArgument( - "Input(Bias) of EmbeddingEltWiseLayerNormOp should not be null.")); - - PADDLE_ENFORCE_EQ( - context->HasInput("Scale"), - true, - platform::errors::InvalidArgument( - "Input(Scale) of EmbeddingEltWiseLayerNormOp should not be null.")); - - PADDLE_ENFORCE_EQ( - context->HasOutput("Out"), - true, - platform::errors::InvalidArgument( - "Output(Out) of EmbeddingEltWiseLayerNormOp should not be null.")); - - // batch * seq_len * 1 - auto ids_dims = context->GetInputsDim("Ids"); - // word_num * hidden - auto embs_dims = context->GetInputsDim("Embs"); - // hidden - auto dims_bias = context->GetInputDim("Bias"); - int batch = ids_dims[0][0]; - int seq_len = ids_dims[0][1]; - int hidden = embs_dims[0][1]; - for (auto& embs_dim : embs_dims) { - PADDLE_ENFORCE_EQ(embs_dim.size(), - 2, - platform::errors::InvalidArgument( - "The Emb dim's size shoule be 2, but found %d.", - embs_dim.size())); - PADDLE_ENFORCE_EQ( - embs_dim[1], - dims_bias[0], - platform::errors::InvalidArgument( - "The second dims (%d) of the Embedding should be equal " - "to the Bias's size(%d).", - embs_dim[1], - dims_bias[0])); - PADDLE_ENFORCE_EQ( - embs_dim[1], - hidden, - platform::errors::InvalidArgument( - "The second dimension size(%d) of the Embedding should be " - "equal to the hidden's size(%d)", - embs_dim[1], - hidden)); - } - - auto dim_output = phi::make_ddim({batch, seq_len, hidden}); - context->SetOutputDim("Out", dim_output); - context->ShareLoD("Ids", /*->*/ "Out"); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto inputs = ctx.MultiInput("Embs"); - auto input_data_type = framework::proto::VarType::Type(0); - bool flag = false; - for (auto* input : inputs) { - if (input->IsInitialized() && input->numel() > 0) { - input_data_type = framework::TransToProtoVarType(input->dtype()); - flag = true; - break; - } - } - if (flag == 0) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "All Inputs of fused_embedding_eltwise_layernorm OP are Empty!")); - } - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -class EmbeddingEltWiseLayerNormOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Ids", "Input id tensors of EmbeddingEltWiseLayerNorm op") - .AsDuplicable(); - AddInput("Embs", "Input emb tensors of EmbeddingEltWiseLayerNorm op") - .AsDuplicable(); - AddInput("Bias", "The LayerNorm Bias of EmbeddingEltWiseLayerNorm op"); - AddInput("Scale", "The LayerNorm Scale of EmbeddingEltWiseLayerNorm op"); - AddOutput("Out", "The output of EmbeddingEltWiseLayerNorm op"); - AddAttr("epsilon", - "Constant for numerical stability [default 1e-5].") - .SetDefault(1e-5) - .AddCustomChecker([](const float& epsilon) { - PADDLE_ENFORCE_GE( - epsilon, - 0.0f, - platform::errors::InvalidArgument( - "'epsilon' is %f, but it should be between 0.0 and 0.001", - epsilon)); - PADDLE_ENFORCE_LE( - epsilon, - 0.001f, - platform::errors::InvalidArgument( - "'epsilon' is %f, but it should be between 0.0 and 0.001.", - epsilon)); - }); - AddComment(R"DOC( -EmbeddingEltWiseLayerNorm Operator. - -This op is used for optimize the following structure in ernie model. -id1 -> lookup_table_op -> data1 -id2 -> lookup_table_op -> data2 - ... -idn -> lookup_table_op -> data_n -data1 + data2 + ... + data_n -> Y -Y -> layer_norm -> Out - -Not suggest to use in other case except has same structure as ernie. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(fused_embedding_eltwise_layernorm, - ops::EmbeddingEltWiseLayerNormOp, - ops::EmbeddingEltWiseLayerNormOpMaker); diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu deleted file mode 100644 index 35574331e17d7..0000000000000 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/math/bert_encoder_functor.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { - -template -class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto &device_ctx = context.template device_context(); - auto ids = context.MultiInput("Ids"); - auto embs = context.MultiInput("Embs"); - int input_num = static_cast(ids.size()); - - phi::DenseTensor in_ids_( - framework::TransToPhiDataType(framework::proto::VarType::INT64)), - in_embs_( - framework::TransToPhiDataType(framework::proto::VarType::INT64)); - framework::DDim in_dim{input_num}; - int device_id; -#ifdef PADDLE_WITH_HIP - hipGetDevice(&device_id); -#else - cudaGetDevice(&device_id); -#endif - - auto &dev_ctx = context.template device_context(); - - in_ids_.Resize(in_dim); - in_embs_.Resize(in_dim); - - int64_t *in_ids_d = dev_ctx.template Alloc( - &in_ids_, in_ids_.numel() * sizeof(int64_t)); - int64_t *in_embs_d = dev_ctx.template Alloc( - &in_embs_, in_embs_.numel() * sizeof(int64_t)); - - std::vector in1s, in2s; - for (int i = 0; i < input_num; ++i) { - in1s.push_back(reinterpret_cast(ids[i]->data())); - in2s.push_back(reinterpret_cast(embs[i]->data())); - } -#ifdef PADDLE_WITH_HIP - hipMemcpyAsync(in_ids_d, - in1s.data(), - sizeof(int64_t) * input_num, - hipMemcpyHostToDevice, - device_ctx.stream()); - hipMemcpyAsync(in_embs_d, - in2s.data(), - sizeof(int64_t) * input_num, - hipMemcpyHostToDevice, - device_ctx.stream()); -#else - cudaMemcpyAsync(in_ids_d, - in1s.data(), - sizeof(int64_t) * input_num, - cudaMemcpyHostToDevice, - device_ctx.stream()); - cudaMemcpyAsync(in_embs_d, - in2s.data(), - sizeof(int64_t) * input_num, - cudaMemcpyHostToDevice, - device_ctx.stream()); -#endif - - auto *bias = context.Input("Bias"); - auto *scale = context.Input("Scale"); - auto *out = context.Output("Out"); - - // should be (B * S * hidden) - auto id0_dims = ids[0]->dims(); - auto emb0_dims = embs[0]->dims(); - - int batch = id0_dims[0]; - int seq_len = id0_dims[1]; - int hidden = emb0_dims[1]; - - auto *bias_d = bias->data(); - auto *scale_d = scale->data(); - auto *output_d = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - - float eps = context.Attr("epsilon"); - - if (std::is_same::value) { - const half *scale_new = reinterpret_cast(scale_d); - const half *bias_new = reinterpret_cast(bias_d); - half *output_new = reinterpret_cast(output_d); - - math::EmbEltwiseLayerNormFunctor emb_eltwise_layernorm_func; - emb_eltwise_layernorm_func(batch, - seq_len, - hidden, - in_ids_d, - scale_new, - bias_new, - in_embs_d, - output_new, - eps, - input_num, - device_ctx.stream()); - } else { - math::EmbEltwiseLayerNormFunctor emb_eltwise_layernorm_func; - emb_eltwise_layernorm_func(batch, - seq_len, - hidden, - in_ids_d, - scale_d, - bias_d, - in_embs_d, - output_d, - eps, - input_num, - device_ctx.stream()); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 -PD_REGISTER_STRUCT_KERNEL(fused_embedding_eltwise_layernorm, - GPU, - ALL_LAYOUT, - ops::EmbeddingEltWiseLayerNormKernel, - float, - plat::float16) {} -#else -PD_REGISTER_STRUCT_KERNEL(fused_embedding_eltwise_layernorm, - GPU, - ALL_LAYOUT, - ops::EmbeddingEltWiseLayerNormKernel, - float) {} -#endif diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc deleted file mode 100644 index 6f00b160d98df..0000000000000 --- a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cc +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusedFCElementwiseLayerNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("X"), "Input", "X", "FusedFcElementwiseLayernorm"); - OP_INOUT_CHECK( - ctx->HasInput("W"), "Input", "W", "FusedFcElementwiseLayernorm"); - OP_INOUT_CHECK( - ctx->HasInput("Y"), "Input", "Y", "FusedFcElementwiseLayernorm"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "FusedFcElementwiseLayernorm"); - - auto w_dims = ctx->GetInputDim("W"); - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2, - platform::errors::InvalidArgument( - "The input Weight of fc is expected to be a 2-D tensor. " - "But received the number of Weight's dimensions is %d, ", - "Weight's shape is %s.", - w_dims.size(), - w_dims)); - - if (ctx->HasInput("Bias0")) { - auto bias0_dims = ctx->GetInputDim("Bias0"); - - PADDLE_ENFORCE_LE(bias0_dims.size(), - 2, - platform::errors::InvalidArgument( - "The input Bias of fc is expected to be an 1-D or " - "2-D tensor. But received the number of Bias's " - "dimensions is %d, Bias's shape is %s.", - bias0_dims.size(), - bias0_dims)); - - PADDLE_ENFORCE_EQ( - bias0_dims[bias0_dims.size() - 1], - w_dims[1], - platform::errors::InvalidArgument( - "The last dimension of input Bias is expected be equal " - "to the actual width of input Weight. But received the last " - "dimension of Bias is %d, Bias's shape is %s; " - "the actual width of Weight is %d, Weight's shape is %s.", - bias0_dims[bias0_dims.size() - 1], - bias0_dims, - w_dims[1], - w_dims)); - - if (bias0_dims.size() == 2) { - PADDLE_ENFORCE_EQ( - bias0_dims[0], - 1, - platform::errors::InvalidArgument( - "The first dimension of input Bias is expected to be 1, " - "but received %d, Bias's shape is %s.", - bias0_dims[0], - bias0_dims)); - } - } - - auto x_dims = ctx->GetInputDim("X"); - int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); - PADDLE_ENFORCE_LT( - x_num_col_dims, - x_dims.size(), - platform::errors::InvalidArgument( - "The attribute x_num_col_dims used to flatten input X to " - "a 2-D tensor, is expected to be less than the number of " - "input X's dimensions. But received x_num_col_dims is %d, " - "the number of input X's dimensions is %d, input X's shape is %s.", - x_num_col_dims, - x_dims.size(), - x_dims)); - - auto x_mat_dims = phi::flatten_to_2d(x_dims, x_num_col_dims); - PADDLE_ENFORCE_EQ( - x_mat_dims[1], - w_dims[0], - platform::errors::InvalidArgument( - "The input's second dimension and weight's first dimension is " - "expected to be the same. But received input's second dimension is " - "%d, input's shape is %s; weight's first dimension is %d, weight's " - "shape is %s.", - x_mat_dims[1], - x_mat_dims, - w_dims[0], - w_dims)); - - std::vector fc_out_dims; - for (int i = 0; i < x_num_col_dims; ++i) { - fc_out_dims.push_back(x_dims[i]); - } - fc_out_dims.push_back(w_dims[1]); - - auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(phi::make_ddim(fc_out_dims), - y_dims, - platform::errors::InvalidArgument( - "The output's shape of fc is expected to be equal to " - "that of input Y. But received output's shape of fc " - "is %s, input Y's shape is %s.", - phi::make_ddim(fc_out_dims), - y_dims)); - - auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); - PADDLE_ENFORCE_LT( - begin_norm_axis, - y_dims.size(), - platform::errors::InvalidArgument( - "The attribute begin_norm_axis used to flatten input Y to a 2-D " - "tensor, is expected to be less than the number of input Y's " - "dimensions. But received begin_norm_axis is %d, the number of " - "input Y's dimensions is %d, input Y's shape is %s.", - begin_norm_axis, - y_dims.size(), - y_dims)); - - auto y_mat_dim = phi::flatten_to_2d(y_dims, begin_norm_axis); - int64_t dim_0 = y_mat_dim[0]; - int64_t dim_1 = y_mat_dim[1]; - if (ctx->HasInput("Scale")) { - auto scale_dims = ctx->GetInputDim("Scale"); - PADDLE_ENFORCE_EQ(scale_dims.size(), - 1, - platform::errors::InvalidArgument( - "The input Scale is expected to be an 1-D tensor. " - "But received the number of input Scale's " - "dimensions is %d, input Scale's shape is %s.", - scale_dims.size(), - scale_dims)); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - scale_dims[0], - dim_1, - platform::errors::InvalidArgument( - "The first dimension of input Scale is expected to be equal to " - "the second dimension of input Y after flattened. " - "But received the first dimension of input Scale is %d, input " - "Scale's shape is %s; the second dimension of flattened input " - "Y is %d, input Y's shape is %s, flattened axis is %d.", - scale_dims[0], - scale_dims, - dim_1, - y_dims, - begin_norm_axis)); - } - } - if (ctx->HasInput("Bias1")) { - auto bias1_dims = ctx->GetInputDim("Bias1"); - PADDLE_ENFORCE_EQ( - bias1_dims.size(), - 1, - platform::errors::InvalidArgument( - "The input Bias1 is expected to be an 1-D tensor. " - "But received the number of input Bias1's dimension is %d, " - "input Bias1's shape is %s.", - bias1_dims.size(), - bias1_dims)); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - bias1_dims[0], - dim_1, - platform::errors::InvalidArgument( - "The first dimension of input Bias1 is expected to be equal to " - "the second dimension of input Y after flattened. " - "But received the first dimension of input Bias1 is %d, input " - "Bias1's shape is %s; the second dimension of flatten input " - "Y is %d, input Y's shape is %s, flattened axis is %d.", - bias1_dims[0], - bias1_dims, - dim_1, - y_dims, - begin_norm_axis)); - } - } - - ctx->SetOutputDim("Out", y_dims); - if (ctx->HasOutput("Mean")) { - ctx->SetOutputDim("Mean", {dim_0}); - } - if (ctx->HasOutput("Variance")) { - ctx->SetOutputDim("Variance", {dim_0}); - } - ctx->ShareLoD("X", "Out"); - } -}; - -class FusedFCElementwiseLayerNormOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of fully connected operation"); - AddInput("W", - "(Tensor), The weight tensor of fully connected operation. It is " - "a 2-D Tensor with shape (I, O)"); - AddInput("Bias0", - "(Tensor, optional), The bias tensor of fully connecred " - "operation. It is a 1-D Tensor with shape (O), or a 2-D Tensor " - "with shape (1, O).") - .AsDispensable(); - AddInput("Y", - "(Tensor), The second input tensor of elementwise_add operation. " - "Note that the shape should be the same as fully connect's result " - "tensor."); - AddInput( - "Scale", - "(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.") - .AsDispensable(); - AddInput( - "Bias1", - "(Tensor, optional), It is a 1-D input Tensor of layer_norm operation.") - .AsDispensable(); - AddOutput("Out", - "(Tensor), Output after normalization. The shape is the shame as " - "layer_norm's input."); - AddOutput("Mean", "(Tensor, optional), Mean of the current minibatch") - .AsDispensable(); - AddOutput("Variance", - "(Tensor, optional), Variance of the current minibatch") - .AsDispensable(); - AddAttr("x_num_col_dims", - "(int, default 1), This op can take tensors with more than " - "two dimensions as its inputs.") - .SetDefault(1) - .EqualGreaterThan(1); - AddAttr("activation_type", - "Activation type used in fully connected operator.") - .SetDefault(""); - AddAttr("epsilon", - "Constant for numerical stability [default 1e-5].") - .SetDefault(1e-5) - .AddCustomChecker([](const float &epsilon) { - PADDLE_ENFORCE_GE(epsilon, - 0.0f, - platform::errors::InvalidArgument( - "'epsilon' should be between 0.0 and 0.001.")); - PADDLE_ENFORCE_LE(epsilon, - 0.001f, - platform::errors::InvalidArgument( - "'epsilon' should be between 0.0 and 0.001.")); - }); - AddAttr("begin_norm_axis", - "the axis of `begin_norm_axis ... Rank(Y) - 1` will be " - "normalized. `begin_norm_axis` splits the tensor(`X`) to a " - "matrix [N,H]. [default 1].") - .SetDefault(1) - .AddCustomChecker([](const int &begin_norm_axis) { - PADDLE_ENFORCE_GT( - begin_norm_axis, - 0, - platform::errors::InvalidArgument( - "'begin_norm_axis' should be greater than zero.")); - }); - AddComment(R"DOC( -fc_out <= fc(X, W, Bias0) -add_out <= elementwise_add(fc_out, Y) -(out, mean, variance) <= layer_norm(add_out, Scale, Bias1) -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR( - fused_fc_elementwise_layernorm, - ops::FusedFCElementwiseLayerNormOp, - ops::FusedFCElementwiseLayerNormOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc deleted file mode 100644 index e7bb037a3f3aa..0000000000000 --- a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h" - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_GE( - ctx->Inputs("X").size(), - 1UL, - platform::errors::InvalidArgument( - "Inputs(X) of TransposeFlattenConcat op should not be empty.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), - true, - platform::errors::InvalidArgument( - "Inputs(X) of TransposeFlattenConcat op should not be empty.")); - - auto ins = ctx->GetInputsDim("X"); - const size_t n = ins.size(); - PADDLE_ENFORCE_GT(n, - 0, - platform::errors::InvalidArgument( - "The size of Inputs(X)'s dimension should be greater " - " than 0, but received %d.", - n)); - - std::vector trans_axis = - ctx->Attrs().Get>("trans_axis"); - int flatten_axis = ctx->Attrs().Get("flatten_axis"); - int concat_axis = ctx->Attrs().Get("concat_axis"); - - size_t x_rank = ins[0].size(); - size_t trans_axis_size = trans_axis.size(); - PADDLE_ENFORCE_EQ(x_rank, - trans_axis_size, - platform::errors::InvalidArgument( - "The input tensor's rank(%d) " - "should be equal to the permutation axis's size(%d)", - x_rank, - trans_axis_size)); - - auto dims0 = - GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0])); - std::vector out_dims(dims0); - for (size_t i = 1; i < n; i++) { - auto dimsi = - GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[i])); - for (int j = 0; j < static_cast(dims0.size()); j++) { - if (j == concat_axis) { - out_dims[concat_axis] += dimsi[j]; - } else { - PADDLE_ENFORCE_EQ(out_dims[j], - dimsi[j], - platform::errors::InvalidArgument( - "After flatting, the %d-th dim should be save " - "except the specify axis.", - j)); - } - } - } - if (out_dims[concat_axis] < 0) { - out_dims[concat_axis] = -1; - } - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - } -}; - -class TransposeFlattenConcatFusionOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(Tensor) The input tensor, tensors with rank up to 6 are supported.") - .AsDuplicable(); - AddOutput("Out", "(Tensor)The output tensor."); - AddAttr>( - "trans_axis", - "(vector) A list of values, and the size of the list should be " - "the same with the input tensor rank. This operator permutes the input " - "tensor's axes according to the values given."); - AddAttr("flatten_axis", - "(int)" - "Indicate up to which input dimensions (exclusive) should be" - "flattened to the outer dimension of the output. The value" - "for axis must be in the range [0, R], where R is the rank of" - "the input tensor. When axis = 0, the shape of the output" - "tensor is (1, (d_0 X d_1 ... d_n), where the shape of the" - "input tensor is (d_0, d_1, ... d_n)."); - AddAttr("concat_axis", - "The axis along which the input tensors will be concatenated. " - "It should be 0 or 1, since the tensor is 2D after flatting."); - AddComment(R"DOC( - - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR( - fusion_transpose_flatten_concat, - ops::TransposeFlattenConcatFusionOp, - ops::TransposeFlattenConcatFusionOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc deleted file mode 100644 index 3d843ac6409ec..0000000000000 --- a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h" - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace operators { - -template -using CudnnDataType = platform::CudnnDataType; - -template -class TransposeFlattenConcatFusionKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - auto& dev_ctx = ctx.template device_context(); - dev_ctx.Alloc(out, out->numel() * sizeof(T)); - auto odims = out->dims(); - - std::vector trans_axis = ctx.Attr>("trans_axis"); - int flatten_axis = ctx.Attr("flatten_axis"); - int concat_axis = ctx.Attr("concat_axis"); - - int rank = ins[0]->dims().size(); - // use at least 4D in cudnnTransformTensor - int max_dim = rank < 4 ? 4 : rank; - std::vector stride_x(max_dim, 0); - std::vector stride_y(max_dim, 0); - std::vector dims_y(max_dim, 0); - - cudnnTensorDescriptor_t in_desc; - cudnnTensorDescriptor_t out_desc; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&in_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&out_desc)); - cudnnDataType_t cudnn_dtype = CudnnDataType::type; - - auto handle = dev_ctx.cudnn_handle(); - - T* odata = out->data(); - for (auto& item : ins) { - auto perm_shape = GetPermuteShape(trans_axis, item->dims()); - int osize = 1; - auto idims = item->dims(); - for (int i = 0; i < rank; i++) { - stride_x[i] = 1; - for (int j = trans_axis[i] + 1; j < rank; j++) { - stride_x[i] *= idims[j]; - } - dims_y[i] = perm_shape[i]; - osize *= perm_shape[i]; - } - stride_y[rank - 1] = 1; - for (int i = rank - 2; i >= 0; i--) { - if (((i + 1) == flatten_axis) && (concat_axis == 1)) { - stride_y[i] = odims[1]; - } else { - stride_y[i] = stride_y[i + 1] * perm_shape[i + 1]; - } - } - - // Since concat is after flatten, the output is 2D tensor. - // If concat_axis is 0, each input's permutated tensor is continuous. - // If concat_axis is 1, the stride of 0-th dim of each input's - // permutated tensor is odims()[1]. - - for (int i = rank; i < max_dim; i++) { - stride_x[i] = 1; - stride_y[i] = 1; - dims_y[i] = 1; - } - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data())); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data())); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnTransformTensor( - handle, - CudnnDataType::kOne(), - in_desc, - static_cast(item->data()), - CudnnDataType::kZero(), - out_desc, - static_cast(odata))); - if (concat_axis == 0) { - odata += osize; - } else { - auto flat_shape = GetFlattenShape(flatten_axis, perm_shape); - odata += flat_shape[1]; - } - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(in_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(out_desc)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(fusion_transpose_flatten_concat, - GPU, - ALL_LAYOUT, - ops::TransposeFlattenConcatFusionKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h deleted file mode 100644 index 52140c0ca46ee..0000000000000 --- a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/phi/core/ddim.h" - -namespace paddle { -namespace operators { - -inline std::vector GetPermuteShape(const std::vector& axis, - const framework::DDim& in_dims) { - std::vector out_dims(in_dims.size()); - for (size_t i = 0; i < axis.size(); i++) { - out_dims[i] = in_dims[axis[i]]; - } - return out_dims; -} - -inline std::vector GetFlattenShape(const int axis, - const std::vector& in_dims) { - int64_t outer = 1, inner = 1; - for (int i = 0; i < static_cast(in_dims.size()); ++i) { - if (i < axis) { - outer *= in_dims[i]; - } else { - inner *= in_dims[i]; - } - } - std::vector out_shape(2); - out_shape[0] = outer; - out_shape[1] = inner; - return out_shape; -} - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 657b0b976ef62..9424ab8fa9924 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -129,132 +129,6 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp &thread_data, } } -template -__global__ void EmbEltwiseLayernormKernel(int hidden, - const int64_t *ids, - const T *scale, - const T *bias, - const int64_t *embs, - T *output, - T eps, - int input_num) { - cub::Sum pair_sum; - // blockIdx.x: position in the sequence - // blockIdx.y: batch - // gridDim.x: Seq - // gridDim.y: Batch - - extern __shared__ int64_t array_id[]; - - const T rhidden = T(1.f) / T(hidden); - const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y; - if (threadIdx.x == 0) { - for (int i = 0; i < input_num; ++i) { - const int64_t *ids_p = reinterpret_cast(ids[i]); - array_id[i] = ids_p[seq_pos]; - } - } - __syncthreads(); - - const int64_t out_offset = seq_pos * hidden; - - phi::funcs::kvp thread_data(0, 0); - -#pragma unroll - for (int it = threadIdx.x; it < hidden; it += TPB) { - T val = 0; - for (int i = 0; i < input_num; ++i) { - val += reinterpret_cast(embs[i])[array_id[i] * hidden + it]; - } - - output[out_offset + it] = val; - const T rhiddenval = rhidden * val; - thread_data = - pair_sum(thread_data, phi::funcs::kvp(rhiddenval, rhiddenval * val)); - } - LayerNorm(thread_data, hidden, out_offset, bias, scale, output, eps); -} - -// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake -#ifndef __HIPCC__ // @{ Half kernel: EmbEltwiseLayernormKernel -template <> -__global__ void EmbEltwiseLayernormKernel(int hidden, - const int64_t *ids, - const half *scale, - const half *bias, - const int64_t *embs, - half *output, - half eps, - int input_num) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - cub::Sum pair_sum; - // blockIdx.x: position in the sequence - // blockIdx.y: batch - // gridDim.x: Seq - // gridDim.y: Batch - - extern __shared__ int64_t array_id[]; - - const half rhidden = half(1.f) / half(hidden); - const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y; - if (threadIdx.x == 0) { - for (int i = 0; i < input_num; ++i) { - const int64_t *ids_p = reinterpret_cast(ids[i]); - array_id[i] = ids_p[seq_pos]; - } - } - __syncthreads(); - - const int64_t out_offset = seq_pos * hidden; - - phi::funcs::kvp thread_data(0, 0); - -#pragma unroll - for (int it = threadIdx.x; it < hidden; it += 256) { - half val = 0; - for (int i = 0; i < input_num; ++i) { - val += reinterpret_cast(embs[i])[array_id[i] * hidden + it]; - } - - output[out_offset + it] = val; - const half rhiddenval = rhidden * val; - thread_data = pair_sum(thread_data, - phi::funcs::kvp(rhiddenval, rhiddenval * val)); - } - LayerNorm( - thread_data, hidden, out_offset, bias, scale, output, eps); -#endif -} -#endif // @} End Half kernel: EmbEltwiseLayernormKernel - -template -void EmbEltwiseLayerNormFunctor::operator()(int batch, - int seq_len, - int hidden, - const int64_t *ids, - const T *scale, - const T *bias, - const int64_t *embs, - T *output, - float eps, - int input_num, - gpuStream_t stream) { - const unsigned tpb = 256; - const dim3 grid(seq_len, batch, 1); - const dim3 block(tpb, 1, 1); - int shared_bytes = input_num * sizeof(int64_t); - EmbEltwiseLayernormKernel<<>>( - hidden, ids, scale, bias, embs, output, eps, input_num); -} - -template class EmbEltwiseLayerNormFunctor; - -// device function 'operator()' is not supportted until cuda 10.0 -// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 -template class EmbEltwiseLayerNormFunctor; -#endif - template __global__ void SkipLayerNormSmallKernel(int num, int hidden, diff --git a/paddle/fluid/operators/math/bert_encoder_functor.h b/paddle/fluid/operators/math/bert_encoder_functor.h index 6d31098686608..76e27380b90e2 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.h +++ b/paddle/fluid/operators/math/bert_encoder_functor.h @@ -48,35 +48,6 @@ struct CUDATypeTraits { }; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -// This functor involves a fusion calculation in Ernie or Bert. -// The fusion mode is as follows: -// -// in_var emb in_var emb -// | | | | -// lookup_table lookup_table -// | | -// lkt_var lkt_var -// \ / -// elementwise_add -// | -// elt_out_var -// -template -class EmbEltwiseLayerNormFunctor { - public: - void operator()(int batch, - int seq_len, - int hidden, - const int64_t *ids, - const T *scale, - const T *bias, - const int64_t *embs, - T *output, - float eps, - int input_num, - gpuStream_t stream); -}; - // This functor involves a fusion calculation in Ernie or Bert. // The fusion mode is as follows: // diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 9f19dcf31d728..b54307861b367 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -145,6 +145,25 @@ backward : fused_dropout_add_grad support_dygraph_mode : true +- op : fused_embedding_eltwise_layernorm + args : (Tensor[] ids, Tensor[] embs, Tensor bias, Tensor scale, float epsilon = 0.00001f) + output : Tensor(out) + infer_meta : + func : FusedEmbeddingEltWiseLayerNormInferMeta + kernel : + func : fused_embedding_eltwise_layernorm + data_type : embs + +- op : fused_fc_elementwise_layernorm + args : (Tensor x, Tensor w, Tensor y, Tensor bias0, Tensor scale, Tensor bias1, int x_num_col_dims = 1, str activation_type = "", float epsilon = 0.00001f, int begin_norm_axis = 1) + output : Tensor(out), Tensor(mean), Tensor(variance) + infer_meta : + func : FusedFCElementwiseLayerNormInferMeta + kernel : + func : fused_fc_elementwise_layernorm + data_type : x + optional : bias0, scale, bias1, mean, variance + - op : fused_linear_param_grad_add args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true, bool has_bias = true) output : Tensor(dweight_out), Tensor(dbias_out) @@ -188,6 +207,15 @@ func : fused_scale_bias_relu_conv_bnstats data_type : x +- op : fusion_transpose_flatten_concat + args : (Tensor[] x, int[] trans_axis, int flatten_axis, int concat_axis) + output : Tensor(out) + infer_meta : + func : FusionTransposeFlattenConcatInferMeta + kernel : + func : fusion_transpose_flatten_concat + data_type : x + - op : generate_sequence_xpu args : (Tensor x, DataType dtype) output : Tensor diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 30041bf323b1e..a21cb39ac076f 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1255,6 +1255,35 @@ attrs : [bool use_cudnn = false, float fuse_alpha = 0.0f, float fuse_beta = 0.0f, float Scale_in = 1.0f, float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}'] +- op : fused_embedding_eltwise_layernorm + inputs : + ids : Ids + embs : Embs + bias : Bias + scale : Scale + outputs : + out : Out + attrs : + epsilon : epsilon + +- op : fused_fc_elementwise_layernorm + inputs : + x : X + w : W + y : Y + bias0 : Bias0 + scale : Scale + bias1 : Bias1 + outputs : + out : Out + mean : Mean + variance : Variance + attrs : + x_num_col_dims : x_num_col_dims + activation_type : activation_type + epsilon : epsilon + begin_norm_axis : begin_norm_axis + - op : fused_feedforward backward: fused_feedforward_grad inputs: @@ -1291,6 +1320,16 @@ extra : attrs : [str data_format = "AnyLayout"] +- op : fusion_transpose_flatten_concat + inputs : + x : X + outputs : + out : Out + attrs : + trans_axis : trans_axis + flatten_axis : flatten_axis + concat_axis : concat_axis + - op : gather backward : gather_grad inputs : diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 6846b5928c116..d047670a9ee5f 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1842,4 +1842,329 @@ void SqueezeExcitationInferMeta(const MetaTensor& x, out->set_dims(DDim(out_shape.data(), static_cast(out_shape.size()))); } +void FusedEmbeddingEltWiseLayerNormInferMeta( + const std::vector& ids, + const std::vector& embs, + const MetaTensor& bias, + const MetaTensor& scale, + const float epsilon, + MetaTensor* out) { + PADDLE_ENFORCE_EQ( + ids.size(), + embs.size(), + phi::errors::InvalidArgument( + "Two inputs of EmbeddingEltWiseLayerNormOp shoube be " + "the same size, but received the size of input Ids = %d," + " the size of input Embs = %d", + ids.size(), + embs.size())); + PADDLE_ENFORCE_GE(embs.size(), + 2UL, + phi::errors::InvalidArgument( + "Input Embs of EmbeddingEltWiseLayerNormOp should " + "have at least 2 tensors")); + PADDLE_ENFORCE_GE(ids.size(), + 2UL, + phi::errors::InvalidArgument( + "Input Ids of EmbeddingEltWiseLayerNormOp should " + "have at least 2 tensors")); + + // batch * seq_len * 1 + std::vector ids_dims, embs_dims; + ids_dims.reserve(ids.size()); + std::transform(ids.begin(), + ids.end(), + std::back_inserter(ids_dims), + [](const MetaTensor* var) { return var->dims(); }); + // word_num * hidden + embs_dims.reserve(embs.size()); + std::transform(embs.begin(), + embs.end(), + std::back_inserter(embs_dims), + [](const MetaTensor* var) { return var->dims(); }); + // hidden + DDim dims_bias = bias.dims(); + + int batch = ids_dims[0][0]; + int seq_len = ids_dims[0][1]; + int hidden = embs_dims[0][1]; + for (auto& embs_dim : embs_dims) { + PADDLE_ENFORCE_EQ( + embs_dim.size(), + 2, + phi::errors::InvalidArgument( + "The Emb dim's size shoule be 2, but found %d.", embs_dim.size())); + PADDLE_ENFORCE_EQ( + embs_dim[1], + dims_bias[0], + phi::errors::InvalidArgument( + "The second dims (%d) of the Embedding should be equal " + "to the Bias's size(%d).", + embs_dim[1], + dims_bias[0])); + PADDLE_ENFORCE_EQ( + embs_dim[1], + hidden, + phi::errors::InvalidArgument( + "The second dimension size(%d) of the Embedding should be " + "equal to the hidden's size(%d)", + embs_dim[1], + hidden)); + } + + auto dim_output = phi::make_ddim({batch, seq_len, hidden}); + out->set_dims(dim_output); + // out->share_lod(ids); + // context->ShareLoD("Ids", /*->*/ "Out"); +} + +void FusionTransposeFlattenConcatInferMeta( + const std::vector& x, + const std::vector& trans_axis, + const int flatten_axis, + const int concat_axis, + MetaTensor* out) { + PADDLE_ENFORCE_GE( + x.size(), + 1UL, + phi::errors::InvalidArgument( + "Inputs(X) of TransposeFlattenConcat op should not be empty.")); + + std::vector ins; + ins.reserve(x.size()); + std::transform( + x.begin(), x.end(), std::back_inserter(ins), [](const MetaTensor* var) { + return var->dims(); + }); + const size_t n = ins.size(); + PADDLE_ENFORCE_GT(n, + 0, + phi::errors::InvalidArgument( + "The size of Inputs(X)'s dimension should be greater " + " than 0, but received %d.", + n)); + + size_t x_rank = ins[0].size(); + size_t trans_axis_size = trans_axis.size(); + PADDLE_ENFORCE_EQ(x_rank, + trans_axis_size, + phi::errors::InvalidArgument( + "The input tensor's rank(%d) " + "should be equal to the permutation axis's size(%d)", + x_rank, + trans_axis_size)); + + auto dims0 = phi::funcs::GetFlattenShape( + flatten_axis, phi::funcs::GetPermuteShape(trans_axis, ins[0])); + std::vector out_dims(dims0); + for (size_t i = 1; i < n; i++) { + auto dimsi = phi::funcs::GetFlattenShape( + flatten_axis, phi::funcs::GetPermuteShape(trans_axis, ins[i])); + for (int j = 0; j < static_cast(dims0.size()); j++) { + if (j == concat_axis) { + out_dims[concat_axis] += dimsi[j]; + } else { + PADDLE_ENFORCE_EQ(out_dims[j], + dimsi[j], + phi::errors::InvalidArgument( + "After flatting, the %d-th dim should be save " + "except the specify axis.", + j)); + } + } + } + if (out_dims[concat_axis] < 0) { + out_dims[concat_axis] = -1; + } + out->set_dims(phi::make_ddim(out_dims)); +} + +void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& y, + const MetaTensor& bias0, + const MetaTensor& scale, + const MetaTensor& bias1, + const int x_num_col_dims, + const std::string& activation_type, + const float epsilon, + const int begin_norm_axis, + MetaTensor* out, + MetaTensor* mean, + MetaTensor* variance, + MetaConfig config) { + DDim w_dims = w.dims(); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + phi::errors::InvalidArgument( + "The input Weight of fc is expected to be a 2-D tensor. " + "But received the number of Weight's dimensions is %d, ", + "Weight's shape is %s.", + w_dims.size(), + w_dims)); + + if (bias0) { + DDim bias0_dims = bias0.dims(); + + PADDLE_ENFORCE_LE(bias0_dims.size(), + 2, + phi::errors::InvalidArgument( + "The input Bias of fc is expected to be an 1-D or " + "2-D tensor. But received the number of Bias's " + "dimensions is %d, Bias's shape is %s.", + bias0_dims.size(), + bias0_dims)); + + PADDLE_ENFORCE_EQ( + bias0_dims[bias0_dims.size() - 1], + w_dims[1], + phi::errors::InvalidArgument( + "The last dimension of input Bias is expected be equal " + "to the actual width of input Weight. But received the last " + "dimension of Bias is %d, Bias's shape is %s; " + "the actual width of Weight is %d, Weight's shape is %s.", + bias0_dims[bias0_dims.size() - 1], + bias0_dims, + w_dims[1], + w_dims)); + + if (bias0_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + bias0_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dimension of input Bias is expected to be 1, " + "but received %d, Bias's shape is %s.", + bias0_dims[0], + bias0_dims)); + } + } + + DDim x_dims = x.dims(); + PADDLE_ENFORCE_LT( + x_num_col_dims, + x_dims.size(), + phi::errors::InvalidArgument( + "The attribute x_num_col_dims used to flatten input X to " + "a 2-D tensor, is expected to be less than the number of " + "input X's dimensions. But received x_num_col_dims is %d, " + "the number of input X's dimensions is %d, input X's shape is %s.", + x_num_col_dims, + x_dims.size(), + x_dims)); + + auto x_mat_dims = phi::flatten_to_2d(x_dims, x_num_col_dims); + PADDLE_ENFORCE_EQ( + x_mat_dims[1], + w_dims[0], + phi::errors::InvalidArgument( + "The input's second dimension and weight's first dimension is " + "expected to be the same. But received input's second dimension is " + "%d, input's shape is %s; weight's first dimension is %d, weight's " + "shape is %s.", + x_mat_dims[1], + x_mat_dims, + w_dims[0], + w_dims)); + + std::vector fc_out_dims; + for (int i = 0; i < x_num_col_dims; ++i) { + fc_out_dims.push_back(x_dims[i]); + } + fc_out_dims.push_back(w_dims[1]); + + DDim y_dims = y.dims(); + PADDLE_ENFORCE_EQ(phi::make_ddim(fc_out_dims), + y_dims, + phi::errors::InvalidArgument( + "The output's shape of fc is expected to be equal to " + "that of input Y. But received output's shape of fc " + "is %s, input Y's shape is %s.", + phi::make_ddim(fc_out_dims), + y_dims)); + + PADDLE_ENFORCE_LT( + begin_norm_axis, + y_dims.size(), + phi::errors::InvalidArgument( + "The attribute begin_norm_axis used to flatten input Y to a 2-D " + "tensor, is expected to be less than the number of input Y's " + "dimensions. But received begin_norm_axis is %d, the number of " + "input Y's dimensions is %d, input Y's shape is %s.", + begin_norm_axis, + y_dims.size(), + y_dims)); + + auto y_mat_dim = phi::flatten_to_2d(y_dims, begin_norm_axis); + int64_t dim_0 = y_mat_dim[0]; + int64_t dim_1 = y_mat_dim[1]; + if (scale) { + DDim scale_dims = scale.dims(); + PADDLE_ENFORCE_EQ(scale_dims.size(), + 1, + phi::errors::InvalidArgument( + "The input Scale is expected to be an 1-D tensor. " + "But received the number of input Scale's " + "dimensions is %d, input Scale's shape is %s.", + scale_dims.size(), + scale_dims)); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ( + scale_dims[0], + dim_1, + phi::errors::InvalidArgument( + "The first dimension of input Scale is expected to be equal to " + "the second dimension of input Y after flattened. " + "But received the first dimension of input Scale is %d, input " + "Scale's shape is %s; the second dimension of flattened input " + "Y is %d, input Y's shape is %s, flattened axis is %d.", + scale_dims[0], + scale_dims, + dim_1, + y_dims, + begin_norm_axis)); + } + } + if (bias1) { + DDim bias1_dims = bias1.dims(); + PADDLE_ENFORCE_EQ( + bias1_dims.size(), + 1, + phi::errors::InvalidArgument( + "The input Bias1 is expected to be an 1-D tensor. " + "But received the number of input Bias1's dimension is %d, " + "input Bias1's shape is %s.", + bias1_dims.size(), + bias1_dims)); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ( + bias1_dims[0], + dim_1, + phi::errors::InvalidArgument( + "The first dimension of input Bias1 is expected to be equal to " + "the second dimension of input Y after flattened. " + "But received the first dimension of input Bias1 is %d, input " + "Bias1's shape is %s; the second dimension of flatten input " + "Y is %d, input Y's shape is %s, flattened axis is %d.", + bias1_dims[0], + bias1_dims, + dim_1, + y_dims, + begin_norm_axis)); + } + } + + out->set_dims(y_dims); + if (mean) { + mean->set_dims({dim_0}); + } + if (variance) { + variance->set_dims({dim_0}); + } + out->share_lod(x); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index fe3ebe989cdc3..c022a4257e4dc 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -485,4 +485,34 @@ void SqueezeExcitationInferMeta(const MetaTensor& x, const std::vector& filter_dims, MetaTensor* out); +void FusedEmbeddingEltWiseLayerNormInferMeta( + const std::vector& ids, + const std::vector& embs, + const MetaTensor& bias, + const MetaTensor& scale, + const float epsilon, + MetaTensor* out); + +void FusionTransposeFlattenConcatInferMeta( + const std::vector& x, + const std::vector& trans_axis, + const int flatten_axis, + const int concat_axis, + MetaTensor* out); + +void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& y, + const MetaTensor& bias0, + const MetaTensor& scale, + const MetaTensor& bias1, + const int x_num_col_dims, + const std::string& activation_type, + const float epsilon, + const int begin_norm_axis, + MetaTensor* out, + MetaTensor* mean, + MetaTensor* variance, + MetaConfig config = MetaConfig()); + } // namespace phi diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 8db9a92f47d5a..d186bda9ceb95 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -244,5 +244,30 @@ inline int64_t CalStride(phi::DDim dim) { return strides; } +inline std::vector GetPermuteShape(const std::vector &axis, + const DDim &in_dims) { + std::vector out_dims(in_dims.size()); + for (size_t i = 0; i < axis.size(); i++) { + out_dims[i] = in_dims[axis[i]]; + } + return out_dims; +} + +inline std::vector GetFlattenShape(const int axis, + const std::vector &in_dims) { + int64_t outer = 1, inner = 1; + for (int i = 0; i < static_cast(in_dims.size()); ++i) { + if (i < axis) { + outer *= in_dims[i]; + } else { + inner *= in_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + return out_shape; +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.cu b/paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.cu new file mode 100644 index 0000000000000..5d4611fa9d09a --- /dev/null +++ b/paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.cu @@ -0,0 +1,210 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUDA +#include +#include + +#include // NOLINT +#endif +#ifdef PADDLE_WITH_HIP +#include + +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.h" + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace phi { +namespace funcs { + +template +__device__ inline T rsqrt(const T& x); + +template <> +__device__ inline float rsqrt(const float& x) { + return rsqrtf(x); +} + +template +__device__ __forceinline__ T local_rsqrt(T num) { + return rsqrt(static_cast(num)); +} +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); } +#endif + +template +__device__ inline void LayerNorm(const phi::funcs::kvp& thread_data, + const int ld, + const int offset, + const T* bias, + const T* scale, + T* output, + T eps) { + using BlockReduce = cub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum()); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = output[idx]; + const T g(scale[i]); + const T b(bias[i]); + output[idx] = g * (val - mu) * rsigma + b; + } +} + +template +__global__ void EmbEltwiseLayernormKernel(int hidden, + const int64_t* ids, + const T* scale, + const T* bias, + const int64_t* embs, + T* output, + T eps, + int input_num) { + cub::Sum pair_sum; + // blockIdx.x: position in the sequence + // blockIdx.y: batch + // gridDim.x: Seq + // gridDim.y: Batch + + extern __shared__ int64_t array_id[]; + + const T rhidden = T(1.f) / T(hidden); + const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y; + if (threadIdx.x == 0) { + for (int i = 0; i < input_num; ++i) { + const int64_t* ids_p = reinterpret_cast(ids[i]); + array_id[i] = ids_p[seq_pos]; + } + } + __syncthreads(); + + const int64_t out_offset = seq_pos * hidden; + + phi::funcs::kvp thread_data(0, 0); + +#pragma unroll + for (int it = threadIdx.x; it < hidden; it += TPB) { + T val = 0; + for (int i = 0; i < input_num; ++i) { + val += reinterpret_cast(embs[i])[array_id[i] * hidden + it]; + } + + output[out_offset + it] = val; + const T rhiddenval = rhidden * val; + thread_data = + pair_sum(thread_data, phi::funcs::kvp(rhiddenval, rhiddenval * val)); + } + LayerNorm(thread_data, hidden, out_offset, bias, scale, output, eps); +} + +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#ifndef __HIPCC__ // @{ Half kernel: EmbEltwiseLayernormKernel +template <> +__global__ void EmbEltwiseLayernormKernel(int hidden, + const int64_t* ids, + const half* scale, + const half* bias, + const int64_t* embs, + half* output, + half eps, + int input_num) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + cub::Sum pair_sum; + // blockIdx.x: position in the sequence + // blockIdx.y: batch + // gridDim.x: Seq + // gridDim.y: Batch + + extern __shared__ int64_t array_id[]; + + const half rhidden = half(1.f) / half(hidden); + const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y; + if (threadIdx.x == 0) { + for (int i = 0; i < input_num; ++i) { + const int64_t* ids_p = reinterpret_cast(ids[i]); + array_id[i] = ids_p[seq_pos]; + } + } + __syncthreads(); + + const int64_t out_offset = seq_pos * hidden; + + phi::funcs::kvp thread_data(0, 0); + +#pragma unroll + for (int it = threadIdx.x; it < hidden; it += 256) { + half val = 0; + for (int i = 0; i < input_num; ++i) { + val += reinterpret_cast(embs[i])[array_id[i] * hidden + it]; + } + + output[out_offset + it] = val; + const half rhiddenval = rhidden * val; + thread_data = pair_sum(thread_data, + phi::funcs::kvp(rhiddenval, rhiddenval * val)); + } + LayerNorm( + thread_data, hidden, out_offset, bias, scale, output, eps); +#endif +} +#endif // @} End Half kernel: EmbEltwiseLayernormKernel + +template +void EmbEltwiseLayerNormFunctor::operator()(int batch, + int seq_len, + int hidden, + const int64_t* ids, + const T* scale, + const T* bias, + const int64_t* embs, + T* output, + float eps, + int input_num, + gpuStream_t stream) { + const unsigned tpb = 256; + const dim3 grid(seq_len, batch, 1); + const dim3 block(tpb, 1, 1); + int shared_bytes = input_num * sizeof(int64_t); + EmbEltwiseLayernormKernel<<>>( + hidden, ids, scale, bias, embs, output, eps, input_num); +} + +template class EmbEltwiseLayerNormFunctor; + +// device function 'operator()' is not supportted until cuda 10.0 +// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 +template class EmbEltwiseLayerNormFunctor; +#endif + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.h b/paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.h new file mode 100644 index 0000000000000..d50224dd5bdaf --- /dev/null +++ b/paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_context.h" + +namespace phi { +namespace funcs { + +// This functor involves a fusion calculation in Ernie or Bert. +// The fusion mode is as follows: +// +// in_var emb in_var emb +// | | | | +// lookup_table lookup_table +// | | +// lkt_var lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +template +class EmbEltwiseLayerNormFunctor { + public: + void operator()(int batch, + int seq_len, + int hidden, + const int64_t* ids, + const T* scale, + const T* bias, + const int64_t* embs, + T* output, + float eps, + int input_num, + gpuStream_t stream); +}; +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu new file mode 100644 index 0000000000000..0344a71b97062 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu @@ -0,0 +1,156 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/emb_eltwise_layer_norm_functor.h" + +namespace phi { +namespace fusion { + +template +void EmbeddingEltWiseLayerNormKernel( + const Context& dev_ctx, + const std::vector& ids, + const std::vector& embs, + const DenseTensor& bias, + const DenseTensor& scale, + const float epsilon, + DenseTensor* out) { + PADDLE_ENFORCE_GE( + epsilon, + 0.0f, + phi::errors::InvalidArgument( + "'epsilon' is %f, but it should be between 0.0 and 0.001", epsilon)); + PADDLE_ENFORCE_LE( + epsilon, + 0.001f, + phi::errors::InvalidArgument( + "'epsilon' is %f, but it should be between 0.0 and 0.001.", epsilon)); + int input_num = static_cast(ids.size()); + + DenseTensor in_ids_(phi::DataType::INT64), in_embs_(phi::DataType::INT64); + DDim in_dim{input_num}; + int device_id; +#ifdef PADDLE_WITH_HIP + hipGetDevice(&device_id); +#else + cudaGetDevice(&device_id); +#endif + + in_ids_.Resize(in_dim); + in_embs_.Resize(in_dim); + + int64_t* in_ids_d = dev_ctx.template Alloc( + &in_ids_, in_ids_.numel() * sizeof(int64_t)); + int64_t* in_embs_d = dev_ctx.template Alloc( + &in_embs_, in_embs_.numel() * sizeof(int64_t)); + + std::vector in1s, in2s; + for (int i = 0; i < input_num; ++i) { + in1s.push_back(reinterpret_cast(ids[i]->data())); + in2s.push_back(reinterpret_cast(embs[i]->data())); + } +#ifdef PADDLE_WITH_HIP + hipMemcpyAsync(in_ids_d, + in1s.data(), + sizeof(int64_t) * input_num, + hipMemcpyHostToDevice, + dev_ctx.stream()); + hipMemcpyAsync(in_embs_d, + in2s.data(), + sizeof(int64_t) * input_num, + hipMemcpyHostToDevice, + dev_ctx.stream()); +#else + cudaMemcpyAsync(in_ids_d, + in1s.data(), + sizeof(int64_t) * input_num, + cudaMemcpyHostToDevice, + dev_ctx.stream()); + cudaMemcpyAsync(in_embs_d, + in2s.data(), + sizeof(int64_t) * input_num, + cudaMemcpyHostToDevice, + dev_ctx.stream()); +#endif + + // should be (B * S * hidden) + auto id0_dims = ids[0]->dims(); + auto emb0_dims = embs[0]->dims(); + + int batch = id0_dims[0]; + int seq_len = id0_dims[1]; + int hidden = emb0_dims[1]; + + auto* bias_d = bias.data(); + auto* scale_d = scale.data(); + auto* output_d = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + + if (std::is_same::value) { + const half* scale_new = reinterpret_cast(scale_d); + const half* bias_new = reinterpret_cast(bias_d); + half* output_new = reinterpret_cast(output_d); + + phi::funcs::EmbEltwiseLayerNormFunctor emb_eltwise_layernorm_func; + emb_eltwise_layernorm_func(batch, + seq_len, + hidden, + in_ids_d, + scale_new, + bias_new, + in_embs_d, + output_new, + epsilon, + input_num, + dev_ctx.stream()); + } else { + phi::funcs::EmbEltwiseLayerNormFunctor emb_eltwise_layernorm_func; + emb_eltwise_layernorm_func(batch, + seq_len, + hidden, + in_ids_d, + scale_d, + bias_d, + in_embs_d, + output_d, + epsilon, + input_num, + dev_ctx.stream()); + } +} + +} // namespace fusion +} // namespace phi + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 +PD_REGISTER_KERNEL(fused_embedding_eltwise_layernorm, + GPU, + ALL_LAYOUT, + phi::fusion::EmbeddingEltWiseLayerNormKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(fused_embedding_eltwise_layernorm, + GPU, + ALL_LAYOUT, + phi::fusion::EmbeddingEltWiseLayerNormKernel, + float) {} +#endif diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu b/paddle/phi/kernels/fusion/gpu/fused_fc_elementwise_layernorm_kernel.cu similarity index 71% rename from paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu rename to paddle/phi/kernels/fusion/gpu/fused_fc_elementwise_layernorm_kernel.cu index f4a9f0a77a53b..f7f8faa329d60 100644 --- a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_fc_elementwise_layernorm_kernel.cu @@ -1,16 +1,19 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include #ifdef __NVCC__ #include @@ -24,13 +27,17 @@ namespace cub = hipcub; #include #endif -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" -namespace paddle { -namespace operators { +namespace phi { +namespace fusion { using float16 = phi::dtype::float16; @@ -300,8 +307,8 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data, } } -template -void AddReluAddLayerNorm(gpuStream_t stream, +template +void AddReluAddLayerNorm(const Context& dev_ctx, bool with_relu, int max_threads, const T* y, @@ -315,30 +322,30 @@ void AddReluAddLayerNorm(gpuStream_t stream, int N, float epsilon) { if (with_relu) { - switch (platform::RoundToPowerOfTwo(N)) { + switch (phi::backends::gpu::RoundToPowerOfTwo(N)) { CUDA_LAUNCH_KERNEL_HELPER( InplaceAddReluAddLayerNormKernel <<>>( + dev_ctx.stream()>>>( y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon)); } } else { - switch (platform::RoundToPowerOfTwo(N)) { + switch (phi::backends::gpu::RoundToPowerOfTwo(N)) { CUDA_LAUNCH_KERNEL_HELPER( InplaceAddReluAddLayerNormKernel <<>>( + dev_ctx.stream()>>>( y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon)); } } } -template <> -void AddReluAddLayerNorm(gpuStream_t stream, +template +void AddReluAddLayerNorm(const Context& dev_ctx, bool with_relu, int max_threads, const float16* y, @@ -352,109 +359,122 @@ void AddReluAddLayerNorm(gpuStream_t stream, int N, float epsilon) { if (with_relu) { - switch (platform::RoundToPowerOfTwo(N)) { + switch (phi::backends::gpu::RoundToPowerOfTwo(N)) { CUDA_LAUNCH_KERNEL_HELPER( InplaceAddReluAddLayerNormKernel <<>>( + dev_ctx.stream()>>>( y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon)); } } else { - switch (platform::RoundToPowerOfTwo(N)) { + switch (phi::backends::gpu::RoundToPowerOfTwo(N)) { CUDA_LAUNCH_KERNEL_HELPER( InplaceAddReluAddLayerNormKernel <<>>( + dev_ctx.stream()>>>( y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon)); } } } -template -class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* out = ctx.Output("Out"); - - auto w_dims = w->dims(); - int N = w_dims[1]; - int K = w_dims[0]; - int M = phi::product(x->dims()) / K; - - const T* x_data = x->data(); - const T* w_data = w->data(); - - auto& dev_ctx = ctx.template device_context(); - auto* out_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - - auto blas = phi::funcs::GetBlas(dev_ctx); - blas.GEMM(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1.0), - x_data, - w_data, - static_cast(0.0), - out_data); - auto* y = ctx.Input("Y"); - auto* bias_0 = ctx.Input("Bias0"); - auto* bias_1 = ctx.Input("Bias1"); - auto* scale = ctx.Input("Scale"); - - const T* y_data = y->data(); - const T* bias_0_data = bias_0 ? bias_0->data() : nullptr; - const T* bias_1_data = bias_1 ? bias_1->data() : nullptr; - const T* scale_data = scale ? scale->data() : nullptr; - - auto* mean = ctx.Output("Mean"); - auto* variance = ctx.Output("Variance"); - - T* mean_data = - mean ? dev_ctx.template Alloc(mean, mean->numel() * sizeof(T)) - : nullptr; - T* variance_data = variance ? dev_ctx.template Alloc( - variance, variance->numel() * sizeof(T)) - : nullptr; - - bool with_relu = - (ctx.Attr("activation_type") == "relu") ? true : false; - float epsilon = ctx.Attr("epsilon"); - - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - AddReluAddLayerNorm(dev_ctx.stream(), - with_relu, - max_threads, - y_data, - bias_0_data, - bias_1_data, - scale_data, - out_data, - mean_data, - variance_data, - M, - N, - epsilon); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(fused_fc_elementwise_layernorm, - GPU, - ALL_LAYOUT, - ops::FusedFCElementwiseLayerNormOpKernel, - float, - double, - plat::float16) {} +template +void FusedFCElementwiseLayerNormKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& y, + const paddle::optional& bias0, + const paddle::optional& scale, + const paddle::optional& bias1, + const int x_num_col_dims, + const std::string& activation_type, + const float epsilon, + const int begin_norm_axis, + DenseTensor* out, + DenseTensor* mean, + DenseTensor* variance) { + PADDLE_ENFORCE_GE( + x_num_col_dims, + 1, + phi::errors::InvalidArgument( + "The x_num_col_dims must be greater than or equal to 1, " + "But received the x_num_col_dims is %d", + x_num_col_dims)); + PADDLE_ENFORCE_GE(epsilon, + 0.0f, + phi::errors::InvalidArgument( + "'epsilon' should be between 0.0 and 0.001.")); + PADDLE_ENFORCE_LE(epsilon, + 0.001f, + phi::errors::InvalidArgument( + "'epsilon' should be between 0.0 and 0.001.")); + PADDLE_ENFORCE_GT(begin_norm_axis, + 0, + phi::errors::InvalidArgument( + "'begin_norm_axis' should be greater than zero.")); + + auto w_dims = w.dims(); + int N = w_dims[1]; + int K = w_dims[0]; + int M = phi::product(x.dims()) / K; + + const T* x_data = x.data(); + const T* w_data = w.data(); + + auto* out_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1.0), + x_data, + w_data, + static_cast(0.0), + out_data); + + const T* y_data = y.data(); + const T* bias_0_data = bias0 ? bias0->data() : nullptr; + const T* bias_1_data = bias1 ? bias1->data() : nullptr; + const T* scale_data = scale ? scale->data() : nullptr; + + T* mean_data = + mean ? dev_ctx.template Alloc(mean, mean->numel() * sizeof(T)) + : nullptr; + T* variance_data = variance ? dev_ctx.template Alloc( + variance, variance->numel() * sizeof(T)) + : nullptr; + + bool with_relu = (activation_type == "relu") ? true : false; + + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + AddReluAddLayerNorm(dev_ctx, + with_relu, + max_threads, + y_data, + bias_0_data, + bias_1_data, + scale_data, + out_data, + mean_data, + variance_data, + M, + N, + epsilon); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_fc_elementwise_layernorm, + GPU, + ALL_LAYOUT, + phi::fusion::FusedFCElementwiseLayerNormKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu b/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu new file mode 100644 index 0000000000000..954fbd67b96ab --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu @@ -0,0 +1,127 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/common_shape.h" + +namespace phi { +namespace fusion { + +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; + +template +void TransposeFlattenConcatFusionKernel( + const Context& dev_ctx, + const std::vector& x, + const std::vector& trans_axis, + const int flatten_axis, + const int concat_axis, + DenseTensor* out) { + dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + auto odims = out->dims(); + + int rank = x[0]->dims().size(); + // use at least 4D in cudnnTransformTensor + int max_dim = rank < 4 ? 4 : rank; + std::vector stride_x(max_dim, 0); + std::vector stride_y(max_dim, 0); + std::vector dims_y(max_dim, 0); + + cudnnTensorDescriptor_t in_desc; + cudnnTensorDescriptor_t out_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&in_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&out_desc)); + cudnnDataType_t cudnn_dtype = CudnnDataType::type; + + auto handle = dev_ctx.cudnn_handle(); + + T* odata = out->data(); + for (auto& item : x) { + auto perm_shape = phi::funcs::GetPermuteShape(trans_axis, item->dims()); + int osize = 1; + auto idims = item->dims(); + for (int i = 0; i < rank; i++) { + stride_x[i] = 1; + for (int j = trans_axis[i] + 1; j < rank; j++) { + stride_x[i] *= idims[j]; + } + dims_y[i] = perm_shape[i]; + osize *= perm_shape[i]; + } + stride_y[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + if (((i + 1) == flatten_axis) && (concat_axis == 1)) { + stride_y[i] = odims[1]; + } else { + stride_y[i] = stride_y[i + 1] * perm_shape[i + 1]; + } + } + + // Since concat is after flatten, the output is 2D tensor. + // If concat_axis is 0, each input's permutated tensor is continuous. + // If concat_axis is 1, the stride of 0-th dim of each input's + // permutated tensor is odims()[1]. + + for (int i = rank; i < max_dim; i++) { + stride_x[i] = 1; + stride_y[i] = 1; + dims_y[i] = 1; + } + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor( + in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data())); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor( + out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data())); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnTransformTensor( + handle, + CudnnDataType::kOne(), + in_desc, + static_cast(item->data()), + CudnnDataType::kZero(), + out_desc, + static_cast(odata))); + if (concat_axis == 0) { + odata += osize; + } else { + auto flat_shape = phi::funcs::GetFlattenShape(flatten_axis, perm_shape); + odata += flat_shape[1]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(in_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(out_desc)); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_transpose_flatten_concat, + GPU, + ALL_LAYOUT, + phi::fusion::TransposeFlattenConcatFusionKernel, + float, + double) {} diff --git a/test/legacy_test/test_fusion_transpose_flatten_concat_op.py b/test/legacy_test/test_fusion_transpose_flatten_concat_op.py index de557e4c4a52e..a0ef5e25b58b6 100644 --- a/test/legacy_test/test_fusion_transpose_flatten_concat_op.py +++ b/test/legacy_test/test_fusion_transpose_flatten_concat_op.py @@ -54,7 +54,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, 1e-6) + self.check_output_with_place(place, 1e-6, check_dygraph=False) def init_test_case(self): self.shapes = [(3, 4, 17, 17), (3, 8, 7, 7), (3, 12, 5, 5)]