Skip to content

Commit

Permalink
add more complex test
Browse files Browse the repository at this point in the history
  • Loading branch information
GGBond8488 committed May 7, 2023
1 parent 0dc471f commit dbcb051
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 48 deletions.
5 changes: 5 additions & 0 deletions python/paddle/fluid/tests/unittests/eager_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2553,6 +2553,11 @@ def check_grad_with_place(
max_relative_error = (
0.001 if max_relative_error < 0.001 else max_relative_error
)
if self.dtype in [np.complex128, np.complex64]:
print("numeric_grads:")
print(numeric_grads)
print("analytic_grads:")
print(analytic_grads)
self._assert_is_close(
numeric_grads,
analytic_grads,
Expand Down
9 changes: 1 addition & 8 deletions python/paddle/fluid/tests/unittests/test_complex_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ def setUp(self):
paddle.enable_static()
self.python_api = paddle.abs
self.op_type = "abs"
self.dtype = np.float64
self.dtype = np.complex128
self.shape = (2, 3, 4, 5)
self.init_input_output()
self.init_grad_input_output()

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)}
self.outputs = {'Out': self.out}
Expand All @@ -40,19 +39,13 @@ def init_input_output(self):
) + 1j * np.random.random(self.shape).astype(self.dtype)
self.out = np.abs(self.x)

def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype)
self.grad_x = self.grad_out * (self.x / np.abs(self.x))

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
)


Expand Down
27 changes: 2 additions & 25 deletions python/paddle/fluid/tests/unittests/test_cumprod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def prepare_inputs_outputs_attrs(self, dim, zero_num):
self.x = (
np.random.uniform(0.0, 0.5, self.shape).astype(self.val_dtype) + 0.5
)
if self.dtype in [np.complex128, np.complex64]:
self.x = self.x + np.random.uniform(0.0, 0.5, self.shape) * 1j
if zero_num > 0:
zero_num = min(zero_num, self.x.size)
shape = self.x.shape
Expand All @@ -97,28 +99,6 @@ def prepare_inputs_outputs_attrs(self, dim, zero_num):
self.outputs = {'Out': self.out}
self.attrs = {'dim': dim}

def init_grad_input_output(self, dim):
reshape_x = self.x.reshape(self.x.size)
self.grad_out = np.ones(self.x.size, self.val_dtype)
self.grad_x = np.zeros(self.x.size, self.val_dtype)
out_data = self.out.reshape(self.x.size)
if self.dtype == np.complex128 or self.dtype == np.complex64:
reshape_x = np.conj(reshape_x)
out_data = np.conj(out_data)
cumprod_grad(
reshape_x, out_data, self.grad_out, self.grad_x, self.shape, dim
)
if self.dtype == np.uint16:
self.grad_x = convert_float_to_uint16(
self.grad_x.reshape(self.shape)
)
self.grad_out = convert_float_to_uint16(
self.grad_out.reshape(self.shape)
)
else:
self.grad_x = self.grad_x.reshape(self.shape)
self.grad_out = self.grad_out.reshape(self.shape)

# test forward.
def test_check_output(self):
for dim in range(-len(self.shape), len(self.shape)):
Expand All @@ -131,15 +111,12 @@ def test_check_grad(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.init_grad_input_output(dim)
if self.dtype == np.float64:
self.check_grad(['X'], 'Out')
else:
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
)


Expand Down
30 changes: 15 additions & 15 deletions python/paddle/fluid/tests/unittests/test_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def setUp(self):
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
# self.init_grad_input_output()

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
Expand All @@ -198,7 +198,7 @@ def setUp(self):
self.outputs = {'Out': self.out}

def init_base_dtype(self):
self.dtype = np.float64
self.dtype = np.complex64

def init_input_output(self):
self.x = np.random.random(100).astype(
Expand All @@ -221,26 +221,26 @@ def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out],
# user_defined_grads=[self.grad_x, self.grad_y],
# user_defined_grad_outputs=[self.grad_out],
)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out],
# user_defined_grads=[self.grad_y],
# user_defined_grad_outputs=[self.grad_out],
)

def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
# user_defined_grads=[self.grad_x],
# user_defined_grad_outputs=[self.grad_out],
)


Expand All @@ -250,7 +250,7 @@ def setUp(self):
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
# self.init_grad_input_output()

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
Expand Down Expand Up @@ -288,26 +288,26 @@ def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out],
# user_defined_grads=[self.grad_x, self.grad_y],
# user_defined_grad_outputs=[self.grad_out],
)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out],
# user_defined_grads=[self.grad_y],
# user_defined_grad_outputs=[self.grad_out],
)

def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
# user_defined_grads=[self.grad_x],
# user_defined_grad_outputs=[self.grad_out],
)


Expand Down

0 comments on commit dbcb051

Please sign in to comment.