Skip to content

Commit

Permalink
Fix/sequence op (#5264)
Browse files Browse the repository at this point in the history
* "replace enum with string"

* "fix layers"
  • Loading branch information
dzhwinter committed Oct 31, 2017
1 parent bcdedec commit 2649221
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 113 deletions.
13 changes: 6 additions & 7 deletions paddle/operators/sequence_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out",
"(Tensor), output of SequencePoolOp, which does not contain LoD "
"infomation.");
AddAttr<int>(
"strategy",
"(int, default AVERAGE) the pooling strategy of SequencePoolOp.")
.SetDefault(AVERAGE)
.InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST});
AddAttr<std::string>(
"pooltype",
"(int, default AVERAGE) the pooling pooltype of SequencePoolOp.")
.SetDefault("AVERAGE");
AddComment(R"DOC(
SequencePoolOp pools features of all time-steps of each instance.
It supports six pooling strategy:
It supports six pooling pooltype:
- AVERAGE: Out[i] = average_{for each instance in i-th sequence}{X[i]}
- SUM: Out[i] = sum_{for each instance in i-th sequence}{X[i]}
- SQRT: Out[i] = sum_{for each instance in i-th sequence}{X[i]}
Expand All @@ -63,7 +62,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
and the value of X = [[1, 3], [2, 4, 6], [5, 1]].
Thus, Out is a [3,1,1] Tensor without LoD infomation.
And for different strategy, the value of Out is as follows:
And for different pooltype, the value of Out is as follows:
- AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
- SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
Expand Down
114 changes: 45 additions & 69 deletions paddle/operators/sequence_pool_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,13 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

enum SeqPoolType {
AVERAGE = 0,
SUM = 1,
SQRT = 2, // square_root_n
MAX = 3,
LAST = 4,
FIRST = 5
};

template <typename Place, typename T>
class SequencePoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int strategy = context.Attr<int>("strategy");
std::string pooltype = context.Attr<std::string>("pooltype");

auto dims = in->dims();
auto lod = in->lod();
Expand All @@ -71,28 +62,21 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);

switch (strategy) {
case AVERAGE:
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
break;
case SUM:
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
break;
case SQRT:
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
break;
case MAX:
out_e.device(place) = in_e.maximum(Eigen::array<int, 1>({{0}}));
break;
case LAST:
out_e.device(place) = in_e.chip(h - 1, 0);
break;
case FIRST:
out_e.device(place) = in_e.chip(0, 0);
break;
default:
PADDLE_THROW("unsupported pooling strategy");
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "MAX") {
out_e.device(place) = in_e.maximum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
Expand All @@ -105,15 +89,15 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
int strategy = context.Attr<int>("strategy");
std::string pooltype = context.Attr<std::string>("pooltype");

auto dims = in->dims();
auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0];

in_g->mutable_data<T>(context.GetPlace());
if (strategy == LAST || strategy == FIRST) {
// set X@Grad be zero at first when strategy is LAST/FIRST
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<Place, T> functor;
functor(context.device_context(), in_g, 0);
}
Expand All @@ -127,41 +111,33 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
Eigen::DSizes<int, 2> bcast(h, 1);

switch (strategy) {
case AVERAGE:
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
break;
case SUM:
in_g_e.device(place) = (out_g_e).broadcast(bcast);
break;
case SQRT:
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
break;
case MAX: {
auto in_t =
in->Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
in_t_map(in_t.data<T>(), h, w);
int row_id;
Eigen::array<int, 2> extents{{1, 1}};
for (int col_id = 0; col_id < w; col_id++) {
in_t_map.col(col_id).maxCoeff(&row_id);
Eigen::array<int, 2> in_offsets{{row_id, col_id}};
Eigen::array<int, 2> out_offsets{{0, col_id}};
in_g_e.slice(in_offsets, extents).device(place) =
out_g_e.slice(out_offsets, extents);
}
break;
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "MAX") {
auto in_t =
in->Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
in_t_map(in_t.data<T>(), h, w);
int row_id;
Eigen::array<int, 2> extents{{1, 1}};
for (int col_id = 0; col_id < w; col_id++) {
in_t_map.col(col_id).maxCoeff(&row_id);
Eigen::array<int, 2> in_offsets{{row_id, col_id}};
Eigen::array<int, 2> out_offsets{{0, col_id}};
in_g_e.slice(in_offsets, extents).device(place) =
out_g_e.slice(out_offsets, extents);
}
case LAST:
in_g_e.chip(h - 1, 0).device(place) = out_g_e;
break;
case FIRST:
in_g_e.chip(0, 0).device(place) = out_g_e;
break;
default:
PADDLE_THROW("unsupported pooling strategy");
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
Expand Down
21 changes: 5 additions & 16 deletions python/paddle/v2/framework/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,32 +351,21 @@ def conv2d(input,
return helper.append_activation(pre_act)


def sequence_pool(input, pool_type, program=None, init_program=None):
# FIXME(dzh) : want to unify the argument of python layer
# function. So we ignore some unecessary attributes

ENUM_POOL_TYPE = dict({
"AVERAGE": 0,
"SUM": 1,
"SQRT": 2,
"MAX": 3,
"LAST": 4,
"FIRST": 5
})
def sequence_pool(input, pool_type, **kwargs):
ENUM_POOL_TYPE = set(["MAX", "AVG", "SQRT", "LAST", "FIRST"])
if pool_type.upper() not in ENUM_POOL_TYPE:
raise ValueError("Unknown pool_type: '%s'. It can only be %s.",
str(pool_type), " ".join(ENUM_POOL_TYPE.keys()))
str(pool_type), " ".join(ENUM_POOL_TYPE))

helper = LayerHelper('sequence_pool', **locals())
helper = LayerHelper('sequence_pool', **kwargs)
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)

# FIXME(dzh): strategy
helper.append_op(
type="sequence_pool",
inputs={"X": [input]},
outputs={"Out": [pool_out]},
attrs={"strategy": ENUM_POOL_TYPE[pool_type.upper()]})
attrs={"pooltype": pool_type.upper()})

return pool_out

Expand Down
33 changes: 12 additions & 21 deletions python/paddle/v2/framework/tests/test_seq_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
from op_test import OpTest


class SeqPoolType(OpTest):
AVERAGE = 0
SUM = 1
SQRT = 2
MAX = 3
LAST = 4
FIRST = 5


class TestSeqAvgPool(OpTest):
def set_data(self):
self.op_type = 'sequence_pool'
Expand All @@ -25,7 +16,7 @@ def set_data(self):
return x, lod, out

def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.AVERAGE}
self.attrs = {'pooltype': "AVERAGE"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.mean(axis=0)
Expand Down Expand Up @@ -54,31 +45,31 @@ def set_data(self):
return x, lod, out

def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.AVERAGE}
self.attrs = {'pooltype': "AVERAGE"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))


class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.SUM}
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)


class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.SUM}
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))


class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.SQRT}
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
Expand All @@ -87,7 +78,7 @@ def compute(self, x, lod, out):

class TestSeqSqrtPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.SQRT}
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
len = lod[0][i + 1] - lod[0][i]
Expand All @@ -99,7 +90,7 @@ def test_check_grad(self):

class TestSeqMaxPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.MAX}
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
Expand All @@ -111,7 +102,7 @@ def test_check_grad(self):

class TestSeqMaxPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.MAX}
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 17))
Expand All @@ -123,31 +114,31 @@ def test_check_grad(self):

class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.LAST}
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]


class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.LAST}
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x[-1, :], (3, 17))


class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.FIRST}
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]


class TestSeqFirstPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'strategy': SeqPoolType.FIRST}
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x[0, :], (3, 17))
Expand Down

0 comments on commit 2649221

Please sign in to comment.