Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composite hardswish #51003

Merged
merged 14 commits into from
Mar 2, 2023
29 changes: 27 additions & 2 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,6 +2150,7 @@ def setUp(self):
self.op_type = 'hard_swish'
self.init_dtype()
self.init_shape()
self.prim_op_type = "comp"
self.python_api = paddle.nn.functional.hardswish

np.random.seed(1024)
Expand All @@ -2165,18 +2166,42 @@ def setUp(self):
self.inputs = {'X': x}
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.outputs = {'Out': out}
self.enable_cinn = False

def init_shape(self):
self.shape = [10, 12]

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)

def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=True, check_prim=True)


class TestHardSwish_ZeroDim(TestHardSwish):
sljlp marked this conversation as resolved.
Show resolved Hide resolved
def setUp(self):
super().setUp()
self.enable_cinn = False

def init_shape(self):
self.shape = []


class TestHardSwishFP16(TestHardSwish):
def setUp(self):
super().setUp()
self.only_prim = True
self.enable_cinn = False

def init_dtype(self):
self.dtype = np.float16


class TestHardSwish_ZeroDim_FP16(TestHardSwishFP16):
sljlp marked this conversation as resolved.
Show resolved Hide resolved
def setUp(self):
super().setUp()
self.enable_cinn = False

def init_shape(self):
self.shape = []

Expand Down
25 changes: 25 additions & 0 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,31 @@ def bernoulli(shape, dtype, p, seed=0):
)


@REGISTER_COMPOSITE('hard_swish')
def hard_swish_composite(x):
"""define composite rule of op hard_swish.
offset=3, threshold=6, scale=6
out = minimum(
maxmum(x + offset, 0), threshold
) * x / scale
"""
offset = 3.0
sljlp marked this conversation as resolved.
Show resolved Hide resolved
threshold = 6.0
scale = 6.0
res = (
minimum(
maximum(
x + full(x.shape, offset, dtype=x.dtype),
full(x.shape, 0.0, dtype=x.dtype),
),
full(x.shape, threshold, dtype=x.dtype),
)
* x
/ full(x.shape, scale, dtype=x.dtype)
)
return res


@REGISTER_COMPOSITE('silu')
def silu_composite(x):
"""
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/incubate/autograd/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401
from paddle.tensor.manipulation import cast # noqa: F401
from paddle.tensor.math import maximum # noqa: F401
from paddle.tensor.math import minimum # noqa: F401

"""
math_op = [
Expand Down Expand Up @@ -87,6 +89,8 @@
'logit',
'max',
'min',
'minimum',
'maximum'
]

trigonometric_op = [
Expand Down