Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix calculations in gru_unit_op to consistent with gru_op #5804

Merged
merged 3 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions paddle/operators/gru_unit_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,19 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddComment(R"DOC(
GRUUnit Operator.

This operator implements partial calculations of the GRU unit as follows:
GRUUnit Operator implements partial calculations of the GRU unit as following:

$$
update \ gate: u_t = actGate(xu_t + W_u * hidden_{prev} + bias_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * hidden_{prev} + bias_r) \\
output \ candidate: {h}_t = actNode({xc}_t + W_c * dot(r_t, hidden_{prev}) + bias_c) \\
output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_{prev})
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
$$

The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp.
which is same as one time step of GRU Operator.

@note To implement the complete GRU unit, fully-connected operator must be
used before to feed xu, xr and xc as the Input of GRUUnit operator.

)DOC");
}
Expand All @@ -150,12 +151,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Gate");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden");
Expand Down
76 changes: 41 additions & 35 deletions paddle/operators/gru_unit_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate

// calculate final output
h.device(place) = u * (h_p - c) + c;
h.device(place) = u * (c - h_p) + h_p;
}
};

Expand Down Expand Up @@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
input_grad->mutable_data<T>(context.GetPlace());
hidden_prev_grad->mutable_data<T>(context.GetPlace());
weight_grad->mutable_data<T>(context.GetPlace());
Tensor gate_grad;
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
Tensor reset_hidden_prev_grad;
reset_hidden_prev_grad.mutable_data<T>(reset_hidden_prev->dims(),
context.GetPlace());

int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];

const T* hidden_prev_data = hidden_prev->data<T>();
T* hidden_prev_grad_data = hidden_prev_grad->data<T>();
const T* weight_data = weight->data<T>();
T* weight_grad_data = weight_grad->data<T>();
T* gate_grad_data = gate_grad.data<T>();
T* gate_grad_data =
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
const T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data<T>();
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data<T>(
reset_hidden_prev->dims(), context.GetPlace());

auto h_p = EigenMatrix<T>::From(*hidden_prev);
auto g = EigenMatrix<T>::From(*gate);
auto d_h = EigenMatrix<T>::From(*hidden_grad);
auto d_x = EigenMatrix<T>::From(*input_grad);
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
auto d_g = EigenMatrix<T>::From(gate_grad);
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
auto place = context.GetEigenDevice<Place>();

int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];

Eigen::array<int, 2> extents({{batch_size, frame_size}});
Eigen::array<int, 2> u_offsets({{0, 0}});
auto u = g.slice(u_offsets, extents); // update gate
Expand All @@ -185,38 +177,52 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {

// backward for unactivated update gate
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
d_g.slice(c_offsets, extents), d_h * u);
// backward for reset_hidden_prev
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size, 1,
gate_grad_data + frame_size * 2, frame_size * 3,
weight_data + frame_size * frame_size * 2, frame_size,
0, reset_hidden_prev_grad_data, frame_size);
// backward for state_weight
math::gemm<Place, T>(
context.device_context(), true, false, frame_size, frame_size,
batch_size, 1, reset_hidden_prev_data, frame_size,
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for unactivated reset gate
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
// backward for update_gate_weight and reset_gate_weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
frame_size * 2, batch_size, 1, hidden_prev_data,
frame_size, gate_grad_data, frame_size * 3, 0,
weight_grad_data, frame_size * 2);
// backward for weight
if (weight_grad) {
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
// backward for state_weight
math::gemm<Place, T>(
context.device_context(), true, false, frame_size, frame_size,
batch_size, 1, reset_hidden_prev_data, frame_size,
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);

// backward for update_gate_weight and reset_gate_weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
frame_size * 2, batch_size, 1, hidden_prev_data,
frame_size, gate_grad_data, frame_size * 3, 0,
weight_grad_data, frame_size * 2);
}
// backward for hidden_prev
d_h_p.device(place) = d_r_h_p * r + d_h * u;
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size * 2, 1, gate_grad_data,
frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
if (hidden_prev_grad) {
T* hidden_prev_grad_data =
hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size * 2, 1, gate_grad_data,
frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
}
// backward for input
d_x.device(place) = d_g;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
auto d_x = EigenMatrix<T>::From(*input_grad);
d_x.device(place) = d_g;
}
// backward for bias
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
Expand Down
21 changes: 10 additions & 11 deletions python/paddle/v2/fluid/tests/test_gru_unit_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def relu(x):


class TestGRUUnitOp(OpTest):
batch_size = 3
frame_size = 5
batch_size = 5
frame_size = 10
activate = {
GRUActivationType.identity: identity,
GRUActivationType.sigmoid: sigmoid,
Expand Down Expand Up @@ -77,7 +77,7 @@ def set_outputs(self):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c
h = u * c + (1 - u) * h_p
self.outputs = {
'Gate': g.astype('float64'),
'ResetHiddenPrev': r_h_p.astype('float64'),
Expand All @@ -92,10 +92,7 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(
['Input', 'HiddenPrev', 'Weight'],
['Hidden', 'ResetHiddenPrev', 'Gate'],
max_relative_error=0.007)
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])


class TestGRUUnitOpWithBias(TestGRUUnitOp):
Expand All @@ -104,18 +101,20 @@ def set_inputs(self):
frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs()
self.inputs['Bias'] = np.random.uniform(
-0.1, 0.1, (1, frame_size * 3)).astype('float32')
-0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = {
'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid
}

def test_check_grad(self):
self.check_grad(['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'])

def test_check_grad_ingore_input(self):
self.check_grad(
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
max_relative_error=0.007)
['HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
no_grad_set=set('Input'))


if __name__ == '__main__':
exit(0) # FIXME(yuyang18): This unittest is not pass. Fix it later
unittest.main()