From bb9cae0a5939b6f2804b7b5c396db3edc5aa9b82 Mon Sep 17 00:00:00 2001 From: wangxinxin08 Date: Fri, 9 Oct 2020 14:13:14 +0000 Subject: [PATCH] modify code according to review --- paddle/fluid/operators/matmul_op.cc | 115 +++++++++++------- .../fluid/tests/unittests/test_nn_grad.py | 24 ++-- 2 files changed, 85 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 24eeefe942a95..5a167adde41f0 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -403,8 +403,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel { auto x = *context.Input("X"); auto y = *context.Input("Y"); auto dout = *context.Input("DOut"); - auto ddx = *context.Input("DDX"); - auto ddy = *context.Input("DDY"); + auto *ddx = context.Input("DDX"); + auto *ddy = context.Input("DDY"); auto *dx = context.Output("DX"); auto *dy = context.Output("DY"); @@ -440,47 +440,67 @@ class MatMulDoubleGradKernel : public framework::OpKernel { } bool ddout_flag = false; - if (dy) { - if (transpose_x && transpose_y) { - // dy = dout' * ddx' - CalcInputGrad(context, dout, true, true, ddx, true, false, false, dy); - } else if (transpose_x) { - // dy = ddx * dout - CalcInputGrad(context, ddx, false, false, dout, false, true, false, dy); - } else if (transpose_y) { - // dy = dout' * ddx - CalcInputGrad(context, dout, true, true, ddx, false, true, false, dy); - } else { - // dy = ddx' * dout - CalcInputGrad(context, ddx, true, true, dout, false, true, false, dy); + if (ddx) { + auto ddx_mat = *ddx; + if (ddx_mat.dims() != x.dims()) { + ddx_mat.Resize(x.dims()); + } + if (dy) { + if (transpose_x && transpose_y) { + // dy = dout' * ddx' + CalcInputGrad(context, dout, true, true, ddx_mat, true, false, false, + dy); + } else if (transpose_x) { + // dy = ddx * dout + CalcInputGrad(context, ddx_mat, false, false, dout, false, true, + false, dy); + } else if (transpose_y) { + // dy = dout' * ddx + CalcInputGrad(context, dout, true, true, ddx_mat, false, true, false, + dy); + } else { + // dy = ddx' * dout + CalcInputGrad(context, ddx_mat, true, true, dout, false, true, false, + dy); + } } - } - if (ddout) { - CalcInputGrad(context, ddx, transpose_x, true, y, transpose_y, false, - ddout_flag, ddout); - ddout_flag = true; + if (ddout) { + CalcInputGrad(context, ddx_mat, transpose_x, true, y, transpose_y, + false, ddout_flag, ddout); + ddout_flag = true; + } } - if (dx) { - if (transpose_x && transpose_y) { - // dx = ddy' * dout' - CalcInputGrad(context, ddy, true, true, dout, true, false, false, dx); - } else if (transpose_x) { - // dx = ddy * dout' - CalcInputGrad(context, ddy, false, false, dout, true, false, false, dx); - } else if (transpose_y) { - // dx = dout * ddy - CalcInputGrad(context, dout, false, false, ddy, false, true, false, dx); - } else { - // dx = dout * ddy' - CalcInputGrad(context, dout, false, false, ddy, true, false, false, dx); + if (ddy) { + auto ddy_mat = *ddy; + if (ddy_mat.dims() != y.dims()) { + ddy_mat.Resize(y.dims()); + } + if (dx) { + if (transpose_x && transpose_y) { + // dx = ddy' * dout' + CalcInputGrad(context, ddy_mat, true, true, dout, true, false, false, + dx); + } else if (transpose_x) { + // dx = ddy * dout' + CalcInputGrad(context, ddy_mat, false, false, dout, true, false, + false, dx); + } else if (transpose_y) { + // dx = dout * ddy + CalcInputGrad(context, dout, false, false, ddy_mat, false, true, + false, dx); + } else { + // dx = dout * ddy' + CalcInputGrad(context, dout, false, false, ddy_mat, true, false, + false, dx); + } } - } - if (ddout) { - CalcInputGrad(context, x, transpose_x, true, ddy, transpose_y, false, - ddout_flag, ddout); + if (ddout) { + CalcInputGrad(context, x, transpose_x, true, ddy_mat, transpose_y, + false, ddout_flag, ddout); + } } if (dx) { @@ -813,15 +833,16 @@ class MatMulOpDoubleGrad : public framework::OperatorWithKernel { OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul"); OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul"); - if (context->HasOutput("DX")) { + if (context->HasOutput("DX") && context->HasInput("DDY")) { context->ShareDim("X", "DX"); } - if (context->HasOutput("DY")) { + if (context->HasOutput("DY") && context->HasInput("DDX")) { context->ShareDim("Y", "DY"); } - if (context->HasOutput("DDOut")) { + if (context->HasOutput("DDOut") && + (context->HasInput("DDY") || context->HasInput("DDX"))) { context->ShareDim("DOut", "DDOut"); } } @@ -841,9 +862,17 @@ class MatMulOpDoubleGradMaker : public framework::SingleGradOpMaker { retv->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); retv->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y"))); - retv->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); - retv->SetOutput("DX", this->InputGrad("X")); - retv->SetOutput("DY", this->InputGrad("Y")); + auto ddx = this->OutputGrad(framework::GradVarName("X")); + auto ddy = this->OutputGrad(framework::GradVarName("Y")); + + if (!ddx.empty() || !ddy.empty()) { + retv->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); + } + retv->SetOutput( + "DX", ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X")); + retv->SetOutput( + "DY", ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y")); + retv->SetAttrMap(this->Attrs()); } }; diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 8dde782fe536f..bf1955c5711f5 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -156,22 +156,24 @@ def test_grad(self): class TestMatmulDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): - prog = fluid.Program() - with fluid.program_guard(prog): - x_shape = [2, 3, 4] - y_shape = [2, 4, 5] - eps = 0.005 - dtype = np.float64 - + eps = 0.005 + x_shapes = [[2], [2, 3], [2, 4, 3], [2, 3, 4, 5], [2, 3, 4]] + y_shapes = [[2], [3, 2], [2, 4, 5], [2, 3, 3, 5], [4, 3]] + transpose_xs = [False, True, True, False, False] + transpose_ys = [False, True, False, True, False] + dtypes = [np.float64, np.float64, np.float32, np.float32, np.float64] + typenames = ["float64", "float64", "float32", "float32", "float64"] + for i, (x_shape, y_shape, transpose_x, transpose_y, dtype, typename) \ + in enumerate(zip(x_shapes, y_shapes, transpose_xs, transpose_ys, dtypes, typenames)): x = layers.create_parameter( - dtype="float64", shape=x_shape, name='x') + dtype=typename, shape=x_shape, name='x{}'.format(i)) y = layers.create_parameter( - dtype="float64", shape=y_shape, name='y') - out = layers.matmul(x, y) + dtype=typename, shape=y_shape, name='y{}'.format(i)) + out = layers.matmul( + x, y, transpose_x, transpose_y, name='out{}'.format(i)) x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) y_arr = np.random.uniform(-1, 1, y_shape).astype(dtype) - gradient_checker.double_grad_check( [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)