Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Dec 21, 2021
1 parent 62a8cd9 commit 1fc3841
Showing 1 changed file with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,5 +379,34 @@ def test_raise_error(self):
net.forward.outputs


class CallNonForwardFuncNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncNet, self).__init__()
self.sub = CallNonForwardFuncSubNet()

@paddle.jit.to_static
def forward(self):
return self.sub.func()


class CallNonForwardFuncSubNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncSubNet, self).__init__()
self.a = paddle.to_tensor([1, 2])

def func(self):
x = self.a * 2
return x


class TestCallNonForwardFunc(unittest.TestCase):
def test_call_non_forward(self):
paddle.disable_static()
net = CallNonForwardFuncNet()
out = net()
self.assertEqual(out.numpy().tolist(), [2, 4])
paddle.enable_static()


if __name__ == '__main__':
unittest.main()

1 comment on commit 1fc3841

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.