From dd462092fa3add1b80993dc0b960e7ec564efdaa Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 24 Aug 2020 14:58:02 +0000 Subject: [PATCH] Add more unitest for grid sample API test=develop --- .../unittests/test_affine_grid_function.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_function.py b/python/paddle/fluid/tests/unittests/test_affine_grid_function.py index ecef9a3456e94..5cfab78fda988 100644 --- a/python/paddle/fluid/tests/unittests/test_affine_grid_function.py +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_function.py @@ -26,13 +26,17 @@ def __init__(self, theta_shape=(20, 2, 3), output_shape=[20, 2, 5, 7], align_corners=True, - dtype="float32"): + dtype="float32", + invalid_theta=False, + variable_output_shape=False): super(AffineGridTestCase, self).__init__(methodName) self.theta_shape = theta_shape self.output_shape = output_shape self.align_corners = align_corners self.dtype = dtype + self.invalid_theta = invalid_theta + self.variable_output_shape = variable_output_shape def setUp(self): self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype) @@ -70,9 +74,12 @@ def functional(self, place): return y_np def paddle_dygraph_layer(self): - theta_var = dg.to_variable(self.theta) + theta_var = dg.to_variable( + self.theta) if not self.invalid_theta else "invalid" + output_shape = dg.to_variable( + self.output_shape) if variable_output_shape else self.output_shape y_var = F.affine_grid( - theta_var, self.output_shape, align_corners=self.align_corners) + theta_var, output_shape, align_corners=self.align_corners) y_np = y_var.numpy() return y_np @@ -108,6 +115,9 @@ def add_cases(suite): suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True)) suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False)) + suite.addTest( + AffineGridTestCase( + methodName='runTest', variable_output_shape=True)) suite.addTest( AffineGridTestCase( @@ -121,6 +131,10 @@ def add_error_cases(suite): suite.addTest( AffineGridErrorTestCase( methodName='runTest', output_shape="not_valid")) + suite.addTest( + AffineGridErrorTestCase( + methodName='runTest', + invalid_theta=True)) # to test theta not variable error checking def load_tests(loader, standard_tests, pattern):