From 6b58489c82a419096b324c70e381ea908845769d Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 30 Oct 2017 19:07:59 +0000 Subject: [PATCH] add dim_idx attribute --- .../fill_constant_batch_size_like_op.cc | 18 ++++++++++++----- .../test_fill_constant_batch_size_like_op.py | 20 ++++++++++++++++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/paddle/operators/fill_constant_batch_size_like_op.cc b/paddle/operators/fill_constant_batch_size_like_op.cc index 58c9f1cd2c79c..0244adb42392c 100644 --- a/paddle/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/operators/fill_constant_batch_size_like_op.cc @@ -36,7 +36,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { [](int a) { return static_cast(a); }); auto dims = framework::make_ddim(shape_int64); - dims[0] = ctx->GetInputDim("Input")[0]; + int dim_idx = ctx->Attrs().Get("dim_idx"); + PADDLE_ENFORCE_GE(dim_idx, 0); + PADDLE_ENFORCE_GT(static_cast(shape.size()), dim_idx); + PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), dim_idx); + + dims[dim_idx] = ctx->GetInputDim("Input")[dim_idx]; ctx->SetOutputDim("Out", dims); } @@ -57,15 +62,18 @@ class FillConstantBatchSizeLikeOpMaker "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::DataType::FP32); - AddAttr>("shape", "(vector) The shape of the output"); - AddAttr("value", "(float, default 0) The value to be filled") - .SetDefault(0.0f); AddInput("Input", "(Tensor) Tensor " - "whose first dimension is used to specify the batch_size"); + "whose dim_idx th dimension is used to specify the batch_size"); AddOutput("Out", "(Tensor) Tensor of specified shape will be filled " "with the specified value"); + AddAttr>("shape", "(vector) The shape of the output"); + AddAttr("dim_idx", + "(int, default 0) the index of batch size dimension") + .SetDefault(0); + AddAttr("value", "(float, default 0) The value to be filled") + .SetDefault(0.0f); AddComment(R"DOC(Fill up a variable with specified constant value.)DOC"); } }; diff --git a/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py b/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py index 065a9133dca25..b7d4a75838963 100644 --- a/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py +++ b/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py @@ -3,13 +3,27 @@ from op_test import OpTest -class TestFillConstantBatchSizeLikeOp(OpTest): +class TestFillConstantBatchSizeLikeOp1(OpTest): def setUp(self): self.op_type = "fill_constant_batch_size_like" self.inputs = {'Input': np.random.random((219, 232)).astype("float32")} - self.attrs = {'value': 3.5, 'shape': [-1, 132, 777]} + self.attrs = {'value': 3.5, 'shape': [-1, 132, 7]} - out = np.random.random((219, 132, 777)).astype("float32") + out = np.random.random((219, 132, 7)).astype("float32") + out.fill(3.5) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantBatchSizeLikeOp2(OpTest): + def setUp(self): + self.op_type = "fill_constant_batch_size_like" + self.inputs = {'Input': np.random.random((219, 232)).astype("float32")} + self.attrs = {'value': 3.5, 'shape': [132, -1, 7], 'dim_idx': 1} + + out = np.random.random((132, 232, 7)).astype("float32") out.fill(3.5) self.outputs = {'Out': out}