From 6d5e7f281020a37bc6599b1b1cd463d64c428630 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Thu, 23 Jun 2016 22:05:55 +0800 Subject: [PATCH] Revise Flat3D + Fix index_t/int comparison warning --- mshadow/tensor_blob.h | 45 +++++++++++++++++++++++++++++++++------- mshadow/tensor_cpu-inl.h | 2 +- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/mshadow/tensor_blob.h b/mshadow/tensor_blob.h index 3d5b8fde8fad..c1cc939f0fc4 100644 --- a/mshadow/tensor_blob.h +++ b/mshadow/tensor_blob.h @@ -206,26 +206,39 @@ struct TShape { return s; } /*! - * flatten the axis before and after the specified axis, so it becomes 3D tensor - * \param axis The axis specified. - * \return the flat 3d shape - */ - inline Shape<3> FlatTo3D(index_t axis) const { + * flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim) + * \param axis_begin The beginning axis specified. + * \param axis_end The ending axis specified. + * \return the flat 3d shape + */ + inline Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const { + CHECK(axis_end >= axis_begin); Shape<3> s; if (ndim_ == 0) return Shape3(0, 0, 0); const index_t *d = this->data(); s.shape_[0] = 1; - s.shape_[1] = d[axis]; + s.shape_[1] = 1; s.shape_[2] = 1; - for (index_t i = 0; i < axis; ++i) { + for (index_t i = 0; i < axis_begin; ++i) { s.shape_[0] *= d[i]; } - for (index_t i = axis + 1; i < ndim_; ++i) { + for (index_t i = axis_begin; i <= axis_end; ++i) { + s.shape_[1] *= d[i]; + } + for (index_t i = axis_end + 1; i < ndim_; ++i) { s.shape_[2] *= d[i]; } return s; } + /*! + * flatten the axis before and after the specified axis, so it becomes 3D tensor + * \param axis The axis specified. + * \return the flat 3d shape + */ + inline Shape<3> FlatTo3D(index_t axis) const { + return FlatTo3D(axis, axis); + } /*! * \return product shape in [dimstart,dimend) * \param dimstart start dimension @@ -697,6 +710,22 @@ class TBlob { return this->get_with_shape( this->shape_.FlatTo3D(axis), stream); } + /*! + * \brief flatten the tensor to 3 dimension, + * collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim). + * \param axis_begin The beginning axis specified. + * \param axis_end The ending axis specified. + * \param stream the possible stream target tensor should reside on + * \tparam Device which device the tensor is on + * \tparam DType the type of elements in the tensor + * \return tensor after flatten + */ + template + inline Tensor FlatTo3D(int axis_begin, int axis_end, + Stream *stream = NULL) const { + return this->get_with_shape( + this->shape_.FlatTo3D(axis_begin, axis_end), stream); + } }; } // namespace mshadow #endif // MSHADOW_TENSOR_BLOB_H_ diff --git a/mshadow/tensor_cpu-inl.h b/mshadow/tensor_cpu-inl.h index 545affcde325..fd12dc87fd6d 100644 --- a/mshadow/tensor_cpu-inl.h +++ b/mshadow/tensor_cpu-inl.h @@ -143,7 +143,7 @@ inline void MapPlan(TRValue *dst, Shape<2> shape = expr::ShapeCheck::Check(dst->self()).FlatTo2D(); expr::Plan dplan = expr::MakePlan(dst->self()); #pragma omp parallel for - for (int y = 0; y < shape[0]; ++y) { + for (index_t y = 0; y < shape[0]; ++y) { for (index_t x = 0; x < shape[1]; ++x) { // trust your compiler! -_- they will optimize it Saver::Save(dplan.REval(y, x), plan.Eval(y, x));