Skip to content

Commit

Permalink
modify code according to review
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxinxin08 committed Oct 9, 2020
1 parent 32604c3 commit bb9cae0
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 54 deletions.
115 changes: 72 additions & 43 deletions paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout = *context.Input<framework::LoDTensor>("DOut");
auto ddx = *context.Input<framework::LoDTensor>("DDX");
auto ddy = *context.Input<framework::LoDTensor>("DDY");
auto *ddx = context.Input<framework::LoDTensor>("DDX");
auto *ddy = context.Input<framework::LoDTensor>("DDY");

auto *dx = context.Output<framework::LoDTensor>("DX");
auto *dy = context.Output<framework::LoDTensor>("DY");
Expand Down Expand Up @@ -440,47 +440,67 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
}

bool ddout_flag = false;
if (dy) {
if (transpose_x && transpose_y) {
// dy = dout' * ddx'
CalcInputGrad(context, dout, true, true, ddx, true, false, false, dy);
} else if (transpose_x) {
// dy = ddx * dout
CalcInputGrad(context, ddx, false, false, dout, false, true, false, dy);
} else if (transpose_y) {
// dy = dout' * ddx
CalcInputGrad(context, dout, true, true, ddx, false, true, false, dy);
} else {
// dy = ddx' * dout
CalcInputGrad(context, ddx, true, true, dout, false, true, false, dy);
if (ddx) {
auto ddx_mat = *ddx;
if (ddx_mat.dims() != x.dims()) {
ddx_mat.Resize(x.dims());
}
if (dy) {
if (transpose_x && transpose_y) {
// dy = dout' * ddx'
CalcInputGrad(context, dout, true, true, ddx_mat, true, false, false,
dy);
} else if (transpose_x) {
// dy = ddx * dout
CalcInputGrad(context, ddx_mat, false, false, dout, false, true,
false, dy);
} else if (transpose_y) {
// dy = dout' * ddx
CalcInputGrad(context, dout, true, true, ddx_mat, false, true, false,
dy);
} else {
// dy = ddx' * dout
CalcInputGrad(context, ddx_mat, true, true, dout, false, true, false,
dy);
}
}
}

if (ddout) {
CalcInputGrad(context, ddx, transpose_x, true, y, transpose_y, false,
ddout_flag, ddout);
ddout_flag = true;
if (ddout) {
CalcInputGrad(context, ddx_mat, transpose_x, true, y, transpose_y,
false, ddout_flag, ddout);
ddout_flag = true;
}
}

if (dx) {
if (transpose_x && transpose_y) {
// dx = ddy' * dout'
CalcInputGrad(context, ddy, true, true, dout, true, false, false, dx);
} else if (transpose_x) {
// dx = ddy * dout'
CalcInputGrad(context, ddy, false, false, dout, true, false, false, dx);
} else if (transpose_y) {
// dx = dout * ddy
CalcInputGrad(context, dout, false, false, ddy, false, true, false, dx);
} else {
// dx = dout * ddy'
CalcInputGrad(context, dout, false, false, ddy, true, false, false, dx);
if (ddy) {
auto ddy_mat = *ddy;
if (ddy_mat.dims() != y.dims()) {
ddy_mat.Resize(y.dims());
}
if (dx) {
if (transpose_x && transpose_y) {
// dx = ddy' * dout'
CalcInputGrad(context, ddy_mat, true, true, dout, true, false, false,
dx);
} else if (transpose_x) {
// dx = ddy * dout'
CalcInputGrad(context, ddy_mat, false, false, dout, true, false,
false, dx);
} else if (transpose_y) {
// dx = dout * ddy
CalcInputGrad(context, dout, false, false, ddy_mat, false, true,
false, dx);
} else {
// dx = dout * ddy'
CalcInputGrad(context, dout, false, false, ddy_mat, true, false,
false, dx);
}
}
}

if (ddout) {
CalcInputGrad(context, x, transpose_x, true, ddy, transpose_y, false,
ddout_flag, ddout);
if (ddout) {
CalcInputGrad(context, x, transpose_x, true, ddy_mat, transpose_y,
false, ddout_flag, ddout);
}
}

if (dx) {
Expand Down Expand Up @@ -813,15 +833,16 @@ class MatMulOpDoubleGrad : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul");

if (context->HasOutput("DX")) {
if (context->HasOutput("DX") && context->HasInput("DDY")) {
context->ShareDim("X", "DX");
}

if (context->HasOutput("DY")) {
if (context->HasOutput("DY") && context->HasInput("DDX")) {
context->ShareDim("Y", "DY");
}

if (context->HasOutput("DDOut")) {
if (context->HasOutput("DDOut") &&
(context->HasInput("DDY") || context->HasInput("DDX"))) {
context->ShareDim("DOut", "DDOut");
}
}
Expand All @@ -841,9 +862,17 @@ class MatMulOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
retv->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
retv->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));

retv->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
retv->SetOutput("DX", this->InputGrad("X"));
retv->SetOutput("DY", this->InputGrad("Y"));
auto ddx = this->OutputGrad(framework::GradVarName("X"));
auto ddy = this->OutputGrad(framework::GradVarName("Y"));

if (!ddx.empty() || !ddy.empty()) {
retv->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
retv->SetOutput(
"DX", ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X"));
retv->SetOutput(
"DY", ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y"));

retv->SetAttrMap(this->Attrs());
}
};
Expand Down
24 changes: 13 additions & 11 deletions python/paddle/fluid/tests/unittests/test_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,24 @@ def test_grad(self):
class TestMatmulDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
x_shape = [2, 3, 4]
y_shape = [2, 4, 5]
eps = 0.005
dtype = np.float64

eps = 0.005
x_shapes = [[2], [2, 3], [2, 4, 3], [2, 3, 4, 5], [2, 3, 4]]
y_shapes = [[2], [3, 2], [2, 4, 5], [2, 3, 3, 5], [4, 3]]
transpose_xs = [False, True, True, False, False]
transpose_ys = [False, True, False, True, False]
dtypes = [np.float64, np.float64, np.float32, np.float32, np.float64]
typenames = ["float64", "float64", "float32", "float32", "float64"]
for i, (x_shape, y_shape, transpose_x, transpose_y, dtype, typename) \
in enumerate(zip(x_shapes, y_shapes, transpose_xs, transpose_ys, dtypes, typenames)):
x = layers.create_parameter(
dtype="float64", shape=x_shape, name='x')
dtype=typename, shape=x_shape, name='x{}'.format(i))
y = layers.create_parameter(
dtype="float64", shape=y_shape, name='y')
out = layers.matmul(x, y)
dtype=typename, shape=y_shape, name='y{}'.format(i))
out = layers.matmul(
x, y, transpose_x, transpose_y, name='out{}'.format(i))

x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, y_shape).astype(dtype)

gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)

Expand Down

0 comments on commit bb9cae0

Please sign in to comment.