From 5fed7ad53be3656c5cd3cb1d556bbbc6deee5575 Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Fri, 10 Nov 2023 14:52:26 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.104=E3=80=91move?= =?UTF-8?q?=20fusion=5Frepeated=5Ffc=5Frelu/fusion=5Fsquared=5Fmat=5Fsub?= =?UTF-8?q?=20to=20phi=20(#58300)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fused/fusion_repeated_fc_relu_op.cc | 209 ------------------ .../fused/fusion_repeated_fc_relu_op.h | 38 ---- .../fused/fusion_squared_mat_sub_op.cc | 159 ------------- .../fused/fusion_squared_mat_sub_op.h | 39 ---- .../pir/dialect/op_generator/ops_api_gen.py | 2 + paddle/phi/api/yaml/fused_ops.yaml | 20 ++ paddle/phi/api/yaml/op_compat.yaml | 19 ++ paddle/phi/infermeta/fusion.cc | 130 +++++++++++ paddle/phi/infermeta/fusion.h | 14 ++ .../cpu/fusion_repeated_fc_relu_kernel.cc | 101 +++++++++ .../cpu/fusion_squared_mat_sub_kernel.cc | 89 ++++++++ .../test_fusion_repeated_fc_relu_op.py | 2 +- .../test_fusion_squared_mat_sub_op.py | 2 +- 13 files changed, 377 insertions(+), 447 deletions(-) delete mode 100644 paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc delete mode 100644 paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h delete mode 100644 paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc delete mode 100644 paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h create mode 100644 paddle/phi/kernels/fusion/cpu/fusion_repeated_fc_relu_kernel.cc create mode 100644 paddle/phi/kernels/fusion/cpu/fusion_squared_mat_sub_kernel.cc diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc deleted file mode 100644 index 8b88316645d8a..0000000000000 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ /dev/null @@ -1,209 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h" - -#include -#include - -#include "paddle/phi/kernels/funcs/jit/kernels.h" - -namespace paddle { -namespace operators { - -void FusionRepeatedFCReluOp::InferShape( - framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionRepeatedFCRelu"); - auto sz = ctx->Inputs("W").size(); - PADDLE_ENFORCE_GT(sz, - 1UL, - platform::errors::InvalidArgument( - "Inputs(W) of FusionRepeatedFCReluOp should " - "be greater than 1, but received value is %d.", - sz)); - PADDLE_ENFORCE_EQ( - ctx->Inputs("Bias").size(), - sz, - platform::errors::InvalidArgument( - "Size of inputs(Bias) of FusionRepeatedFCReluOp should be " - "equal to inputs size %d, but received value is %d.", - sz, - ctx->Inputs("Bias").size())); - PADDLE_ENFORCE_EQ( - ctx->Outputs("ReluOut").size(), - sz - 1, - platform::errors::InvalidArgument( - "Size of output(ReluOut) of FusionRepeatedFCReluOp should " - "be equal to inputs size minus one %d, but received value is %d", - sz - 1, - ctx->Outputs("ReluOut").size())); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "FusionRepeatedFCRelu"); - - auto i_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ( - i_dims.size(), - 2, - platform::errors::InvalidArgument( - "Input shape size should be 2, but received value is %d.", - i_dims.size())); - - auto w_dims = ctx->GetInputsDim("W"); - auto b_dims = ctx->GetInputsDim("Bias"); - PADDLE_ENFORCE_EQ(w_dims.size(), - b_dims.size(), - platform::errors::InvalidArgument( - "Shape size of weight and bias should be equal, but " - "weight size is %d, bias size is %d.", - w_dims.size(), - b_dims.size())); - PADDLE_ENFORCE_EQ(i_dims[1], - w_dims[0][0], - platform::errors::InvalidArgument( - "input width should be equal to weight height, but " - "input width is %d, weight height is %d.", - i_dims[1], - w_dims[0][0])); - - for (size_t i = 1; i < sz; ++i) { - PADDLE_ENFORCE_EQ(w_dims[i].size(), - 2, - platform::errors::InvalidArgument( - "Every weight shape size should be 2, but received " - "w_dims[%d].size() = %d.", - i, - w_dims[i].size())); - PADDLE_ENFORCE_EQ( - phi::product(b_dims[i]), - w_dims[i][1], - platform::errors::InvalidArgument( - "The length of Bias must be equal with w_dims[1], but received " - "product(b_dims[%d]) = %d, w_dims[%d][1] = %d.", - i, - phi::product(b_dims[i]), - i, - w_dims[i][1])); - } - ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]}); - ctx->ShareLoD("X", /*->*/ "Out"); -} - -phi::KernelKey FusionRepeatedFCReluOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -void FusionRepeatedFCReluOpMaker::Make() { - AddInput("X", "(phi::DenseTensor) Input tensors of this operator."); - AddInput("W", "(phi::DenseTensor) The weight tensors of this operator.") - .AsDuplicable(); - AddInput("Bias", "(phi::DenseTensor) The bias tensors of this operator.") - .AsDuplicable(); - AddOutput("ReluOut", - "(phi::DenseTensor) The output tensor of each relu operator.") - .AsDuplicable() - .AsIntermediate(); - AddOutput("Out", "(phi::DenseTensor) Output tensor of this operator."); - AddComment(R"DOC( - Fusion Repeated FC with Relu Operator. -)DOC"); -} - -template -static void fc_relu(const T* x, - const T* w, - const T* b, - T* y, - const phi::jit::matmul_attr_t& attr) { - auto matmul = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); - auto addbias_relu = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr.n); - matmul(x, w, y, &attr); - T* dst = y; - for (int i = 0; i < attr.m; ++i) { - addbias_relu(b, dst, dst, attr.n); - dst += attr.n; - } -} - -template -class FusionRepeatedFCReluKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto in = ctx.Input("X"); - auto weights = ctx.MultiInput("W"); - auto biases = ctx.MultiInput("Bias"); - auto relus = ctx.MultiOutput("ReluOut"); - auto* out = ctx.Output("Out"); - auto place = ctx.GetPlace(); - int weight_sz = static_cast(weights.size()); - - auto i_dims = phi::vectorize(in->dims()); - const auto& w_dims = weights[0]->dims(); - phi::jit::matmul_attr_t attr; - attr.m = i_dims[0]; - attr.n = static_cast(w_dims[1]); - attr.k = static_cast(w_dims[0]); - relus[0]->Resize({attr.m, attr.n}); - fc_relu(in->data(), - weights[0]->data(), - biases[0]->data(), - relus[0]->mutable_data(place), - attr); - - for (int i = 1; i < weight_sz - 1; ++i) { - const auto& i_dims = relus[i - 1]->dims(); - const auto& w_dims = weights[i]->dims(); - attr.m = static_cast(i_dims[0]); - attr.n = static_cast(w_dims[1]); - attr.k = static_cast(w_dims[0]); - relus[i]->Resize({attr.m, attr.n}); - fc_relu(relus[i - 1]->data(), - weights[i]->data(), - biases[i]->data(), - relus[i]->mutable_data(place), - attr); - } - - const auto& i_dims_last = relus[weight_sz - 2]->dims(); - const auto& w_dims_last = weights[weight_sz - 1]->dims(); - attr.m = static_cast(i_dims_last[0]); - attr.n = static_cast(w_dims_last[1]); - attr.k = static_cast(w_dims_last[0]); - fc_relu(relus[weight_sz - 2]->data(), - weights[weight_sz - 1]->data(), - biases[weight_sz - 1]->data(), - out->mutable_data(place), - attr); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_repeated_fc_relu, - ops::FusionRepeatedFCReluOp, - ops::FusionRepeatedFCReluOpMaker); - -PD_REGISTER_STRUCT_KERNEL(fusion_repeated_fc_relu, - CPU, - ALL_LAYOUT, - ops::FusionRepeatedFCReluKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h deleted file mode 100644 index 62eae8f7c0525..0000000000000 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionRepeatedFCReluOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionRepeatedFCReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc deleted file mode 100644 index c1d902754a4be..0000000000000 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h" - -#include -#include - -#include "paddle/phi/kernels/funcs/jit/kernels.h" - -namespace paddle { -namespace operators { - -void FusionSquaredMatSubOp::InferShape( - framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionSquaredMatSub"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusionSquaredMatSub"); - OP_INOUT_CHECK( - ctx->HasOutput("SquaredX"), "SquaredX", "Out", "FusionSquaredMatSub"); - OP_INOUT_CHECK( - ctx->HasOutput("SquaredY"), "SquaredY", "Out", "FusionSquaredMatSub"); - OP_INOUT_CHECK( - ctx->HasOutput("SquaredXY"), "SquaredXY", "Out", "FusionSquaredMatSub"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Out", "Out", "FusionSquaredMatSub"); - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ( - x_dims.size(), - y_dims.size(), - platform::errors::InvalidArgument("The input tensor X's dims size should " - "be equal to Y's. But received X's " - "dims size = %d, Y's dims size = %d.", - x_dims.size(), - y_dims.size())); - PADDLE_ENFORCE_EQ(x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The input tensor X's dims size should be 2. But " - "received X's dims size = %d.", - x_dims.size())); - PADDLE_ENFORCE_EQ( - x_dims[1], - y_dims[0], - platform::errors::InvalidArgument("The input tensor X's dims[1] should " - "be equal to Y's dims[0]. But received " - "X's dims[1] = %d, Y's dims[0] = %d.", - x_dims[1], - y_dims[0])); - ctx->SetOutputDim("SquaredX", x_dims); - ctx->SetOutputDim("SquaredY", y_dims); - ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]}); - ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]}); -} - -phi::KernelKey FusionSquaredMatSubOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -void FusionSquaredMatSubOpMaker::Make() { - AddInput("X", "(phi::DenseTensor) Input Mat A of this operator."); - AddInput("Y", "(phi::DenseTensor) Input Mat B of this operator."); - AddOutput("SquaredX", "(phi::DenseTensor) Squared X.").AsIntermediate(); - AddOutput("SquaredY", "(phi::DenseTensor) Squared Y.").AsIntermediate(); - AddOutput("SquaredXY", "(phi::DenseTensor) Squared X*Y.").AsIntermediate(); - AddOutput("Out", "(phi::DenseTensor) Output tensor of concat operator."); - AddAttr("scalar", "The scalar on output matrix.").SetDefault(1.f); - AddComment(R"DOC( - Fusion Squared Matrix and substrct operator. - - ( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar -)DOC"); -} - -template -class FusionSquaredMatSubKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto x = ctx.Input("X"); - auto y = ctx.Input("Y"); - auto* squared_x = ctx.Output("SquaredX"); - auto* squared_y = ctx.Output("SquaredY"); - auto* squared_xy = ctx.Output("SquaredXY"); - auto* out = ctx.Output("Out"); - auto place = ctx.GetPlace(); - T scalar = static_cast(ctx.Attr("scalar")); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - phi::jit::matmul_attr_t attr; - attr.m = static_cast(x_dims[0]); - attr.k = static_cast(x_dims[1]); - attr.n = static_cast(y_dims[1]); - int o_numel = attr.m * attr.n; - - auto vsquare_x = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr.m * attr.k); - auto vsquare_y = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr.k * attr.n); - auto vsquare_xy = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(o_numel); - auto vsub = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(o_numel); - auto vscal = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(o_numel); - auto matmul = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); - - const T* x_data = x->data(); - const T* y_data = y->data(); - T* squared_x_data = squared_x->mutable_data(place); - T* squared_y_data = squared_y->mutable_data(place); - T* squared_xy_data = squared_xy->mutable_data(place); - T* o_data = out->mutable_data(place); - - matmul(x_data, y_data, squared_xy_data, &attr); - vsquare_xy(squared_xy_data, squared_xy_data, o_numel); - - vsquare_x(x_data, squared_x_data, attr.m * attr.k); - vsquare_y(y_data, squared_y_data, attr.k * attr.n); - matmul(squared_x_data, squared_y_data, o_data, &attr); - - vsub(squared_xy_data, o_data, o_data, o_numel); - vscal(&scalar, o_data, o_data, o_numel); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_squared_mat_sub, - ops::FusionSquaredMatSubOp, - ops::FusionSquaredMatSubOpMaker); - -PD_REGISTER_STRUCT_KERNEL(fusion_squared_mat_sub, - CPU, - ALL_LAYOUT, - ops::FusionSquaredMatSubKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h deleted file mode 100644 index 41bde97c4bdb0..0000000000000 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar -class FusionSquaredMatSubOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionSquaredMatSubOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 6b6087afe5e97..f6d0d7ed8740f 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -84,6 +84,8 @@ 'fusion_gru', 'fusion_seqconv_eltadd_relu', 'fusion_seqexpand_concat_fc', + 'fusion_repeated_fc_relu', + 'fusion_squared_mat_sub', 'fused_attention', 'fused_feedforward', 'self_dp_attention', diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index bc2e3ecf64857..df49fe040812c 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -227,6 +227,16 @@ optional : h0, bias intermediate : reordered_h0, xx, batched_input, batched_out +- op : fusion_repeated_fc_relu + args : (Tensor x, Tensor[] w, Tensor[] bias) + output : Tensor[](relu_out){w.size()-1}, Tensor(out) + infer_meta : + func : FusionRepeatedFCReluInferMeta + kernel : + func : fusion_repeated_fc_relu + data_type: x + intermediate : relu_out + - op : fusion_seqconv_eltadd_relu args : (Tensor x, Tensor filter, Tensor bias, int context_length, int context_start = 0, int context_stride = 1) output : Tensor(out), Tensor(col_mat) @@ -248,6 +258,16 @@ optional : fc_bias intermediate : fc_out +- op : fusion_squared_mat_sub + args : (Tensor x, Tensor y, float scalar = 1.0f) + output : Tensor(squared_x), Tensor(squared_y), Tensor(squared_xy), Tensor(out) + infer_meta : + func : FusionSquaredMatSubInferMeta + kernel : + func : fusion_squared_mat_sub + data_type : x + intermediate : squared_x, squared_y, squared_xy + - op : fusion_transpose_flatten_concat args : (Tensor[] x, int[] trans_axis, int flatten_axis, int concat_axis) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index b78b12386d049..7b07c8ef6b240 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1343,6 +1343,15 @@ shift_data : Shift_data scale_weights : Scale_weights +- op : fusion_repeated_fc_relu + inputs : + x : X + w : W + bias : Bias + outputs : + relu_out : ReluOut + out : Out + - op : fusion_seqconv_eltadd_relu inputs : x : X @@ -3312,6 +3321,16 @@ outputs: {out: Out} +- op: fusion_squared_mat_sub + inputs : + x : X + y : Y + outputs : + squared_x : SquaredX + squared_y : SquaredY + squared_xy : SquaredXY + out : Out + - op: get_tensor_from_selected_rows inputs : x : X diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index bcc411c40aa1a..ae6c9d367a320 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2175,6 +2175,136 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, out->share_lod(x); } +void FusionRepeatedFCReluInferMeta(const MetaTensor& x, + const std::vector& w, + const std::vector& bias, + std::vector relu_out, + MetaTensor* out) { + auto sz = w.size(); + PADDLE_ENFORCE_GT(sz, + 1UL, + phi::errors::InvalidArgument( + "Inputs(W) of FusionRepeatedFCReluOp should " + "be greater than 1, but received value is %d.", + sz)); + PADDLE_ENFORCE_EQ( + bias.size(), + sz, + phi::errors::InvalidArgument( + "Size of inputs(Bias) of FusionRepeatedFCReluOp should be " + "equal to inputs size %d, but received value is %d.", + sz, + bias.size())); + PADDLE_ENFORCE_EQ( + relu_out.size(), + sz - 1, + phi::errors::InvalidArgument( + "Size of output(ReluOut) of FusionRepeatedFCReluOp should " + "be equal to inputs size minus one %d, but received value is %d", + sz - 1, + relu_out.size())); + + auto i_dims = x.dims(); + PADDLE_ENFORCE_EQ( + i_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input shape size should be 2, but received value is %d.", + i_dims.size())); + + std::vector w_dims, b_dims; + w_dims.reserve(w.size()); + std::transform(w.begin(), + w.end(), + std::back_inserter(w_dims), + [](const MetaTensor* var) { return var->dims(); }); + + b_dims.reserve(bias.size()); + std::transform(bias.begin(), + bias.end(), + std::back_inserter(b_dims), + [](const MetaTensor* var) { return var->dims(); }); + + PADDLE_ENFORCE_EQ(w_dims.size(), + b_dims.size(), + phi::errors::InvalidArgument( + "Shape size of weight and bias should be equal, but " + "weight size is %d, bias size is %d.", + w_dims.size(), + b_dims.size())); + PADDLE_ENFORCE_EQ(i_dims[1], + w_dims[0][0], + phi::errors::InvalidArgument( + "input width should be equal to weight height, but " + "input width is %d, weight height is %d.", + i_dims[1], + w_dims[0][0])); + + for (size_t i = 1; i < sz; ++i) { + PADDLE_ENFORCE_EQ(w_dims[i].size(), + 2, + phi::errors::InvalidArgument( + "Every weight shape size should be 2, but received " + "w_dims[%d].size() = %d.", + i, + w_dims[i].size())); + PADDLE_ENFORCE_EQ( + phi::product(b_dims[i]), + w_dims[i][1], + phi::errors::InvalidArgument( + "The length of Bias must be equal with w_dims[1], but received " + "product(b_dims[%d]) = %d, w_dims[%d][1] = %d.", + i, + phi::product(b_dims[i]), + i, + w_dims[i][1])); + } + out->set_dims({i_dims[0], w_dims[sz - 1][1]}); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + +void FusionSquaredMatSubInferMeta(const MetaTensor& x, + const MetaTensor& y, + const float scalar, + MetaTensor* squared_x, + MetaTensor* squared_y, + MetaTensor* squared_xy, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + y_dims.size(), + phi::errors::InvalidArgument("The input tensor X's dims size should " + "be equal to Y's. But received X's " + "dims size = %d, Y's dims size = %d.", + x_dims.size(), + y_dims.size())); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The input tensor X's dims size should be 2. But " + "received X's dims size = %d.", + x_dims.size())); + PADDLE_ENFORCE_EQ( + x_dims[1], + y_dims[0], + phi::errors::InvalidArgument("The input tensor X's dims[1] should " + "be equal to Y's dims[0]. But received " + "X's dims[1] = %d, Y's dims[0] = %d.", + x_dims[1], + y_dims[0])); + squared_x->set_dims(x_dims); + squared_x->set_dtype(x.dtype()); + squared_y->set_dims(y_dims); + squared_y->set_dtype(x.dtype()); + squared_xy->set_dims({x_dims[0], y_dims[1]}); + squared_xy->set_dtype(x.dtype()); + out->set_dims({x_dims[0], y_dims[1]}); + out->set_dtype(x.dtype()); +} + void FusionGRUInferMeta(const MetaTensor& x, const MetaTensor& h0, const MetaTensor& weight_x, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index c5bf412c715cb..ff171766fe70a 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -519,6 +519,20 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, MetaTensor* variance, MetaConfig config = MetaConfig()); +void FusionRepeatedFCReluInferMeta(const MetaTensor& x, + const std::vector& w, + const std::vector& bias, + std::vector relu_out, + MetaTensor* out); + +void FusionSquaredMatSubInferMeta(const MetaTensor& x, + const MetaTensor& y, + const float scalar, + MetaTensor* squared_x, + MetaTensor* squared_y, + MetaTensor* squared_xy, + MetaTensor* out); + void FusionGRUInferMeta(const MetaTensor& x, const MetaTensor& h0, const MetaTensor& weight_x, diff --git a/paddle/phi/kernels/fusion/cpu/fusion_repeated_fc_relu_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_repeated_fc_relu_kernel.cc new file mode 100644 index 0000000000000..b65cf71bf9385 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_repeated_fc_relu_kernel.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" + +namespace phi { +namespace fusion { + +template +static void fc_relu(const T* x, + const T* w, + const T* b, + T* y, + const phi::jit::matmul_attr_t& attr) { + auto matmul = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() + .At(attr); + auto addbias_relu = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() + .At(attr.n); + matmul(x, w, y, &attr); + T* dst = y; + for (int i = 0; i < attr.m; ++i) { + addbias_relu(b, dst, dst, attr.n); + dst += attr.n; + } +} + +template +void FusionRepeatedFCReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& w, + const std::vector& bias, + std::vector relu_out, + DenseTensor* out) { + int weight_sz = static_cast(w.size()); + + auto i_dims = phi::vectorize(x.dims()); + const auto& w_dims = w[0]->dims(); + phi::jit::matmul_attr_t attr; + attr.m = i_dims[0]; + attr.n = static_cast(w_dims[1]); + attr.k = static_cast(w_dims[0]); + relu_out[0]->Resize({attr.m, attr.n}); + auto* relu_out_temp = dev_ctx.template Alloc(relu_out[0]); + fc_relu( + x.data(), w[0]->data(), bias[0]->data(), relu_out_temp, attr); + + for (int i = 1; i < weight_sz - 1; ++i) { + const auto& i_dims = relu_out[i - 1]->dims(); + const auto& w_dims = w[i]->dims(); + attr.m = static_cast(i_dims[0]); + attr.n = static_cast(w_dims[1]); + attr.k = static_cast(w_dims[0]); + relu_out[i]->Resize({attr.m, attr.n}); + auto* relu_out_tmp = dev_ctx.template Alloc(relu_out[i]); + fc_relu(relu_out[i - 1]->data(), + w[i]->data(), + bias[i]->data(), + relu_out_tmp, + attr); + } + + const auto& i_dims_last = relu_out[weight_sz - 2]->dims(); + const auto& w_dims_last = w[weight_sz - 1]->dims(); + attr.m = static_cast(i_dims_last[0]); + attr.n = static_cast(w_dims_last[1]); + attr.k = static_cast(w_dims_last[0]); + auto* out_data = dev_ctx.template Alloc(out); + fc_relu(relu_out[weight_sz - 2]->data(), + w[weight_sz - 1]->data(), + bias[weight_sz - 1]->data(), + out_data, + attr); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_repeated_fc_relu, + CPU, + ALL_LAYOUT, + phi::fusion::FusionRepeatedFCReluKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fusion/cpu/fusion_squared_mat_sub_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_squared_mat_sub_kernel.cc new file mode 100644 index 0000000000000..4f44364577468 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_squared_mat_sub_kernel.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" + +namespace phi { +namespace fusion { + +template +void FusionSquaredMatSubKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const float scalar, + DenseTensor* squared_x, + DenseTensor* squared_y, + DenseTensor* squared_xy, + DenseTensor* out) { + T scalar_t = static_cast(scalar); + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + phi::jit::matmul_attr_t attr; + attr.m = static_cast(x_dims[0]); + attr.k = static_cast(x_dims[1]); + attr.n = static_cast(y_dims[1]); + int o_numel = attr.m * attr.n; + + auto vsquare_x = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() + .At(attr.m * attr.k); + auto vsquare_y = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() + .At(attr.k * attr.n); + auto vsquare_xy = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() + .At(o_numel); + auto vsub = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache().At( + o_numel); + auto vscal = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache().At( + o_numel); + auto matmul = + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() + .At(attr); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* squared_x_data = dev_ctx.template Alloc(squared_x); + T* squared_y_data = dev_ctx.template Alloc(squared_y); + T* squared_xy_data = dev_ctx.template Alloc(squared_xy); + T* o_data = dev_ctx.template Alloc(out); + + matmul(x_data, y_data, squared_xy_data, &attr); + vsquare_xy(squared_xy_data, squared_xy_data, o_numel); + + vsquare_x(x_data, squared_x_data, attr.m * attr.k); + vsquare_y(y_data, squared_y_data, attr.k * attr.n); + matmul(squared_x_data, squared_y_data, o_data, &attr); + + vsub(squared_xy_data, o_data, o_data, o_numel); + vscal(&scalar_t, o_data, o_data, o_numel); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_squared_mat_sub, + CPU, + ALL_LAYOUT, + phi::fusion::FusionSquaredMatSubKernel, + float, + double) {} diff --git a/test/legacy_test/test_fusion_repeated_fc_relu_op.py b/test/legacy_test/test_fusion_repeated_fc_relu_op.py index b57cdfc1d4eb2..52c8852e2ddcd 100644 --- a/test/legacy_test/test_fusion_repeated_fc_relu_op.py +++ b/test/legacy_test/test_fusion_repeated_fc_relu_op.py @@ -78,7 +78,7 @@ def setUp(self): self.outputs = {'Out': outs[-1], 'ReluOut': relu_outs} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) def set_conf(self): pass diff --git a/test/legacy_test/test_fusion_squared_mat_sub_op.py b/test/legacy_test/test_fusion_squared_mat_sub_op.py index a9c692e538261..c8e0c67523932 100644 --- a/test/legacy_test/test_fusion_squared_mat_sub_op.py +++ b/test/legacy_test/test_fusion_squared_mat_sub_op.py @@ -42,7 +42,7 @@ def set_conf(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestFusionSquaredMatSubOpCase1(TestFusionSquaredMatSubOp):