Skip to content

Commit

Permalink
Adjust the relative error of QR's grad (#42221)
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong authored Apr 27, 2022
1 parent acca035 commit 4c80385
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/paddle/fluid/tests/unittests/test_qr_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class TestQrOp(OpTest):
def setUp(self):
paddle.enable_static()
np.random.seed(4)
np.random.seed(7)
self.op_type = "qr"
a, q, r = self.get_input_and_output()
self.inputs = {"X": a}
Expand Down Expand Up @@ -74,7 +74,8 @@ def test_check_output(self):
self.check_output()

def test_check_grad_normal(self):
self.check_grad(['X'], ['Q', 'R'])
self.check_grad(
['X'], ['Q', 'R'], numeric_grad_delta=1e-5, max_relative_error=1e-6)


class TestQrOpCase1(TestQrOp):
Expand Down Expand Up @@ -116,6 +117,7 @@ def get_shape(self):
class TestQrAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
np.random.seed(7)

def run_qr_dygraph(shape, mode, dtype):
if dtype == "float32":
Expand Down Expand Up @@ -180,6 +182,7 @@ def run_qr_dygraph(shape, mode, dtype):

def test_static(self):
paddle.enable_static()
np.random.seed(7)

def run_qr_static(shape, mode, dtype):
if dtype == "float32":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
'matrix_power', \
'cholesky_solve', \
'solve', \
'qr', \
]

NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
Expand Down

0 comments on commit 4c80385

Please sign in to comment.