Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#4699 from jacquesqiao/expose-backward
Browse files Browse the repository at this point in the history
expose AppendBackward of ProgramDesc to python
  • Loading branch information
jacquesqiao authored Oct 11, 2017
2 parents 134a073 + e8cad5a commit 5e9d439
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions paddle/framework/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);

// TODO(jiayi): Add target as parameter and generate backward op
// according to target.
void AppendBackward(ProgramDescBind& program_desc,
const std::unordered_set<std::string>& no_grad_vars);

Expand Down
6 changes: 6 additions & 0 deletions paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/pybind/protobuf.h"
#include <deque>
#include <iostream>
#include "paddle/framework/backward.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/program_desc.h"
Expand Down Expand Up @@ -116,6 +117,11 @@ void BindProgramDesc(py::module &m) {
py::return_value_policy::reference)
.def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("append_backward",
[](ProgramDescBind &program_desc,
const std::unordered_set<std::string> &no_grad_vars) {
AppendBackward(program_desc, no_grad_vars);
})
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size);
}
Expand Down
30 changes: 30 additions & 0 deletions python/paddle/v2/framework/tests/test_program.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import unittest

import paddle.v2.framework.core as core
from paddle.v2.framework.graph import g_program


Expand Down Expand Up @@ -31,6 +33,34 @@ def test_program(self):
self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx)

def test_append_backward(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)

mul_op_desc = block.append_op()
mul_op_desc.set_type("mul")
mul_op_desc.set_input("X", ["x1"])
mul_op_desc.set_input("Y", ["y1"])
mul_op_desc.set_output("Out", ["out1"])

sum_op_desc = block.append_op()
sum_op_desc.set_type("elementwise_add")
sum_op_desc.set_input("X", ["out1"])
sum_op_desc.set_input("Y", ["b1"])
sum_op_desc.set_output("Out", ["out2"])

expect_ops = [
"mul", "elementwise_add", "elementwise_add_grad", "mul_grad"
]
actual_ops = []
prog.append_backward(set())
for op in block.all_ops():
actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5e9d439

Please sign in to comment.