Skip to content

Commit

Permalink
[Prim][PIR] prim pir support eval mode in jit (PaddlePaddle#58829)
Browse files Browse the repository at this point in the history
* prim pir support eval mode

* fix bug

* remove unused code
  • Loading branch information
cyber-pioneer authored Nov 9, 2023
1 parent 75e7845 commit 3d0cb2f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
5 changes: 2 additions & 3 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,8 @@ def _create_program(self, is_infer_mode=False):
infer_program = PirPassContext.apply(
infer_program, self._build_strategy
)
# TODO(Aurelius84): Support this later.
# if self._hooker:
# infer_program = self._hooker.after_infer(infer_program)
if self._hooker:
self._hooker.after_infer(infer_program)
return infer_program
else:
train_program = self.origin_runable_program.clone()
Expand Down
14 changes: 10 additions & 4 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,12 +1500,18 @@ def after_append_backward(self, whole_program, src_vars, forward_end_idx):
return whole_program, new_start_index, dst_vars
return whole_program, forward_end_idx, src_vars

def after_infer(self, infer_program, src_vars):
def after_infer(self, infer_program):
with backend_guard(self.backend):
if core._is_fwd_prim_enabled():
dst_vars = decomposition.decompose(infer_program, src_vars)
return infer_program, dst_vars
return infer_program, src_vars
targets = decomposition.decompose(
infer_program.program, infer_program.out_values
)
infer_program.out_values = targets
infer_program.forward_range = (
0,
len(infer_program.program.global_block().ops),
)
return


class ProgramCache:
Expand Down
53 changes: 47 additions & 6 deletions test/prim/pir_prim/test_prim_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from paddle.framework import core


def func(x):
x1 = paddle.mean(x)
out = paddle.nn.functional.gelu(x1, False)
return out


class TestDy2staticPir(unittest.TestCase):
def test_basic_network_backward(self):
core._set_prim_all_enabled(True)

def func(x):
x1 = paddle.mean(x)
out = paddle.nn.functional.gelu(x1, False)
return out

# ==== dygraph computation ====
static_func = paddle.jit.to_static(func, full_graph=True)
x = paddle.randn((8, 16, 64))
Expand All @@ -43,7 +44,16 @@ def func(x):
actual_out = out * 2
actual_out.backward()
actual_grad = x.grad
core._set_prim_all_enabled(True)

core._set_prim_all_enabled(False)
ops = [
op.name()
for op in static_func.program_cache.last()[-1][-1]
.train_program.program.global_block()
.ops
]
assert "pd_op.erf" in ops
assert "pd_op.gelu" not in ops

np.testing.assert_allclose(
ref_out, actual_out.numpy(), atol=1e-6, rtol=1e-6
Expand All @@ -54,5 +64,36 @@ def func(x):
)


class TestDy2staticPirEval(unittest.TestCase):
def test_basic_network_backward_(self):
core._set_prim_all_enabled(True)

# ==== dygraph computation ====
static_func = paddle.jit.to_static(func, full_graph=True)
static_func.eval()
x = paddle.randn((8, 16, 64))
x.stop_gradient = False
ref_out = func(x) * 2

# ==== to static compuatation ====
out = static_func(x)
actual_out = out * 2

ops = [
op.name()
for op in static_func.program_cache.last()[-1][-1]
.infer_program.program.global_block()
.ops
]
core._set_prim_all_enabled(False)

assert "pd_op.erf" in ops
assert "pd_op.gelu" not in ops

np.testing.assert_allclose(
ref_out, actual_out.numpy(), atol=1e-6, rtol=1e-6
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3d0cb2f

Please sign in to comment.