Skip to content

Commit

Permalink
Add ConditionalBlockGradInferVarType
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 committed Feb 16, 2022
1 parent 3581c07 commit 1d9339b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
15 changes: 14 additions & 1 deletion paddle/fluid/operators/controlflow/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,18 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
}
};

class ConditionalBlockGradInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
// NOTE(Aurelius84): VarType of Output is LoDTensor by default. In case of
// Input is {Tensor, LoDTensorArray}, we need synchronous the Input's
// VarType into Input@GRAD to avoid generating {Tensor, Tensor} as
// Input@GRAD.
ctx->SyncTypeAndDataType(ConditionalOp::kInputs,
framework::GradVarName(ConditionalOp::kInputs));
}
};

template <typename T>
class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
public:
Expand Down Expand Up @@ -300,4 +312,5 @@ REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
ops::ConditionalBlockOpProtoMaker,
ops::ConditionalBlockGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
ops::ConditionalBlockGradInferShape);
ops::ConditionalBlockGradInferShape,
ops::ConditionalBlockGradInferVarType);
30 changes: 30 additions & 0 deletions python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,5 +306,35 @@ def init_data(self):
self.input = np.random.random((3, 4)).astype('float32')


class ListWithCondNet(paddle.nn.Layer):
def __init__(self):
super(ListWithCondNet, self).__init__()

@paddle.jit.to_static
def forward(self, x, index):
y = paddle.nn.functional.relu(x)
a = []

for i in y:
a.append(i)

if index > 0:
res = a[0] * a[0]
else:
res = a[-1] * a[-1]

z = a[-1] * res
return z


class TestListWithCondGradInferVarType(unittest.TestCase):
def test_to_static(self):
net = ListWithCondNet()
x = paddle.to_tensor([2, 3, 4], dtype='float32')
index = paddle.to_tensor([1])
res = net(x, index)
self.assertEqual(res[0], 16.)


if __name__ == '__main__':
unittest.main()

1 comment on commit 1d9339b

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.