From 1fc384183451ddbb16c5feb362dcd6c030c0c282 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Tue, 21 Dec 2021 04:01:25 +0000 Subject: [PATCH] add unit test --- .../dygraph_to_static/test_declarative.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index 91086c31a396a..1c2ac34e1b501 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -379,5 +379,34 @@ def test_raise_error(self): net.forward.outputs +class CallNonForwardFuncNet(paddle.nn.Layer): + def __init__(self): + super(CallNonForwardFuncNet, self).__init__() + self.sub = CallNonForwardFuncSubNet() + + @paddle.jit.to_static + def forward(self): + return self.sub.func() + + +class CallNonForwardFuncSubNet(paddle.nn.Layer): + def __init__(self): + super(CallNonForwardFuncSubNet, self).__init__() + self.a = paddle.to_tensor([1, 2]) + + def func(self): + x = self.a * 2 + return x + + +class TestCallNonForwardFunc(unittest.TestCase): + def test_call_non_forward(self): + paddle.disable_static() + net = CallNonForwardFuncNet() + out = net() + self.assertEqual(out.numpy().tolist(), [2, 4]) + paddle.enable_static() + + if __name__ == '__main__': unittest.main()