From fcd93b324f2720ed7661d37c8b0881225b7832be Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Sat, 12 Jun 2021 12:24:18 +0800 Subject: [PATCH] Support Div and FloorDiv functor in elementwise system (#33053) --- .../elementwise/elementwise_div_op.cu | 58 ++++++++-------- .../elementwise/elementwise_floordiv_op.cu | 34 +++++++++- .../elementwise/elementwise_floordiv_op.h | 1 - .../elementwise/elementwise_op_impl.cu.h | 67 +++++++++++++++---- 4 files changed, 114 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index b10ed57af901f..8853fd609f77c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -23,38 +22,37 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { +template +struct CudaDivFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] / args[1]; + } +}; + template -struct SameDimsElemwiseDiv { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - DivRangeFunctor functor(x->data(), y->data(), z->data()); - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, - x->numel()); - for_range(functor); +struct CudaDivFunctor::value>> { + inline HOSTDEVICE T operator()(const T* args) const { + PADDLE_ENFORCE(args[1] != 0, + "Invalid Argument Error: Integer division by zero " + "encountered in divide. Please check the input value."); + return args[0] / args[1]; } }; -template <> -struct SameDimsElemwiseDiv { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - auto size = x->numel(); - dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseDivCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - x2, y2, z2, size); +template +class ElementwiseDivKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaDivFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu index 60846d1e8fee1..a0510d95700b2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu @@ -12,11 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" namespace ops = paddle::operators; namespace plat = paddle::platform; +namespace paddle { +namespace operators { + +template +struct CudaFloorDivFunctor { + inline HOSTDEVICE T operator()(const T argv[]) const { + PADDLE_ENFORCE(argv[1] != 0, + "InvalidArgument: divide by zero " + "encountered in floor-divide ops, please check.\n"); + return static_cast(std::trunc(argv[0] / argv[1])); + } +}; + +template +class ElementwiseFloorDivKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_floordiv, ops::ElementwiseFloorDivKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h index 06eb0b1cc8510..bc3c2994c847c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 33a2b7e182f0a..101512e35fdcb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" #ifdef __HIPCC__ @@ -28,19 +28,62 @@ namespace operators { enum ElementwiseType { kUnary = 1, kBinary = 2 }; +/* +* According to NVIDIA, if number of threads per block is 64/128/256/512, +* cuda performs better. And number of blocks should be greater (at least +* 2x~4x) than number of SMs. Hence, SM count is took into account within +* this function to determine the right number of threads per block. +*/ +inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx, + int64_t numel, int vec_size) { + int threads = ELEMENTWISE_BLOCK_SIZE; + int sm_count = ctx.GetSMCount(); + int active_threads_num = numel / vec_size; + if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about twice of SM, to acquire better performance. + threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1)); + } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about 4 times of SM, to acquire better performance. + threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2)); + } + // Number of threads per block shall be larger than 64. + return std::max(64, threads); +} + +/* +* Only the address of input data is the multiplier of 1,2,4, vectorized load +* with corresponding multiplier-value is possible. Moreover, the maximum length +* of vectorized load is 128 bits once. Hence, valid length of vectorized load +* shall be determined under both former constraints. +*/ template int GetVectorizedSizeImpl(const T *pointer) { + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); uint64_t address = reinterpret_cast(pointer); + constexpr int vec8 = + std::alignment_of>::value; // NOLINT constexpr int vec4 = std::alignment_of>::value; // NOLINT constexpr int vec2 = std::alignment_of>::value; // NOLINT - if (address % vec4 == 0) { - return 4; + if (address % vec8 == 0) { + /* + * Currently, decide to deal with no more than 4 data once while adopting + * vectorization load/store, if performance test shows that dealing with + * 8 data once in vectorization load/store does get optimized, return code + * below can be changed into " return std::min(8, valid_vec_size); " . + */ + return std::min(4, valid_vec_size); + } else if (address % vec4 == 0) { + return std::min(4, valid_vec_size); } else if (address % vec2 == 0) { - return 2; + return std::min(2, valid_vec_size); + } else { + return 1; } - return 1; } template @@ -96,7 +139,7 @@ struct ElementwiseDataWrapper { template -__device__ void VectorizedKernelImpl( +__device__ inline void VectorizedKernelImpl( ElementwiseDataWrapper data, Functor func, int tid) { using InVecType = CudaAlignedVector; @@ -104,34 +147,30 @@ __device__ void VectorizedKernelImpl( InVecType ins_vec[ET]; OutVecType out_vec; InT *ins_ptr[ET]; - OutT *out_ptr; + InT ins[ET]; #pragma unroll for (int i = 0; i < ET; ++i) { ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); } - out_ptr = reinterpret_cast(&out_vec); - // load data.load_vector(ins_vec, tid); // compute #pragma unroll for (int i = 0; i < VecSize; ++i) { - InT ins[ET]; #pragma unroll for (int j = 0; j < ET; ++j) { ins[j] = ins_ptr[j][i]; } - out_ptr[i] = func(ins); + out_vec.val[i] = func(ins); } - // store data.store_vector(out_vec, tid); } template -__device__ void ScalarKernelImpl( +__device__ inline void ScalarKernelImpl( ElementwiseDataWrapper data, Functor func, int start, int remain) { InT ins[ET]; @@ -182,7 +221,7 @@ void LaunchSameDimsElementwiseCudaKernel( // calculate the max vec_size for all ins and outs auto size = ins[0]->numel(); int vec_size = GetVectorizedSize(ins, *outs); - int block_size = ELEMENTWISE_BLOCK_SIZE; + int block_size = GetThreadsConfig(ctx, size, vec_size); int grid_size = ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; const InT *in0 = ins[0]->data();