Skip to content

Commit

Permalink
Merge pull request #2 from reyoung/rnn_exec
Browse files Browse the repository at this point in the history
Partial complete rnn
  • Loading branch information
Yang Yang(Tony) committed Oct 24, 2017
2 parents 65a1709 + c46297d commit 8a0e4e6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
2 changes: 2 additions & 0 deletions paddle/operators/recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddInput(name.initial_states, "variables to initialize states.")
.AsDuplicable();

AddInput("parameters", "parameter variables used inside").AsDuplicable();

AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
Expand Down
35 changes: 33 additions & 2 deletions python/paddle/v2/framework/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
import re

__all__ = [
Expand Down Expand Up @@ -440,4 +441,34 @@ def __call__(self, *args, **kwargs):
def complete_rnn_op(self):
# TODO(yuyang18): Create RNN Op here.
# Implement this method after RNN op complete.
pass
program = self.helper.program
rnn_block = program.current_block()
parent_block = self.parent_block()

local_inputs = set()

for op in rnn_block.ops:
assert isinstance(op, Operator)
for oname in op.output_names:
for out_var_name in op.output(oname):
local_inputs.add(out_var_name)

for var in self.inputs:
local_inputs.add(var.name)
for m in self.memories:
local_inputs.add(m)

params = list()
for op in rnn_block.ops:
assert isinstance(op, Operator)
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in local_inputs:
params.append(in_var_name)

inlinks = [parent_block.var(i.name) for i in self.inputs]
outlinks = self.outputs
print outlinks

print params
exit(1)

0 comments on commit 8a0e4e6

Please sign in to comment.