Skip to content

Commit

Permalink
[Prim][PIR] gelu forward sink (PaddlePaddle#58981)
Browse files Browse the repository at this point in the history
* prim gelu op sink

* prim gelu op sink

* update code

* pir gelu sink c++

* pir gelu sink c++

* process accuracy

* adapter windows

* adapter windows

* adapter windows
  • Loading branch information
kevincheng2 authored and SecretXV committed Nov 28, 2023
1 parent ff98e2b commit 32559a1
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"relu",
"softmax",
"layer_norm",
"gelu",
]

# come into effect in generated file op_decomp.cc
Expand All @@ -36,6 +37,7 @@
"relu",
"softmax",
"layer_norm",
"gelu",
]


Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,35 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
return std::make_tuple(out, mean_, variance);
}

template <typename T>
Tensor gelu_decomp(const Tensor& x, bool approximate) {
const double PM_2_SQRTPI = 1.12837916709551257390; /* 2/sqrt(pi) */
const double PM_SQRT1_2 = 0.70710678118654752440; /* 1/sqrt(2) */

auto org_dtype = x.dtype();
auto half = full<T>(phi::vectorize(x.dims()), 0.5, org_dtype);
auto one = full<T>(phi::vectorize(x.dims()), 1.0, org_dtype);
if (approximate) {
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
auto kAlpha =
full<T>(phi::vectorize(x.dims()), PM_2_SQRTPI * PM_SQRT1_2, org_dtype);
auto GELU_CONSTANT = full<T>(phi::vectorize(x.dims()), 0.044715, org_dtype);
auto x_pow3 =
elementwise_pow<T>(x, full<T>(phi::vectorize(x.dims()), 3, org_dtype));
auto tanh_out = tanh<T>(kAlpha * (x + x_pow3 * GELU_CONSTANT));

auto res = x * half * (one + tanh_out);
return res;
} else {
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
auto M_SQRT1_2T = full<T>(phi::vectorize(x.dims()), PM_SQRT1_2, org_dtype);
auto erf_out = one + erf<T>(x * M_SQRT1_2T);

auto res = x * half * erf_out;
return res;
}
}

} // namespace details

} // namespace primitive
Expand Down
26 changes: 0 additions & 26 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _pir_ops

from .primitives import * # noqa: F403
from .register import register_decomp
Expand All @@ -37,31 +36,6 @@ def mean(x, axis, keepdim):
return res


@register_decomp('pd_op.gelu')
def gelu(x, approximate):
"""define composite rule of op gelu"""
M_SQRT1_2 = (
0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc
)
M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */
full_shape = x.shape if len(x.shape) == 0 else [1]
one = ones(full_shape, x.dtype)
half = full(full_shape, 0.5, x.dtype)
# Todo(cz): after symbol overload, add and multiply will be replaced by "+" and "*"
if approximate:
# gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(full_shape, 0.044715, x.dtype)
tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x))
out = x * half * (one + tanh_out)
return out
else:
# gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype)))
out = x * cdf
return out


@register_decomp('pd_op.sqrt')
def sqrt(x):
"""
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,6 +2691,8 @@ def setUp(self):
self.public_python_api = paddle.nn.functional.gelu
self.init_dtype()
self.init_shape()
# Todo: Under float64, only this accuracy is currently supported, for further processing
self.fw_comp_rtol = 1e-7
approximate = False
np.random.seed(2048)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
Expand All @@ -2713,7 +2715,7 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True, check_prim_pir=False)
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
Expand Down
45 changes: 45 additions & 0 deletions test/prim/pir_prim/test_sink_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,50 @@ def test_relu_forward(self):
np.testing.assert_equal(ref, actual)


class TestGeluSink(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [8, 16, 32, 64]
self.x = np.random.random(self.shape_x).astype("float32")
self.prog = None

def base_net(self, approximate=True, flag=None):
if flag == "forward":
core._set_prim_forward_enabled(True)
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data('x', self.shape_x, dtype='float32')
x.stop_gradient = False
sum_out = F.gelu(x, approximate=approximate)
[new_out] = decompose(main_program, [sum_out])
gradients = grad(new_out, x)

exe = paddle.static.Executor()
[fwd, dx] = exe.run(
feed={'x': self.x}, fetch_list=[new_out, gradients]
)

whole_ops = [op.name() for op in main_program.global_block().ops]
self.prog = main_program
if flag == "forward":
core._set_prim_forward_enabled(False)
assert 'pd_op.gelu' not in whole_ops
else:
assert 'pd_op.gelu' in whole_ops
return fwd, dx

def test_gelu_forward_true(self):
res_ref = self.base_net(approximate=True)
res = self.base_net(approximate=True, flag="forward")
for ref, actual in zip(res_ref, res):
np.testing.assert_allclose(ref, actual, rtol=1e-6)

def test_gelu_approximate_false(self):
res_ref = self.base_net(approximate=False)
res = self.base_net(approximate=False, flag="forward")
for ref, actual in zip(res_ref, res):
np.testing.assert_allclose(ref, actual, rtol=1e-6)


if __name__ == "__main__":
unittest.main()

0 comments on commit 32559a1

Please sign in to comment.