From fbbfe8b8594960934529330dc1321e1fdc6c2a6d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 4 Dec 2017 14:21:13 +0800 Subject: [PATCH 1/5] code refine --- paddle/operators/elementwise_add_op.h | 39 +++++++- paddle/operators/elementwise_op_function.h | 108 +++++++++++++++++++++ 2 files changed, 146 insertions(+), 1 deletion(-) diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index f04fe3ec6069a..686d45573d887 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -19,11 +19,48 @@ namespace paddle { namespace operators { +template +struct AddFunctor { + HOSTDEVICE T operator()(T a, T b) const { return a + b; } +}; + template class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - ElementwiseCompute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor, T, Place> functor(x, y, z, ctx, + AddFunctor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 8aa35b2c46678..22b96b93121a1 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -16,6 +16,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" +#include "paddle/platform/transform.h" #include "paddle/operators/math/math_function.h" @@ -54,6 +55,113 @@ inline void get_mid_dims(const framework::DDim& x_dims, } } +template +struct RowwiseTransformIterator; +template +struct MidWiseTransformIterator; + +template +struct RowwiseTransformIterator { + RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + + RowwiseTransformIterator& operator++() { + ++i_; + if (i_ == n_) { + i_ = 0; + } + return *this; + } + + bool operator==( + const RowwiseTransformIterator& rhs) const { + return &(this->operator*()) == &(*rhs); + } + + bool operator!=( + const RowwiseTransformIterator& rhs) const { + return &(this->operator*()) &= &(*rhs); + } + + const T& operator*() { return ptr_[i_]; } + + const T* ptr_; + int i_; + int n_; +}; + +template +struct MidWiseTransformIterator { + MidWiseTransformIterator(const T* ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator& operator++() { + ++j_; + if (j_ == post_) { + j_ = 0; + ++i_; + if (i_ == n_) { + i_ = 0; + } + } + return *this; + } + + bool operator==( + const MidWiseTransformIterator& rhs) const { + return &(this->operator*()) == &(*rhs); + } + + bool operator!=( + const MidWiseTransformIterator& rhs) const { + return &(this->operator*()) &= &(*rhs); + } + + const T& operator*() { return ptr_[i_]; } + + const T* ptr_; + int i_; + int j_; + int n_; + int post_; +}; + +template +struct TransformFunctor { + TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z, const framework::ExecutionContext& ctx, + Functor func) + : x_(x->data()), + y_(y->data()), + z_(z->mutable_data(ctx.GetPlace())), + nx_(x->numel()), + ctx_(ctx), + func_(func) {} + + inline void Run() const { + platform::Transform trans; + trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + platform::Transform trans; + trans(ctx_.device_context(), x_, x_ + nx_, + RowwiseTransformIterator(y_, n), z_, func_); + } + + inline void RunMidWise(int n, int pre, int post) const { + platform::Transform trans; + trans(ctx_.device_context(), x_, x_ + nx_, + MidWiseTransformIterator(y_, n, post), z_, func_); + } + + const T* x_; + const T* y_; + T* z_; + int64_t nx_; + const framework::ExecutionContext& ctx_; + Functor func_; +}; + #define EIGEN_FUNCTOR(name, eigen_op) \ struct Eigen##name##Functor { \ template \ From 488908e95b1f17be97ab295e0971cd7832d703c7 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 4 Dec 2017 20:55:28 +0800 Subject: [PATCH 2/5] refine cuda --- paddle/operators/elementwise_op_function.h | 55 ++++++++++++++++++++-- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 22b96b93121a1..09ab42b501b3f 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -18,6 +18,10 @@ #include "paddle/framework/operator.h" #include "paddle/platform/transform.h" +#ifdef __NVCC__ +#include +#endif + #include "paddle/operators/math/math_function.h" namespace paddle { @@ -74,12 +78,12 @@ struct RowwiseTransformIterator { bool operator==( const RowwiseTransformIterator& rhs) const { - return &(this->operator*()) == &(*rhs); + return (ptr_ + i_) == &(*rhs); } bool operator!=( const RowwiseTransformIterator& rhs) const { - return &(this->operator*()) &= &(*rhs); + return (ptr_ + i_) &= &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -108,12 +112,12 @@ struct MidWiseTransformIterator { bool operator==( const MidWiseTransformIterator& rhs) const { - return &(this->operator*()) == &(*rhs); + return (ptr_ + i_) == &(*rhs); } bool operator!=( const MidWiseTransformIterator& rhs) const { - return &(this->operator*()) &= &(*rhs); + return (ptr_ + i_) &= &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -125,6 +129,49 @@ struct MidWiseTransformIterator { int post_; }; +#ifdef __NVCC__ +template +struct RowwiseTransformIterator + : public thrust::iterator_adaptor< + RowwiseTransformIterator, const T*> { + public: + typedef thrust::iterator_adaptor< + RowwiseTransformIterator, const T*> + super_t; + __host__ __device__ RowwiseTransformIterator(const T* x, int n) + : super_t(x), begin_(x), n_(n){}; + friend class thrust::iterator_core_access; + + private: + unsigned int n_; + const T* begin_; + __host__ __device__ typename super_t::reference dereference() const { + return *(begin_ + (this->base() - begin_) % n_); + } +}; + +template +struct MidWiseTransformIterator + : public thrust::iterator_adaptor< + MidWiseTransformIterator, const T*> { + public: + typedef thrust::iterator_adaptor< + MidWiseTransformIterator, const T*> + super_t; + __host__ __device__ MidWiseTransformIterator(const T* x, int n, int post) + : super_t(x), begin_(x), n_(n), post_(post){}; + friend class thrust::iterator_core_access; + + private: + unsigned int post_; + unsigned int n_; + const T* begin_; + __host__ __device__ typename super_t::reference dereference() const { + return *(begin_ + (((this->base() - begin_) / post_) % n_)); + } +}; +#endif + template struct TransformFunctor { TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, From 54f09620cf6ee80224270cbf400cdbbcbdd19697 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 5 Dec 2017 08:43:08 +0800 Subject: [PATCH 3/5] code refine --- paddle/operators/elementwise_op_function.h | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 09ab42b501b3f..14da42a786f08 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -70,9 +70,7 @@ struct RowwiseTransformIterator { RowwiseTransformIterator& operator++() { ++i_; - if (i_ == n_) { - i_ = 0; - } + i_ %= n_; return *this; } @@ -90,7 +88,7 @@ struct RowwiseTransformIterator { const T* ptr_; int i_; - int n_; + int64_t n_; }; template @@ -99,14 +97,7 @@ struct MidWiseTransformIterator { : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator& operator++() { - ++j_; - if (j_ == post_) { - j_ = 0; - ++i_; - if (i_ == n_) { - i_ = 0; - } - } + i_ = ++j_ / post_ % n_; return *this; } @@ -124,8 +115,8 @@ struct MidWiseTransformIterator { const T* ptr_; int i_; - int j_; - int n_; + int64_t j_; + int64_t n_; int post_; }; From 9e244a8cbe8e0d089e0f3d402230a1d5f2ffcbb9 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 5 Dec 2017 10:59:43 +0800 Subject: [PATCH 4/5] follow comments --- paddle/operators/elementwise_add_op.h | 4 ++-- paddle/operators/elementwise_op_function.h | 28 +++++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index 686d45573d887..3a198c167e4cb 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel { auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, Place> functor(x, y, z, ctx, - AddFunctor()); + TransformFunctor, T, Place> functor( + x, y, z, ctx.device_context(), AddFunctor()); auto x_dims = x->dims(); auto y_dims = y->dims(); diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 14da42a786f08..6d849bff49909 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -81,7 +81,7 @@ struct RowwiseTransformIterator { bool operator!=( const RowwiseTransformIterator& rhs) const { - return (ptr_ + i_) &= &(*rhs); + return (ptr_ + i_) != &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -97,7 +97,7 @@ struct MidWiseTransformIterator { : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator& operator++() { - i_ = ++j_ / post_ % n_; + i_ = (++j_ / post_) % n_; return *this; } @@ -108,7 +108,7 @@ struct MidWiseTransformIterator { bool operator!=( const MidWiseTransformIterator& rhs) const { - return (ptr_ + i_) &= &(*rhs); + return (ptr_ + i_) != &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -129,14 +129,14 @@ struct RowwiseTransformIterator typedef thrust::iterator_adaptor< RowwiseTransformIterator, const T*> super_t; - __host__ __device__ RowwiseTransformIterator(const T* x, int n) + HOSTDEVICE RowwiseTransformIterator(const T* x, int n) : super_t(x), begin_(x), n_(n){}; friend class thrust::iterator_core_access; private: unsigned int n_; const T* begin_; - __host__ __device__ typename super_t::reference dereference() const { + HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (this->base() - begin_) % n_); } }; @@ -149,7 +149,7 @@ struct MidWiseTransformIterator typedef thrust::iterator_adaptor< MidWiseTransformIterator, const T*> super_t; - __host__ __device__ MidWiseTransformIterator(const T* x, int n, int post) + HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) : super_t(x), begin_(x), n_(n), post_(post){}; friend class thrust::iterator_core_access; @@ -157,7 +157,7 @@ struct MidWiseTransformIterator unsigned int post_; unsigned int n_; const T* begin_; - __host__ __device__ typename super_t::reference dereference() const { + HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (((this->base() - begin_) / post_) % n_)); } }; @@ -166,7 +166,7 @@ struct MidWiseTransformIterator template struct TransformFunctor { TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z, const framework::ExecutionContext& ctx, + framework::Tensor* z, const platform::DeviceContext& ctx, Functor func) : x_(x->data()), y_(y->data()), @@ -177,26 +177,26 @@ struct TransformFunctor { inline void Run() const { platform::Transform trans; - trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_); + trans(ctx_, x_, x_ + nx_, y_, z_, func_); } inline void RunRowWise(int n, int pre) const { platform::Transform trans; - trans(ctx_.device_context(), x_, x_ + nx_, - RowwiseTransformIterator(y_, n), z_, func_); + trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator(y_, n), z_, + func_); } inline void RunMidWise(int n, int pre, int post) const { platform::Transform trans; - trans(ctx_.device_context(), x_, x_ + nx_, - MidWiseTransformIterator(y_, n, post), z_, func_); + trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator(y_, n, post), + z_, func_); } const T* x_; const T* y_; T* z_; int64_t nx_; - const framework::ExecutionContext& ctx_; + const platform::DeviceContext& ctx_; Functor func_; }; From 37671ac0539b69e2c5bb72dbaad22a96d633118c Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 5 Dec 2017 12:39:47 +0800 Subject: [PATCH 5/5] follow comments --- paddle/operators/elementwise_op_function.h | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 6d849bff49909..ec448a9e95640 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims, } template -struct RowwiseTransformIterator; +class RowwiseTransformIterator; template -struct MidWiseTransformIterator; +class MidWiseTransformIterator; template -struct RowwiseTransformIterator { +class RowwiseTransformIterator { + public: RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator& operator++() { @@ -86,13 +87,15 @@ struct RowwiseTransformIterator { const T& operator*() { return ptr_[i_]; } + private: const T* ptr_; int i_; int64_t n_; }; template -struct MidWiseTransformIterator { +class MidWiseTransformIterator { + public: MidWiseTransformIterator(const T* ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} @@ -113,6 +116,7 @@ struct MidWiseTransformIterator { const T& operator*() { return ptr_[i_]; } + private: const T* ptr_; int i_; int64_t j_; @@ -122,7 +126,7 @@ struct MidWiseTransformIterator { #ifdef __NVCC__ template -struct RowwiseTransformIterator +class RowwiseTransformIterator : public thrust::iterator_adaptor< RowwiseTransformIterator, const T*> { public: @@ -142,7 +146,7 @@ struct RowwiseTransformIterator }; template -struct MidWiseTransformIterator +class MidWiseTransformIterator : public thrust::iterator_adaptor< MidWiseTransformIterator, const T*> { public: @@ -164,7 +168,8 @@ struct MidWiseTransformIterator #endif template -struct TransformFunctor { +class TransformFunctor { + public: TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z, const platform::DeviceContext& ctx, Functor func) @@ -192,6 +197,7 @@ struct TransformFunctor { z_, func_); } + private: const T* x_; const T* y_; T* z_;