Skip to content

Commit

Permalink
Merge pull request apache#139 from sxjscience/revise_flat3d
Browse files Browse the repository at this point in the history
Revise Flat3D + Fix index_t/int comparison warning
  • Loading branch information
sxjscience authored Jun 23, 2016
2 parents 948abff + 6d5e7f2 commit 3615b22
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
45 changes: 37 additions & 8 deletions mshadow/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,26 +206,39 @@ struct TShape {
return s;
}
/*!
* flatten the axis before and after the specified axis, so it becomes 3D tensor
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline Shape<3> FlatTo3D(index_t axis) const {
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim)
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const {
CHECK(axis_end >= axis_begin);
Shape<3> s;
if (ndim_ == 0) return Shape3(0, 0, 0);
const index_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = d[axis];
s.shape_[1] = 1;
s.shape_[2] = 1;

for (index_t i = 0; i < axis; ++i) {
for (index_t i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (index_t i = axis + 1; i < ndim_; ++i) {
for (index_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (index_t i = axis_end + 1; i < ndim_; ++i) {
s.shape_[2] *= d[i];
}
return s;
}
/*!
* flatten the axis before and after the specified axis, so it becomes 3D tensor
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline Shape<3> FlatTo3D(index_t axis) const {
return FlatTo3D(axis, axis);
}
/*!
* \return product shape in [dimstart,dimend)
* \param dimstart start dimension
Expand Down Expand Up @@ -697,6 +710,22 @@ class TBlob {
return this->get_with_shape<Device, 3, DType>(
this->shape_.FlatTo3D(axis), stream);
}
/*!
* \brief flatten the tensor to 3 dimension,
* collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim).
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline Tensor<Device, 3, DType> FlatTo3D(int axis_begin, int axis_end,
Stream<Device> *stream = NULL) const {
return this->get_with_shape<Device, 3, DType>(
this->shape_.FlatTo3D(axis_begin, axis_end), stream);
}
};
} // namespace mshadow
#endif // MSHADOW_TENSOR_BLOB_H_
2 changes: 1 addition & 1 deletion mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ inline void MapPlan(TRValue<R, cpu, dim, DType> *dst,
Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D();
expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
#pragma omp parallel for
for (int y = 0; y < shape[0]; ++y) {
for (index_t y = 0; y < shape[0]; ++y) {
for (index_t x = 0; x < shape[1]; ++x) {
// trust your compiler! -_- they will optimize it
Saver::Save(dplan.REval(y, x), plan.Eval(y, x));
Expand Down

0 comments on commit 3615b22

Please sign in to comment.