Skip to content

Commit

Permalink
[CINN]Add sigmoid convert pass (PaddlePaddle#63733)
Browse files Browse the repository at this point in the history
* add sigmoid convert to cinn pass

* add sigmoid infer symbolic
  • Loading branch information
phlrain authored and co63oc committed Apr 25, 2024
1 parent d716967 commit c8db135
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 0 deletions.
46 changes: 46 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,51 @@ class UnsqueezeOpPattern
}
};

class SigmoidOpPattern
: public pir::OpRewritePattern<paddle::dialect::SigmoidOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SigmoidOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::SigmoidOp op,
pir::PatternRewriter &rewriter) const override {
auto input_dtype = paddle::dialect::TransToPhiDataType(
op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype());

auto in = op->operand_source(0);
bool need_cast = (input_dtype == phi::DataType::FLOAT16 ||
input_dtype == phi::DataType::BFLOAT16 ||
input_dtype == phi::DataType::UINT16);
if (need_cast) {
in = rewriter.Build<paddle::dialect::CastOp>(in, phi::DataType::FLOAT32)
.result(0);
}

// 1 / ( 1 + exp(-x))
auto one = rewriter
.Build<paddle::dialect::FullOp>(
std::vector<int64_t>({1}), 1.0, phi::DataType::FLOAT32)
.result(0);
auto minus_x =
rewriter.Build<paddle::dialect::ScaleOp>(in, -1.0, 0.0).result(0);
auto exp = rewriter.Build<paddle::dialect::ExpOp>(minus_x).result(0);
auto add_exp = rewriter.Build<paddle::dialect::AddOp>(one, exp).result(0);
auto div =
rewriter.Build<paddle::dialect::DivideOp>(one, add_exp).result(0);

if (need_cast) {
div = rewriter.Build<paddle::dialect::CastOp>(div, input_dtype).result(0);
}

rewriter.ReplaceAllUsesWith(op.result(0), div);

rewriter.EraseOp(op);

return true;
}
};
class GatherOpPattern
: public pir::OpRewritePattern<paddle::dialect::GatherOp> {
public:
Expand Down Expand Up @@ -948,6 +993,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<RefreshCombineOpPattern>(context);
ps.Add<SqueezeOpPattern>(context);
ps.Add<UnsqueezeOpPattern>(context);
ps.Add<SigmoidOpPattern>(context);
ps.Add<GatherOpPattern>(context);

return ps;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ OP_SAME_OPERANDS_AND_RESULT(Scale_)
OP_SAME_OPERANDS_AND_RESULT(ScatterNdAdd)
OP_SAME_OPERANDS_AND_RESULT(Scatter)
OP_SAME_OPERANDS_AND_RESULT(Scatter_)
OP_SAME_OPERANDS_AND_RESULT(Sigmoid)
OP_SAME_OPERANDS_AND_RESULT(Sigmoid_)
OP_SAME_OPERANDS_AND_RESULT(Sign)
OP_SAME_OPERANDS_AND_RESULT(Sin)
OP_SAME_OPERANDS_AND_RESULT(Sin_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sigmoid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sigmoid_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
Expand Down
76 changes: 76 additions & 0 deletions test/ir/pir/cinn/symbolic/test_dyshape_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import unittest
from os.path import dirname

import numpy as np

os.environ["FLAGS_prim_forward_blacklist"] = "pd_op.sigmoid"

import paddle
from paddle import nn
from paddle.static import InputSpec

sys.path.append(dirname(dirname(__file__)))

import utils


class CastLayer(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
return x.sigmoid()


class TestCast(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.shape = [1024, 32, 1024, 17]
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = True

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval(self, use_cinn):
net = CastLayer()
input_spec = [
InputSpec(shape=[None, 32, None, None], dtype='float32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
cinn_out = self.eval(use_cinn=True)
dy_out = self.eval(use_cinn=False)
np.testing.assert_allclose(
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
)


if __name__ == '__main__':
unittest.main()

0 comments on commit c8db135

Please sign in to comment.