diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index 2d6a1122b0c28..75fd64297db01 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -13,8 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" namespace paddle { @@ -25,107 +28,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueOp"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueOp"); - OP_INOUT_CHECK( - ctx->HasInput("Bias"), "Output", "Bias", "FusedGemmEpilogueOp"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "FusedGemmEpilogueOp"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - auto bias_dims = ctx->GetInputDim("Bias"); - auto trans_x = ctx->Attrs().Get("trans_x"); - auto trans_y = ctx->Attrs().Get("trans_y"); - - PADDLE_ENFORCE_EQ( - y_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input tensor Y's dimension of FusedGemmEpilogueOp " - " should be 2, but got %d.", - y_dims.size())); - - PADDLE_ENFORCE_GE( - x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input tensor X's dimension of FusedGemmEpilogueOp " - " should be >= 2, but got %d.", - x_dims.size())); - - PADDLE_ENFORCE_EQ( - bias_dims.size(), - 1, - platform::errors::InvalidArgument( - "The Input tensor bias's dimension of FusedGemmEpilogueOp " - " should be == 1, but got %d.", - bias_dims.size())); - - PADDLE_ENFORCE_EQ(bias_dims[0], - trans_y ? y_dims[0] : y_dims[1], - platform::errors::InvalidArgument( - "The Input tensor bias's dimension 0" - " should be == Y[-1], but got bias's shape = [%s] " - "and Y's shape = [%s]", - bias_dims, - y_dims)); - - auto x_mat_dims = - common::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1); - - int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1]; - int K_from_y = trans_y ? y_dims[1] : y_dims[0]; - - PADDLE_ENFORCE_EQ( - K_from_x, - K_from_y, - platform::errors::InvalidArgument( - "The last dimension of X should be equal with Y's first dimension." - "But received X[-1] = [%d], Y[0] = [%d].", - K_from_x, - K_from_y)); - - std::vector out_dims; - out_dims.reserve(static_cast(x_dims.size())); - if (trans_x) { - for (int i = 1; i < x_dims.size(); ++i) out_dims.push_back(x_dims[i]); - } else { - for (int i = 0; i < x_dims.size() - 1; ++i) out_dims.push_back(x_dims[i]); - } - - if (trans_y) { - out_dims.push_back(y_dims[0]); - } else { - out_dims.push_back(y_dims[1]); - } - ctx->SetOutputDim("Out", common::make_ddim(out_dims)); - - auto activation = ctx->Attrs().Get("activation"); - if (ctx->HasOutput("ReserveSpace")) { - ctx->SetOutputDim("ReserveSpace", common::make_ddim(out_dims)); - - if (activation == "none") { - PADDLE_THROW(platform::errors::InvalidArgument( - "The ReserveSpace would not be used when activation = \"none\"")); - } else { - int min_size_of_n = activation == "relu" ? 128 : 8; - int N_size = trans_y ? y_dims[0] : y_dims[1]; - PADDLE_ENFORCE_EQ( - N_size % min_size_of_n, - 0, - platform::errors::InvalidArgument( - "The output dimension N (X(MxK) * Y(KxN) = C(MxN)) " - "should be multiple of %d when auxiliary_key given " - "and activation=%s, but got N = %d.", - min_size_of_n, - activation, - N_size)); - } - } - } - phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -188,94 +90,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("DOut"), "Input", "DOut", "FusedGemmEpilogueGradOp"); - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueGradOp"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueGradOp"); - OP_INOUT_CHECK(ctx->HasOutput("DY"), "Output", "DY", "FusedGemmEpilogueOp"); - - auto dout_dims = ctx->GetInputDim("DOut"); - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - auto trans_x = ctx->Attrs().Get("trans_x"); - auto trans_y = ctx->Attrs().Get("trans_y"); - - PADDLE_ENFORCE_GE( - dout_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input tensor DOut's dimension of FusedGemmEpilogueGradOp " - " should be >= 2, but got %d.", - dout_dims.size())); - - PADDLE_ENFORCE_EQ( - y_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input tensor Y's dimension of FusedGemmEpilogueGradOp " - " should be 2, but got %d.", - y_dims.size())); - - PADDLE_ENFORCE_GE( - x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input tensor X's dimension of FusedGemmEpilogueGradOp " - " should be >= 2, but got %d.", - x_dims.size())); - - PADDLE_ENFORCE_EQ( - dout_dims.size(), - x_dims.size(), - platform::errors::InvalidArgument( - "The Input tensor DOut's and X's dimension of " - "FusedGemmEpilogueGradOp " - " should be the same, but got DOut's dim = %d and X's = %d.", - dout_dims.size(), - x_dims.size())); - - auto dout_mat_dims = common::flatten_to_2d(dout_dims, dout_dims.size() - 1); - auto x_mat_dims = common::flatten_to_2d(x_dims, x_dims.size() - 1); - - PADDLE_ENFORCE_EQ( - dout_mat_dims[1], - trans_y ? y_dims[0] : y_dims[1], - platform::errors::InvalidArgument( - "The last dimension of DOut should be equal with Y's last" - "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", - dout_mat_dims[1], - y_dims[1])); - - PADDLE_ENFORCE_EQ( - dout_mat_dims[0], - trans_x ? x_mat_dims[1] : x_mat_dims[0], - platform::errors::InvalidArgument( - "The first dimension of DOut should be equal with X's first" - "dimension. But received DOut[0] = [%d], Y[0] = [%d].", - dout_mat_dims[0], - x_mat_dims[0])); - - auto activation_grad = ctx->Attrs().Get("activation_grad"); - if (activation_grad != "none" && !ctx->HasInput("ReserveSpace")) { - PADDLE_ENFORCE_EQ(true, - false, - platform::errors::InvalidArgument( - "The ReserveSpace should not be empty. " - "when activation == {relu_grad, gelu_grad}.")); - } - - if (ctx->HasOutput("DX")) { - ctx->SetOutputDim("DX", x_dims); - } - ctx->SetOutputDim("DY", y_dims); - - if (ctx->HasOutput("DBias")) { - int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1]; - ctx->SetOutputDim("DBias", common::make_ddim({dbias_dim})); - } - } - phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); @@ -367,12 +181,19 @@ class FusedGemmEpilogueOpGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - fused_gemm_epilogue, - ops::FusedGemmEpilogueOp, - ops::FusedGemmEpilogueOpMaker, - ops::FusedGemmEpilogueOpGradMaker, - ops::FusedGemmEpilogueOpGradMaker); +DECLARE_INFER_SHAPE_FUNCTOR(fused_gemm_epilogue, + FusedGemmEpilogueInferShapeFunctor, + PD_INFER_META(phi::FusedGemmEpilogueInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(fused_gemm_epilogue_grad, + FusedGemmEpilogueGradInferShapeFunctor, + PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta)); +REGISTER_OPERATOR(fused_gemm_epilogue, + ops::FusedGemmEpilogueOp, + ops::FusedGemmEpilogueOpMaker, + ops::FusedGemmEpilogueOpGradMaker, + ops::FusedGemmEpilogueOpGradMaker, + FusedGemmEpilogueInferShapeFunctor); REGISTER_OPERATOR(fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradOp, - ops::FusedGemmEpilogueGradOpMaker); + ops::FusedGemmEpilogueGradOpMaker, + FusedGemmEpilogueGradInferShapeFunctor); diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu deleted file mode 100644 index 2ae9f65c4e5a2..0000000000000 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ /dev/null @@ -1,195 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" - -namespace paddle { -namespace operators { - -#if CUDA_VERSION >= 11060 - -template -phi::funcs::MatmulFusedType GetFwdFusedEpilogueType( - const phi::GPUContext& ctx, - const std::string& activation, - phi::DenseTensor* reserve_space) { - using FusedType = phi::funcs::MatmulFusedType; - - FusedType fused_type = FusedType::kMatmulBias; - if (activation != "none") { - if (activation == "relu") { - if (reserve_space == nullptr) { - fused_type = FusedType::kMatmulBiasRelu; - } else { - fused_type = FusedType::kMatmulBiasReluWithReservedData; - int64_t reserve_size = SizeOf(phi::DataType::BOOL) * - common::product(reserve_space->dims()); - ctx.Alloc(reserve_space, phi::DataType::BOOL, reserve_size); - } - } else if (activation == "gelu") { - if (reserve_space == nullptr) { - fused_type = FusedType::kMatmulBiasGelu; - } else { - fused_type = FusedType::kMatmulBiasGeluWithReservedData; - int64_t reserve_size = - sizeof(T) * common::product(reserve_space->dims()); - ctx.Alloc(reserve_space, reserve_size); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "fused_gemm_epilogue's activate should be one of {none, relu, gelu}," - " but received %s, please check", - activation)); - } - } - return fused_type; -} - -template -class FusedGemmEpilogueKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#if CUDA_VERSION < 11060 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_gemm_epilogue operator only support CUDA 11.6 " - "or higher version.")); -#endif - auto& dev_ctx = ctx.template device_context(); - - const phi::DenseTensor* x = ctx.Input("X"); - const phi::DenseTensor* y = ctx.Input("Y"); - const phi::DenseTensor* bias = ctx.Input("Bias"); - phi::DenseTensor* out = ctx.Output("Out"); - phi::DenseTensor* reserve_space = - ctx.Output("ReserveSpace"); - - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); - - std::string activation = ctx.Attr("activation"); - dev_ctx.Alloc(out, out->numel() * sizeof(T)); - // (M * K) * (K * N) - auto x_mat_dims = - common::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); - int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; - int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; - int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; - - auto fused_type = - GetFwdFusedEpilogueType(dev_ctx, activation, reserve_space); - void* reserve_data = reserve_space ? reserve_space->data() : nullptr; - - VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims() - << "}, out.shape={" << out->dims() << "}, M=" << M << ", N=" << N - << ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y - << ", activation=" << activation << ", fused_type=" << fused_type - << ", reserve_space=" << reserve_space; - - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - x, - y, - out, - static_cast(bias->data()), - reserve_data, - M, - N, - K, - trans_x, - trans_y, - fused_type); - } -}; - -template -class FusedGemmEpilogueGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#if CUDA_VERSION < 11060 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_gemm_epilogue operator only support CUDA 11.6 " - "or higher version.")); -#endif - auto& dev_ctx = ctx.template device_context(); - - const phi::DenseTensor* dout = ctx.Input("DOut"); - const phi::DenseTensor* x = ctx.Input("X"); - const phi::DenseTensor* y = ctx.Input("Y"); - const phi::DenseTensor* reserve_space = - ctx.Input("ReserveSpace"); - phi::DenseTensor* dx = ctx.Output("DX"); - phi::DenseTensor* dy = ctx.Output("DY"); - phi::DenseTensor* dbias = ctx.Output("DBias"); - - std::string activation_grad = ctx.Attr("activation_grad"); - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); - - // (M * K) * (K * N) - auto x_mat_dims = - common::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); - int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; - int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; - int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; - - VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims() - << "}, dout.shape={" << dout->dims() << "}, M=" << M << ", N=" << N - << ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y - << ", activation=" << activation_grad - << ", reserve_space=" << reserve_space; - - phi::funcs::ComputeFusedGemmEpilogueBackward(dev_ctx, - dout, - x, - y, - reserve_space, - M, - N, - K, - trans_x, - trans_y, - activation_grad, - dx, - dy, - dbias); - } -}; -#endif - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue, - GPU, - ALL_LAYOUT, - ops::FusedGemmEpilogueKernel, - float, - double, - plat::float16, - plat::bfloat16) {} -PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue_grad, - GPU, - ALL_LAYOUT, - ops::FusedGemmEpilogueGradKernel, - float, - double, - plat::float16, - plat::bfloat16) {} diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op_xpu.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op_xpu.cc deleted file mode 100644 index fb6afbf5d256d..0000000000000 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op_xpu.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/scope_guard.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" - -namespace paddle { -namespace operators { - -template -class FusedGemmEpilogueXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - - const phi::DenseTensor* x = ctx.Input("X"); - const phi::DenseTensor* y = ctx.Input("Y"); - const phi::DenseTensor* bias = ctx.Input("Bias"); - - phi::DenseTensor* out = ctx.Output("Out"); - phi::DenseTensor* reserve_space = - ctx.Output("ReserveSpace"); - - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); - - std::string activation = ctx.Attr("activation"); - VLOG(5) << "trans_x = " << trans_x << " , trans_y = " << trans_y - << " , activation = " << activation; - - auto x_mat_dims = - common::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); - - // (M * K) * (K * N) for new api use - // int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; - // int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; - // int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; - - // 调用新接口,这里先分开调用,等待qingpen的新接口 - int r = 0; - xpu::Activation_t act = xpu::Activation_t::LINEAR; - if (activation == "relu") { - act = xpu::Activation_t::RELU; - } else if (activation == "gelu") { - act = xpu::Activation_t::GELU; - } - // fc + bias + act - // 1. fc - phi::XpuFcInfo fc_info; - - phi::GetFCInfo(x_mat_dims, y->dims(), trans_x, trans_y, &fc_info); - xpu::Context* xpu_ctx = dev_ctx.x_context(); - - const XPUType* x_ptr = reinterpret_cast(x->data()); - const XPUType* y_ptr = reinterpret_cast(y->data()); - XPUType* out_ptr = - reinterpret_cast(out->mutable_data(ctx.GetPlace())); - xpu::ctx_guard RAII_GUARD(xpu_ctx); - XPUType* fc_out_ptr = RAII_GUARD.alloc_l3_or_gm(out->numel()); - phi::MatMulXPUFunction( - xpu_ctx, x_ptr, y_ptr, fc_out_ptr, fc_info, 1.0f); - XPUType* bias_out_ptr = out_ptr; - if (activation != "none" && reserve_space) { - bias_out_ptr = reinterpret_cast( - reserve_space->mutable_data(ctx.GetPlace())); - } - // 2 bias - const XPUType* bias_ptr = reinterpret_cast(bias->data()); - r = xpu::broadcast_add(xpu_ctx, - fc_out_ptr, - bias_ptr, - bias_out_ptr, - {fc_info.m, fc_info.n}, - {fc_info.n}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - // 3 act - if (activation == "relu") { - r = xpu::relu(xpu_ctx, bias_out_ptr, out_ptr, out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); - } else if (activation == "gelu") { - r = xpu::gelu(xpu_ctx, bias_out_ptr, out_ptr, out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); - } - } -}; - -template -class FusedGemmEpilogueXPUGradKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); - auto& dev_ctx = ctx.template device_context(); - const phi::DenseTensor* dout = ctx.Input("DOut"); - const phi::DenseTensor* x = ctx.Input("X"); - const phi::DenseTensor* y = ctx.Input("Y"); - - const phi::DenseTensor* reserve_space = - ctx.Input("ReserveSpace"); - - phi::DenseTensor* dx = ctx.Output("DX"); - phi::DenseTensor* dy = ctx.Output("DY"); - phi::DenseTensor* dbias = ctx.Output("DBias"); - - std::string activation = "none"; - if (ctx.HasAttr("activation")) { - activation = ctx.Attr("activation"); - } else if (ctx.HasAttr("activation_grad")) { - activation = ctx.Attr("activation_grad"); - } - - auto* xpu_ctx = dev_ctx.x_context(); - xpu::ctx_guard RAII_GUARD(xpu_ctx); - const XPUType* dout_ptr = reinterpret_cast(dout->data()); - - const XPUType* dout_fc_ptr = dout_ptr; - const XPUType* x_ptr = reinterpret_cast(x->data()); - const XPUType* y_ptr = reinterpret_cast(y->data()); - - // const XPUType* - const XPUType* reserve_space_ptr = - (reserve_space == NULL) - ? (reinterpret_cast(NULL)) - : (reinterpret_cast(reserve_space->data())); - XPUType* d_act_input_ptr = NULL; - if (activation != "none") { - d_act_input_ptr = RAII_GUARD.alloc_l3_or_gm(dout->numel()); - dout_fc_ptr = d_act_input_ptr; - } - - // 1. act_grad 2. fc_grad 3. dbias - int r = 0; - if (activation == "relu") { - r = xpu::relu_grad(xpu_ctx, - reserve_space_ptr, - reserve_space_ptr, - dout_ptr, - d_act_input_ptr, - dout->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad"); - } else if (activation == "gelu") { - r = xpu::gelu_grad(xpu_ctx, - reserve_space_ptr, - reserve_space_ptr, - dout_ptr, - d_act_input_ptr, - dout->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); - } - - auto x_mat_dims = - common::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); - phi::XpuFcInfo info_forward; - phi::GetFCInfo(x_mat_dims, y->dims(), trans_x, trans_y, &info_forward); - - // 2. fc_grad - const XPUType* a_1 = reinterpret_cast(NULL); - const XPUType* b_1 = reinterpret_cast(NULL); - const XPUType* a_2 = reinterpret_cast(NULL); - const XPUType* b_2 = reinterpret_cast(NULL); - XPUType* c_1 = - (dx == NULL) - ? reinterpret_cast(NULL) - : reinterpret_cast(dx->mutable_data(ctx.GetPlace())); - XPUType* c_2 = - (dy == NULL) - ? reinterpret_cast(NULL) - : reinterpret_cast(dy->mutable_data(ctx.GetPlace())); - phi::XpuFcInfo info_dx; - phi::XpuFcInfo info_dy; - std::tuple - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - info_forward, - trans_x, - trans_y, - x_ptr, - y_ptr, - dout_fc_ptr); - std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; - if (dx) { - phi::MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); - } - if (dy) { - phi::MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); - } - // 3. dbias - if (dbias) { - XPUType* dbias_ptr = - reinterpret_cast(dbias->mutable_data(ctx.GetPlace())); - r = xpu::reduce_sum(xpu_ctx, - dout_fc_ptr, - dbias_ptr, - {info_forward.m, info_forward.n}, - {0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue, - XPU, - ALL_LAYOUT, - ops::FusedGemmEpilogueXPUKernel, - float, - plat::float16) {} -PD_REGISTER_STRUCT_KERNEL(fused_gemm_epilogue_grad, - XPU, - ALL_LAYOUT, - ops::FusedGemmEpilogueXPUGradKernel, - float, - plat::float16) {} diff --git a/paddle/fluid/operators/identity_loss_op.cc b/paddle/fluid/operators/identity_loss_op.cc deleted file mode 100644 index 76e7f8a733e40..0000000000000 --- a/paddle/fluid/operators/identity_loss_op.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class IdentityLossOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); - } -}; - -class IdentityLossOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) The input of identity_loss op"); - AddOutput("Out", "(Tensor) The output of identity_loss op"); - AddAttr("reduction", "(int, default 1). The reduction.") - .SetDefault(1) - .InEnum({0, 1, 2}); - AddComment(R"DOC( -IdentityLoss Operator mark the Loss var. - -)DOC"); - } -}; - -class IdentityLossGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", framework::GradVarName("X")); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(input_data_type, platform::CPUPlace()); - } -}; - -template -class IdentityLossGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("identity_loss_grad"); - grad_op->SetInput("X", this->Input("X")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_INPLACE_OP_INFERER(IdentityLossInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(IdentityLossGradInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")}); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(identity_loss, - IdentityLossInferShapeFunctor, - PD_INFER_META(phi::IdentityLossInferMeta)); - -REGISTER_OPERATOR(identity_loss, - ops::IdentityLossOp, - ops::IdentityLossOpMaker, - ops::IdentityLossGradMaker, - ops::IdentityLossGradMaker, - ops::IdentityLossInplaceInferer, - IdentityLossInferShapeFunctor); - -REGISTER_OPERATOR(identity_loss_grad, - ops::IdentityLossGradOp, - ops::IdentityLossGradInplaceInferer); diff --git a/paddle/fluid/operators/ops_signature/fused_gemm_epilogue_sig.cc b/paddle/fluid/operators/ops_signature/fused_gemm_epilogue_sig.cc new file mode 100644 index 0000000000000..9934f8ea443f0 --- /dev/null +++ b/paddle/fluid/operators/ops_signature/fused_gemm_epilogue_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FusedGemmEpilogueOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_gemm_epilogue", + {"X", "Y", "Bias"}, + {"trans_x", "trans_y", "activation"}, + {"Out", "ReserveSpace"}); +} + +KernelSignature FusedGemmEpilogueGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_gemm_epilogue_grad", + {"X", "Y", "ReserveSpace", "DOut"}, + {"trans_x", "trans_y", "activation_grad"}, + {"DX", "DY", "DBias"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_gemm_epilogue, + phi::FusedGemmEpilogueOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fused_gemm_epilogue_grad, + phi::FusedGemmEpilogueGradOpArgumentMapping); diff --git a/paddle/fluid/operators/ops_signature/identity_loss_sig.cc b/paddle/fluid/operators/ops_signature/identity_loss_sig.cc deleted file mode 100644 index 0b748396e9d05..0000000000000 --- a/paddle/fluid/operators/ops_signature/identity_loss_sig.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature IdentityLossOpArgumentMapping( - const ArgumentMappingContext& ctx UNUSED) { - return KernelSignature("identity_loss", {"X"}, {"reduction"}, {"Out"}); -} - -KernelSignature IdentityLossGradOpArgumentMapping( - const ArgumentMappingContext& ctx UNUSED) { - return KernelSignature( - "identity_loss_grad", {"X", "Out@GRAD"}, {"reduction"}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(identity_loss, phi::IdentityLossOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(identity_loss_grad, - phi::IdentityLossGradOpArgumentMapping); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 8ee3ca74c69a7..469a18888e515 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1097,6 +1097,17 @@ kernel : func : i1e_grad +- backward_op : identity_loss_grad + forward : identity_loss (Tensor x, int reduction) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int reduction) + output : Tensor(x_grad) + infer_meta : + func : IdentityLossGradInferMeta + kernel : + func : identity_loss_grad + data_type : out_grad + inplace : (out_grad -> x_grad) + - backward_op : imag_grad forward : imag (Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 1defdb9906bde..7c0402069954d 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3439,6 +3439,12 @@ outputs : out : Out +- op: identity_loss + inputs : + x: X + outputs : + out : Out + - op: lod_array_length inputs : {x: X} diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e00e6c0c05258..2f5c5b66194b1 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1228,6 +1228,16 @@ func : i1e backward : i1e_grad +- op : identity_loss + args : (Tensor x, int reduction = 1) + output : Tensor(out) + infer_meta : + func : IdentityLossInferMeta + kernel : + func : identity_loss + inplace: (x -> out) + backward : identity_loss_grad + - op : imag args : (Tensor x) output : Tensor (out) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 006fd4d5437d1..c15455e07182c 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1808,6 +1808,15 @@ void HuberLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void IdentityLossGradInferMeta(const MetaTensor& x, + const MetaTensor& out_grad, + const int reduction, + MetaTensor* x_grad) { + x_grad->set_dims(x.dims()); + x_grad->share_lod(x); + x_grad->set_dtype(out_grad.dtype()); +} + void IndexSampleInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 6ca56a5c98a17..92443d66d42ce 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -297,6 +297,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaTensor* residual, MetaConfig config = MetaConfig()); +void IdentityLossGradInferMeta(const MetaTensor& x, + const MetaTensor& out_grad, + const int reduction, + MetaTensor* x_grad); + void IndexSampleInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out, diff --git a/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu new file mode 100644 index 0000000000000..1d194289cdd68 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_grad_kernel.cu @@ -0,0 +1,87 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" + +namespace phi { +namespace fusion { + +template +void FusedGemmEpilogueGradKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const paddle::optional& reserve_space, + const DenseTensor& out_grad, + const bool trans_x, + const bool trans_y, + const std::string& activation_grad, + DenseTensor* x_grad, + DenseTensor* y_grad, + DenseTensor* bias_grad) { +#if CUDA_VERSION < 11060 + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_gemm_epilogue operator only support CUDA 11.6 " + "or higher version.")); +#endif + +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 11060 + + // (M * K) * (K * N) + auto x_mat_dims = + phi::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1); + int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; + int64_t K = trans_y ? y.dims()[1] : y.dims()[0]; + int64_t N = trans_y ? y.dims()[0] : y.dims()[1]; + + VLOG(6) << "x.shape={" << x.dims() << "}, y.shape={" << y.dims() + << "}, dout.shape={" << out_grad.dims() << "}, M=" << M << ", N=" << N + << ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y + << ", activation_grad=" << activation_grad + << ", reserve_space=" << reserve_space.get_ptr(); + + phi::funcs::ComputeFusedGemmEpilogueBackward(dev_ctx, + &out_grad, + &x, + &y, + reserve_space.get_ptr(), + M, + N, + K, + trans_x, + trans_y, + activation_grad, + x_grad, + y_grad, + bias_grad); +#endif +#endif +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_gemm_epilogue_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu new file mode 100644 index 0000000000000..810e15c405177 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_gemm_epilogue_kernel.cu @@ -0,0 +1,125 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" + +namespace phi { +namespace fusion { +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 11060 +template +phi::funcs::MatmulFusedType GetFwdFusedEpilogueType( + const phi::GPUContext& ctx, + const std::string& activation, + phi::DenseTensor* reserve_space) { + using FusedType = phi::funcs::MatmulFusedType; + + FusedType fused_type = FusedType::kMatmulBias; + if (activation != "none") { + if (activation == "relu") { + if (reserve_space == nullptr) { + fused_type = FusedType::kMatmulBiasRelu; + } else { + fused_type = FusedType::kMatmulBiasReluWithReservedData; + reserve_space->Resize({phi::product(reserve_space->dims())}); + ctx.template Alloc(reserve_space); + } + } else if (activation == "gelu") { + if (reserve_space == nullptr) { + fused_type = FusedType::kMatmulBiasGelu; + } else { + fused_type = FusedType::kMatmulBiasGeluWithReservedData; + int64_t reserve_size = sizeof(T) * phi::product(reserve_space->dims()); + ctx.template Alloc(reserve_space, reserve_size); + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "fused_gemm_epilogue's activate should be one of {none, relu, gelu}," + " but received %s, please check", + activation)); + } + } + return fused_type; +} +#endif +#endif + +template +void FusedGemmEpilogueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& bias, + const bool trans_x, + const bool trans_y, + const std::string& activation, + DenseTensor* out, + DenseTensor* reserve_space) { +#if CUDA_VERSION < 11060 + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_gemm_epilogue operator only support CUDA 11.6 " + "or higher version.")); +#endif +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 11060 + + dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + // (M * K) * (K * N) + auto x_mat_dims = + phi::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1); + int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; + int64_t K = trans_y ? y.dims()[1] : y.dims()[0]; + int64_t N = trans_y ? y.dims()[0] : y.dims()[1]; + + auto fused_type = + GetFwdFusedEpilogueType(dev_ctx, activation, reserve_space); + void* reserve_data = reserve_space ? reserve_space->data() : nullptr; + + VLOG(6) << "x.shape={" << x.dims() << "}, y.shape={" << y.dims() + << "}, out.shape={" << out->dims() << "}, M=" << M << ", N=" << N + << ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y + << ", activation=" << activation << ", fused_type=" << fused_type + << ", reserve_space=" << reserve_space; + + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &x, + &y, + out, + static_cast(bias.data()), + reserve_data, + M, + N, + K, + trans_x, + trans_y, + fused_type); +#endif +#endif +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_gemm_epilogue, + GPU, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_grad_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_grad_kernel.cc new file mode 100644 index 0000000000000..dff385a34d9cf --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_grad_kernel.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/scope_guard.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" + +namespace phi { +namespace fusion { + +template +void FusedGemmEpilogueXPUGradKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const paddle::optional& reserve_space, + const DenseTensor& out_grad, + const bool trans_x, + const bool trans_y, + const std::string& activation_grad, + DenseTensor* x_grad, + DenseTensor* y_grad, + DenseTensor* bias_grad) { + using XPUType = typename XPUTypeTrait::Type; + std::string activation = activation_grad; + + auto* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + const XPUType* dout_ptr = + reinterpret_cast(out_grad.data()); + + const XPUType* dout_fc_ptr = dout_ptr; + const XPUType* x_ptr = reinterpret_cast(x.data()); + const XPUType* y_ptr = reinterpret_cast(y.data()); + + // const XPUType* + const XPUType* reserve_space_ptr = + (reserve_space.get_ptr() == NULL) + ? (reinterpret_cast(NULL)) + : (reinterpret_cast(reserve_space->data())); + XPUType* d_act_input_ptr = NULL; + if (activation != "none") { + d_act_input_ptr = RAII_GUARD.alloc_l3_or_gm(out_grad.numel()); + dout_fc_ptr = d_act_input_ptr; + } + + // 1. act_grad 2. fc_grad 3. dbias + int r = 0; + if (activation == "relu") { + r = xpu::relu_grad(xpu_ctx, + reserve_space_ptr, + reserve_space_ptr, + dout_ptr, + d_act_input_ptr, + out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad"); + } else if (activation == "gelu") { + r = xpu::gelu_grad(xpu_ctx, + reserve_space_ptr, + reserve_space_ptr, + dout_ptr, + d_act_input_ptr, + out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); + } + auto x_mat_dims = + common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1); + phi::XpuFcInfo info_forward; + phi::GetFCInfo(x_mat_dims, y.dims(), trans_x, trans_y, &info_forward); + + // 2. fc_grad + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1; + if (x_grad == NULL) { + c_1 = reinterpret_cast(NULL); + } else { + auto* x_grad_tmp = dev_ctx.template Alloc(x_grad); + c_1 = reinterpret_cast(x_grad_tmp); + } + XPUType* c_2; + if (y_grad == NULL) { + c_2 = reinterpret_cast(NULL); + } else { + auto* y_grad_tmp = dev_ctx.template Alloc(y_grad); + c_2 = reinterpret_cast(y_grad_tmp); + } + phi::XpuFcInfo info_dx; + phi::XpuFcInfo info_dy; + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + trans_x, + trans_y, + x_ptr, + y_ptr, + dout_fc_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; + if (x_grad) { + phi::MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); + } + if (y_grad) { + phi::MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); + } + // 3. dbias + if (bias_grad) { + XPUType* dbias_ptr; + auto* dbias_tmp_ptr = dev_ctx.template Alloc(bias_grad); + dbias_ptr = reinterpret_cast(dbias_tmp_ptr); + r = xpu::reduce_sum( + xpu_ctx, dout_fc_ptr, dbias_ptr, {info_forward.m, info_forward.n}, {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_gemm_epilogue_grad, + XPU, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueXPUGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc new file mode 100644 index 0000000000000..8854a79987e79 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/scope_guard.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" + +namespace phi { +namespace fusion { + +template +void FusedGemmEpilogueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& bias, + const bool trans_x, + const bool trans_y, + const std::string& activation, + DenseTensor* out, + DenseTensor* reserve_space) { + using XPUType = typename XPUTypeTrait::Type; + + auto x_mat_dims = + common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1); + + // (M * K) * (K * N) for new api use + // int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; + // int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; + // int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; + + // 调用新接口,这里先分开调用,等待qingpen的新接口 + int r = 0; + xpu::Activation_t act = xpu::Activation_t::LINEAR; + if (activation == "relu") { + act = xpu::Activation_t::RELU; + } else if (activation == "gelu") { + act = xpu::Activation_t::GELU; + } + // fc + bias + act + // 1. fc + phi::XpuFcInfo fc_info; + + phi::GetFCInfo(x_mat_dims, y.dims(), trans_x, trans_y, &fc_info); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + const XPUType* x_ptr = reinterpret_cast(x.data()); + const XPUType* y_ptr = reinterpret_cast(y.data()); + auto* out_tmp_ptr = dev_ctx.template Alloc(out); + XPUType* out_ptr = reinterpret_cast(out_tmp_ptr); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + XPUType* fc_out_ptr = RAII_GUARD.alloc_l3_or_gm(out->numel()); + phi::MatMulXPUFunction( + xpu_ctx, x_ptr, y_ptr, fc_out_ptr, fc_info, 1.0f); + XPUType* bias_out_ptr = out_ptr; + if (activation != "none" && reserve_space) { + auto* bias_out_temp_ptr = dev_ctx.template Alloc(reserve_space); + bias_out_ptr = reinterpret_cast(bias_out_temp_ptr); + } + // 2 bias + const XPUType* bias_ptr = reinterpret_cast(bias.data()); + r = xpu::broadcast_add(xpu_ctx, + fc_out_ptr, + bias_ptr, + bias_out_ptr, + {fc_info.m, fc_info.n}, + {fc_info.n}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + // 3 act + if (activation == "relu") { + r = xpu::relu(xpu_ctx, bias_out_ptr, out_ptr, out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); + } else if (activation == "gelu") { + r = xpu::gelu(xpu_ctx, bias_out_ptr, out_ptr, out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_gemm_epilogue, + XPU, + ALL_LAYOUT, + phi::fusion::FusedGemmEpilogueKernel, + float, + phi::dtype::float16) {} diff --git a/python/paddle/incubate/nn/loss.py b/python/paddle/incubate/nn/loss.py index 0a8043b445b82..c6eb7df467a79 100644 --- a/python/paddle/incubate/nn/loss.py +++ b/python/paddle/incubate/nn/loss.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _legacy_C_ops +from paddle import _C_ops from paddle.base.data_feeder import check_variable_and_dtype from paddle.base.layer_helper import LayerHelper from paddle.framework import in_dynamic_mode @@ -60,7 +60,7 @@ def identity_loss(x, reduction="none"): raise Exception("Unsupported reduction type.") if in_dynamic_mode(): - return _legacy_C_ops.identity_loss(x, "reduction", reduction) + return _C_ops.identity_loss(x, reduction) check_variable_and_dtype(x, 'x', ['float32', 'float64'], "identity_loss") attrs = {'reduction': reduction}