From a77d6fbe3dc744114fce89991b5fbf6b47060b4e Mon Sep 17 00:00:00 2001 From: Ganesh Venkataramana Date: Wed, 3 Jul 2019 23:39:26 -0700 Subject: [PATCH] review changes #1 - refactored code to avoid excess spacing - avoided the raw numerically unstable computations of tanh and arctanh - exposed the tanh and arctanh functions available in cuda toolkit to cuml - removed redundant cudaStreamSynchronize statements --- cpp/src_prims/cuda_utils.h | 32 ++ cpp/src_prims/timeSeries/jones_transform.h | 423 +++++++++------------ 2 files changed, 212 insertions(+), 243 deletions(-) diff --git a/cpp/src_prims/cuda_utils.h b/cpp/src_prims/cuda_utils.h index 0a7a308bad..5a533118b8 100644 --- a/cpp/src_prims/cuda_utils.h +++ b/cpp/src_prims/cuda_utils.h @@ -421,6 +421,38 @@ HDI double myPow(double x, double power) { } /** @} */ +/** + * @defgroup myTanh tanh function + * @{ + */ +template +HDI T myTanh(T x); +template <> +HDI float myTanh(float x) { + return tanhf(x); +} +template <> +HDI double myTanh(double x) { + return tanh(x); +} +/** @} */ + +/** + * @defgroup myATanh arctanh function + * @{ + */ +template +HDI T myATanh(T x); +template <> +HDI float myATanh(float x) { + return atanhf(x); +} +template <> +HDI double myATanh(double x) { + return atanh(x); +} +/** @} */ + /** * @defgroup LambdaOps Lambda operations in reduction kernels * @{ diff --git a/cpp/src_prims/timeSeries/jones_transform.h b/cpp/src_prims/timeSeries/jones_transform.h index 63b788141a..8f65ab2ebb 100644 --- a/cpp/src_prims/timeSeries/jones_transform.h +++ b/cpp/src_prims/timeSeries/jones_transform.h @@ -1,5 +1,3 @@ - - /* * Copyright (c) 2019, NVIDIA CORPORATION. * @@ -16,42 +14,21 @@ * limitations under the License. */ /** -* @file ar_param_transform.h -* @brief TODO brief +* @file jones_transform.h +* @brief Transforms params to induce stationarity/invertability. +* reference: Jones(1980) */ #include -#include #include "common/cuml_allocator.hpp" #include "common/device_buffer.hpp" #include "cuda_utils.h" #include "linalg/unary_op.h" - namespace MLCommon { namespace TimeSeries { -//just a helper function to display stuff -template -void display_helper(DataT *arr, int row, int col, cudaStream_t stream){ - - DataT *h_arr = (DataT*) malloc(row*col*sizeof(DataT*)); - - updateHost(h_arr, arr, row*col, stream); - - CUDA_CHECK(cudaStreamSynchronize(stream)); - - for(int i = 0; i struct PAC { - HDI Type operator()(Type in) { return ((1- myExp(-1*in))/(1+myExp(-1*in))); } -}; - - -/** -* @brief Lambda to map to the arctanh -* -* @tparam Type: Data type of the input -* @param in: the input to the functional mapping -* @return : arctanh() of the input -*/ -template -struct arctanh { - HDI Type operator()(Type in) { return (log(1+in) - log(1-in))/2; } + HDI Type operator()(Type in) { return myTanh(in / 2); } }; /** @@ -87,51 +51,43 @@ struct arctanh { * @param newParams: pointer to the memory where the new params are to be stored, which is also where the initial mapped input is stored * @param batchSize: number of models in a batch */ -template -__global__ void ar_param_invtransform_kernel(DataT *newParams, IdxT batchSize) { - - arctanh arctanh; +template +__global__ void ar_param_invtransform_kernel(DataT* newParams, IdxT batchSize) { //calculating the index of the model that the coefficients belong to IdxT modelIndex = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); - DataT tmp[P_VALUE]; - DataT myNewParams[P_VALUE]; + DataT tmp[P_VALUE]; + DataT myNewParams[P_VALUE]; - if(modelIndex0; --j){ - + for (int j = P_VALUE - 1; j > 0; --j) { DataT a = myNewParams[j]; - for(int k = 0; k -__global__ void ma_param_invtransform_kernel(DataT *newParams, IdxT batchSize) { - - arctanh arctanh; +template +__global__ void ma_param_invtransform_kernel(DataT* newParams, IdxT batchSize) { //calculating the index of the model that the coefficients belong to IdxT modelIndex = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); + IdxT k, j; - DataT tmp[Q_VALUE]; - DataT myNewParams[Q_VALUE]; + DataT tmp[Q_VALUE]; + DataT myNewParams[Q_VALUE]; - if(modelIndex0; --j){ - + for (j = Q_VALUE - 1; j > 0; --j) { DataT b = myNewParams[j]; - for(int k = 0; k -__global__ void ar_param_transform_kernel(DataT *newParams, IdxT batchSize) { - +template +__global__ void ar_param_transform_kernel(DataT* newParams, IdxT batchSize) { //calculating the index of the model that the coefficients belong to IdxT modelIndex = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); - DataT tmp[P_VALUE]; - DataT myNewParams[P_VALUE]; + DataT tmp[P_VALUE]; + DataT myNewParams[P_VALUE]; - if(modelIndex -__global__ void ma_param_transform_kernel(DataT *newParams, IdxT batchSize) { - +template +__global__ void ma_param_transform_kernel(DataT* newParams, IdxT batchSize) { //calculating the index of the model that the coefficients belong to IdxT modelIndex = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); - DataT tmp[Q_VALUE]; - DataT myNewParams[Q_VALUE]; + DataT tmp[Q_VALUE]; + DataT myNewParams[Q_VALUE]; - if(modelIndex * @param stream: the cudaStream object */ -template -void ar_param_transform( - const DataT* params, IdxT batchSize, IdxT pValue, - DataT* newParams, std::shared_ptr allocator, cudaStream_t stream) { - - //rand index for size less than 2 is not defined - ASSERT( batchSize>= 1 && pValue>=1, "not defined!"); +template +void ar_param_transform(const DataT* params, IdxT batchSize, IdxT pValue, + DataT* newParams, + std::shared_ptr allocator, + cudaStream_t stream) { + ASSERT(batchSize >= 1 && pValue >= 1, "not defined!"); - IdxT nElements = batchSize*pValue; + IdxT nElements = batchSize * pValue; //elementWise transforming the params matrix LinAlg::unaryOp(newParams, params, nElements, PAC(), stream); @@ -317,32 +254,33 @@ void ar_param_transform( //setting the kernel configuration static const int BLOCK_DIM_Y = 1, BLOCK_DIM_X = 256; dim3 numThreadsPerBlock(BLOCK_DIM_X, BLOCK_DIM_Y); - dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x),1); + dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x), 1); //calling the kernel - switch(pValue){ - - case 1: ar_param_transform_kernel - <<>>(newParams, batchSize); - break; - case 2: ar_param_transform_kernel - <<>>(newParams, batchSize); - break; - case 3: ar_param_transform_kernel - <<>>(newParams, batchSize); - break; - case 4: ar_param_transform_kernel - <<>>(newParams, batchSize); - break; - default: ASSERT(false, "Unsupported pValue '%d'!", pValue); + switch (pValue) { + case 1: + ar_param_transform_kernel + <<>>(newParams, batchSize); + break; + case 2: + ar_param_transform_kernel + <<>>(newParams, batchSize); + break; + case 3: + ar_param_transform_kernel + <<>>(newParams, batchSize); + break; + case 4: + ar_param_transform_kernel + <<>>(newParams, batchSize); + break; + default: + ASSERT(false, "Unsupported pValue '%d'!", pValue); } CUDA_CHECK(cudaPeekAtLastError()); - - CUDA_CHECK(cudaStreamSynchronize(stream)); } - /** * @brief Host Function to batchwise transform the autoregressive coefficients according to "jone's (1980)" transformation * @@ -353,15 +291,13 @@ void ar_param_transform( * @param allocator: object that takes care of temporary device memory allocation of type std::shared_ptr * @param stream: the cudaStream object */ -template +template void ar_param_inverse_transform( - const DataT* params, IdxT batchSize, IdxT pValue, - DataT* newParams, std::shared_ptr allocator, cudaStream_t stream) { + const DataT* params, IdxT batchSize, IdxT pValue, DataT* newParams, + std::shared_ptr allocator, cudaStream_t stream) { + ASSERT(batchSize >= 1 && pValue >= 1, "not defined!"); - //rand index for size less than 2 is not defined - ASSERT( batchSize>= 1 && pValue>=1, "not defined!"); - - IdxT nElements = batchSize*pValue; + IdxT nElements = batchSize * pValue; //elementWise transforming the params matrix copy(newParams, params, (size_t)nElements, stream); @@ -369,32 +305,33 @@ void ar_param_inverse_transform( //setting the kernel configuration static const int BLOCK_DIM_Y = 1, BLOCK_DIM_X = 256; dim3 numThreadsPerBlock(BLOCK_DIM_X, BLOCK_DIM_Y); - dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x),1); + dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x), 1); //calling the kernel - switch(pValue){ - - case 1: ar_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - case 2: ar_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - case 3: ar_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - case 4: ar_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - default: ASSERT(false, "Unsupported pValue '%d'!", pValue); + switch (pValue) { + case 1: + ar_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + case 2: + ar_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + case 3: + ar_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + case 4: + ar_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + default: + ASSERT(false, "Unsupported pValue '%d'!", pValue); } CUDA_CHECK(cudaPeekAtLastError()); - - CUDA_CHECK(cudaStreamSynchronize(stream)); } - /** * @brief Host Function to batchwise transform the moving average coefficients according to "jone's (1980)" transformation * @@ -405,14 +342,14 @@ void ar_param_inverse_transform( * @param allocator: object that takes care of temporary device memory allocation of type std::shared_ptr * @param stream: the cudaStream object */ -template -void ma_param_transform( - const DataT* params, IdxT batchSize, IdxT qValue, - DataT* newParams, std::shared_ptr allocator, cudaStream_t stream) { - //rand index for size less than 2 is not defined - ASSERT( batchSize>= 1 && qValue>=1, "not defined!"); +template +void ma_param_transform(const DataT* params, IdxT batchSize, IdxT qValue, + DataT* newParams, + std::shared_ptr allocator, + cudaStream_t stream) { + ASSERT(batchSize >= 1 && qValue >= 1, "not defined!"); - IdxT nElements = batchSize*qValue; + IdxT nElements = batchSize * qValue; //elementWise transforming the params matrix LinAlg::unaryOp(newParams, params, nElements, PAC(), stream); @@ -420,32 +357,33 @@ void ma_param_transform( //setting the kernel configuration static const int BLOCK_DIM_Y = 1, BLOCK_DIM_X = 256; dim3 numThreadsPerBlock(BLOCK_DIM_X, BLOCK_DIM_Y); - dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x),1); + dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x), 1); //calling the kernel - switch(qValue){ - - case 1: ma_param_transform_kernel - <<>>(newParams, batchSize); - break; - case 2: ma_param_transform_kernel - <<>>(newParams, batchSize); - break; - case 3: ma_param_transform_kernel - <<>>(newParams, batchSize); - break; - case 4: ma_param_transform_kernel - <<>>(newParams, batchSize); - break; - default: ASSERT(false, "Unsupported qValue '%d'!", qValue); + switch (qValue) { + case 1: + ma_param_transform_kernel + <<>>(newParams, batchSize); + break; + case 2: + ma_param_transform_kernel + <<>>(newParams, batchSize); + break; + case 3: + ma_param_transform_kernel + <<>>(newParams, batchSize); + break; + case 4: + ma_param_transform_kernel + <<>>(newParams, batchSize); + break; + default: + ASSERT(false, "Unsupported qValue '%d'!", qValue); } CUDA_CHECK(cudaPeekAtLastError()); - - CUDA_CHECK(cudaStreamSynchronize(stream)); } - /** * @brief Host Function to batchwise inverse transform the moving average coefficients according to "jone's (1980)" transformation * @@ -456,47 +394,46 @@ void ma_param_transform( * @param allocator: object that takes care of temporary device memory allocation of type std::shared_ptr * @param stream: the cudaStream object */ -template +template void ma_param_inverse_transform( - const DataT* params, IdxT batchSize, IdxT qValue, - DataT* newParams, std::shared_ptr allocator, cudaStream_t stream) { - - //rand index for size less than 2 is not defined - ASSERT( batchSize>= 1 && qValue>=1, "not defined!"); + const DataT* params, IdxT batchSize, IdxT qValue, DataT* newParams, + std::shared_ptr allocator, cudaStream_t stream) { + ASSERT(batchSize >= 1 && qValue >= 1, "not defined!"); - IdxT nElements = batchSize*qValue; + IdxT nElements = batchSize * qValue; -//copying contents + //copying contents copy(newParams, params, (size_t)nElements, stream); //setting the kernel configuration static const int BLOCK_DIM_Y = 1, BLOCK_DIM_X = 256; dim3 numThreadsPerBlock(BLOCK_DIM_X, BLOCK_DIM_Y); - dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x),1); + dim3 numBlocks(ceildiv(batchSize, numThreadsPerBlock.x), 1); //calling the kernel - switch(qValue){ - - case 1: ma_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - case 2: ma_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - case 3: ma_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - case 4: ma_param_invtransform_kernel - <<>>(newParams, batchSize); - break; - default: ASSERT(false, "Unsupported qValue '%d'!", qValue); + switch (qValue) { + case 1: + ma_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + case 2: + ma_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + case 3: + ma_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + case 4: + ma_param_invtransform_kernel + <<>>(newParams, batchSize); + break; + default: + ASSERT(false, "Unsupported qValue '%d'!", qValue); } CUDA_CHECK(cudaPeekAtLastError()); - - CUDA_CHECK(cudaStreamSynchronize(stream)); } - }; //end namespace TimeSeries }; //end namespace MLCommon