Skip to content

Commit

Permalink
add more unit test for test_append_backward
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquesqiao committed Oct 11, 2017
1 parent 2e55469 commit e8cad5a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
2 changes: 1 addition & 1 deletion paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void BindProgramDesc(py::module &m) {
py::return_value_policy::reference)
.def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("backward",
.def("append_backward",
[](ProgramDescBind &program_desc,
const std::unordered_set<std::string> &no_grad_vars) {
AppendBackward(program_desc, no_grad_vars);
Expand Down
27 changes: 20 additions & 7 deletions python/paddle/v2/framework/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,33 @@ def test_program(self):
self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx)

def test_backward(self):
def test_append_backward(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)

mul_op_desc = block.append_op()
mul_op_desc.set_type("mul")
mul_op_desc.set_input("X", ["x1"])
mul_op_desc.set_input("Y", ["y1"])
mul_op_desc.set_output("Out", ["out1"])

sum_op_desc = block.append_op()
sum_op_desc.set_type("sum")
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
sum_op_desc.set_type("elementwise_add")
sum_op_desc.set_input("X", ["out1"])
sum_op_desc.set_input("Y", ["b1"])
sum_op_desc.set_output("Out", ["out2"])

self.assertEqual(len(block.all_ops()), 1)
prog.backward(set())
self.assertEqual(len(block.all_ops()), 3)
expect_ops = [
"mul", "elementwise_add", "elementwise_add_grad", "mul_grad"
]
actual_ops = []
prog.append_backward(set())
for op in block.all_ops():
actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops)


if __name__ == '__main__':
Expand Down

0 comments on commit e8cad5a

Please sign in to comment.