Skip to content

Commit

Permalink
add dim_idx attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang Yang committed Oct 30, 2017
1 parent 5f67111 commit 6b58489
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
18 changes: 13 additions & 5 deletions paddle/operators/fill_constant_batch_size_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
[](int a) { return static_cast<int64_t>(a); });
auto dims = framework::make_ddim(shape_int64);

dims[0] = ctx->GetInputDim("Input")[0];
int dim_idx = ctx->Attrs().Get<int>("dim_idx");
PADDLE_ENFORCE_GE(dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), dim_idx);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), dim_idx);

dims[dim_idx] = ctx->GetInputDim("Input")[dim_idx];
ctx->SetOutputDim("Out", dims);
}

Expand All @@ -57,15 +62,18 @@ class FillConstantBatchSizeLikeOpMaker
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddInput("Input",
"(Tensor) Tensor "
"whose first dimension is used to specify the batch_size");
"whose dim_idx th dimension is used to specify the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("dim_idx",
"(int, default 0) the index of batch size dimension")
.SetDefault(0);
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddComment(R"DOC(Fill up a variable with specified constant value.)DOC");
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@
from op_test import OpTest


class TestFillConstantBatchSizeLikeOp(OpTest):
class TestFillConstantBatchSizeLikeOp1(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("float32")}
self.attrs = {'value': 3.5, 'shape': [-1, 132, 777]}
self.attrs = {'value': 3.5, 'shape': [-1, 132, 7]}

out = np.random.random((219, 132, 777)).astype("float32")
out = np.random.random((219, 132, 7)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output()


class TestFillConstantBatchSizeLikeOp2(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("float32")}
self.attrs = {'value': 3.5, 'shape': [132, -1, 7], 'dim_idx': 1}

out = np.random.random((132, 232, 7)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}

Expand Down

0 comments on commit 6b58489

Please sign in to comment.