Skip to content

Commit

Permalink
fixed bug on Illegal mem access
Browse files Browse the repository at this point in the history
  • Loading branch information
vishalmehta1991 committed Dec 13, 2019
1 parent 8fe7326 commit a23ee8f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 24 deletions.
8 changes: 7 additions & 1 deletion cpp/src/decisiontree/levelalgo/common_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,13 @@ void convert_scatter_to_gather(const unsigned int *flagsptr,
nodecount, nodestart, n_nodes + 1,
tempmem->stream);
CUDA_CHECK(cudaGetLastError());

unsigned int *h_nodestart = (unsigned int *)(tempmem->h_split_binidx->data());
MLCommon::updateHost(h_nodestart, nodestart + n_nodes, 1, tempmem->stream);
CUDA_CHECK(cudaStreamSynchronize(tempmem->stream));
CUDA_CHECK(cudaMemsetAsync(nodecount, 0, n_nodes * sizeof(unsigned int),
tempmem->stream));
CUDA_CHECK(cudaMemsetAsync(
samplelist, 0, h_nodestart[0] * sizeof(unsigned int), tempmem->stream));
build_list<<<nblocks, nthreads, 0, tempmem->stream>>>(
flagsptr, nodestart, n_rows, nodecount, samplelist);
CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -293,6 +297,8 @@ void make_split_gather(const T *data, unsigned int *nodestart,
nodecount, nodestart, h_counter[0] + 1,
tempmem->stream);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaMemsetAsync(samplelist, 0, h_counter[0] * sizeof(unsigned int),
tempmem->stream));
CUDA_CHECK(cudaMemsetAsync(nodecount, 0, h_counter[0] * sizeof(unsigned int),
tempmem->stream));
build_list<<<nblocks, nthreads, 0, tempmem->stream>>>(
Expand Down
11 changes: 6 additions & 5 deletions cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,21 @@ void best_split_gather_classification(
const float min_impurity_split,
std::shared_ptr<TemporaryMemory<T, int>> tempmem,
SparseTreeNode<T, int> *d_sparsenodes, int *d_nodelist) {
const int TPB = TemporaryMemory<T, int>::gather_threads;
if (split_algo == 0) {
using E = typename MLCommon::Stats::encode_traits<T>::E;
T init_val = std::numeric_limits<T>::max();
size_t shmemsz = n_unique_labels * (nbins + 1) * sizeof(int);
best_split_gather_classification_minmax_kernel<T, E, FDEV>
<<<n_nodes, 64, shmemsz, tempmem->stream>>>(
best_split_gather_classification_minmax_kernel<T, E, FDEV, TPB>
<<<n_nodes, tempmem->gather_threads, shmemsz, tempmem->stream>>>(
data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, n_nodes,
n_unique_labels, nbins, nrows, Ncols, ncols_sampled, treesz,
min_impurity_split, init_val, d_sparsenodes, d_nodelist);
} else {
const T *d_question_ptr = tempmem->d_quantile->data();
size_t shmemsz = n_unique_labels * (nbins + 1) * sizeof(int);
best_split_gather_classification_kernel<T, QuantileQues<T>, FDEV>
<<<n_nodes, 64, shmemsz, tempmem->stream>>>(
best_split_gather_classification_kernel<T, QuantileQues<T>, FDEV, TPB>
<<<n_nodes, tempmem->gather_threads, shmemsz, tempmem->stream>>>(
data, labels, d_colids, d_colstart, d_question_ptr, d_nodestart,
d_samplelist, n_nodes, n_unique_labels, nbins, nrows, Ncols,
ncols_sampled, treesz, min_impurity_split, d_sparsenodes, d_nodelist);
Expand All @@ -362,7 +363,7 @@ void make_leaf_gather_classification(
std::shared_ptr<TemporaryMemory<T, int>> tempmem) {
size_t shmemsz = n_unique_labels * sizeof(int);
make_leaf_gather_classification_kernel<T, FDEV>
<<<n_nodes, 64, shmemsz, tempmem->stream>>>(
<<<n_nodes, tempmem->gather_threads, shmemsz, tempmem->stream>>>(
labels, nodestart, samplelist, n_unique_labels, d_sparsenodes, nodelist);
CUDA_CHECK(cudaGetLastError());
}
13 changes: 8 additions & 5 deletions cpp/src/decisiontree/levelalgo/levelhelper_regressor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -466,20 +466,22 @@ void best_split_gather_regression(
const int split_algo, const size_t treesz, const float min_impurity_split,
std::shared_ptr<TemporaryMemory<T, T>> tempmem,
SparseTreeNode<T, T> *d_sparsenodes, int *d_nodelist) {
const int TPB = TemporaryMemory<T, T>::gather_threads;

if (split_algo == 0) {
using E = typename MLCommon::Stats::encode_traits<T>::E;
T init_val = std::numeric_limits<T>::max();
size_t shmemsz = nbins * sizeof(int) + nbins * sizeof(T);
best_split_gather_regression_minmax_kernel<T, E>
<<<n_nodes, 64, shmemsz, tempmem->stream>>>(
best_split_gather_regression_minmax_kernel<T, E, TPB>
<<<n_nodes, tempmem->gather_threads, shmemsz, tempmem->stream>>>(
data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, n_nodes,
nbins, nrows, Ncols, ncols_sampled, treesz, min_impurity_split,
init_val, d_sparsenodes, d_nodelist);
} else {
const T *d_question_ptr = tempmem->d_quantile->data();
size_t shmemsz = nbins * sizeof(T) + nbins * sizeof(int);
best_split_gather_regression_kernel<T, QuantileQues<T>>
<<<n_nodes, 64, shmemsz, tempmem->stream>>>(
best_split_gather_regression_kernel<T, QuantileQues<T>, TPB>
<<<n_nodes, tempmem->gather_threads, shmemsz, tempmem->stream>>>(
data, labels, d_colids, d_colstart, d_question_ptr, d_nodestart,
d_samplelist, n_nodes, nbins, nrows, Ncols, ncols_sampled, treesz,
min_impurity_split, d_sparsenodes, d_nodelist);
Expand All @@ -492,7 +494,8 @@ void make_leaf_gather_regression(
const unsigned int *samplelist, SparseTreeNode<T, T> *d_sparsenodes,
int *nodelist, const int n_nodes,
std::shared_ptr<TemporaryMemory<T, T>> tempmem) {
make_leaf_gather_regression_kernel<<<n_nodes, 64, 0, tempmem->stream>>>(
make_leaf_gather_regression_kernel<<<n_nodes, tempmem->gather_threads, 0,
tempmem->stream>>>(
labels, nodestart, samplelist, d_sparsenodes, nodelist);
CUDA_CHECK(cudaGetLastError());
}
8 changes: 4 additions & 4 deletions cpp/src/decisiontree/levelalgo/levelkernel_classifier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ DI GainIdxPair bin_info_gain_classification(
return tid_pair;
}

template <typename T, typename QuestionType, typename FDEV>
template <typename T, typename QuestionType, typename FDEV, int TPB>
__global__ void best_split_gather_classification_kernel(
const T* __restrict__ data, const int* __restrict__ labels,
const unsigned int* __restrict__ colids,
Expand All @@ -364,7 +364,7 @@ __global__ void best_split_gather_classification_kernel(
__shared__ GainIdxPair shmem_pair;
__shared__ int shmem_col;
__shared__ float parent_metric;
typedef cub::BlockReduce<GainIdxPair, 64> BlockReduce;
typedef cub::BlockReduce<GainIdxPair, TPB> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

//shmemhist_left[n_unique_labels*nbins]
Expand Down Expand Up @@ -451,7 +451,7 @@ __global__ void best_split_gather_classification_kernel(
}

//The same as above but fused minmax at block level
template <typename T, typename E, typename FDEV>
template <typename T, typename E, typename FDEV, int TPB>
__global__ void best_split_gather_classification_minmax_kernel(
const T* __restrict__ data, const int* __restrict__ labels,
const unsigned int* __restrict__ colids,
Expand All @@ -467,7 +467,7 @@ __global__ void best_split_gather_classification_minmax_kernel(
__shared__ GainIdxPair shmem_pair;
__shared__ int shmem_col;
__shared__ float parent_metric;
typedef cub::BlockReduce<GainIdxPair, 64> BlockReduce;
typedef cub::BlockReduce<GainIdxPair, TPB> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T shmem_min, shmem_max, best_min, best_delta;
//shmemhist_left[n_unique_labels*nbins]
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/decisiontree/levelalgo/levelkernel_regressor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ DI GainIdxPair bin_info_gain_regression(const T sum_parent, const T *sum_left,
return tid_pair;
}

template <typename T, typename QuestionType>
template <typename T, typename QuestionType, int TPB>
__global__ void best_split_gather_regression_kernel(
const T *__restrict__ data, const T *__restrict__ labels,
const unsigned int *__restrict__ colids,
Expand All @@ -486,7 +486,7 @@ __global__ void best_split_gather_regression_kernel(
__shared__ T mean_parent;
__shared__ GainIdxPair shmem_pair;
__shared__ int shmem_col;
typedef cub::BlockReduce<GainIdxPair, 64> BlockReduce;
typedef cub::BlockReduce<GainIdxPair, TPB> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

int colstart_local = -1;
Expand Down Expand Up @@ -570,7 +570,7 @@ __global__ void best_split_gather_regression_kernel(
}

//Same as above but fused with minmax mode.
template <typename T, typename E>
template <typename T, typename E, int TPB>
__global__ void best_split_gather_regression_minmax_kernel(
const T *__restrict__ data, const T *__restrict__ labels,
const unsigned int *__restrict__ colids,
Expand All @@ -587,7 +587,7 @@ __global__ void best_split_gather_regression_minmax_kernel(
__shared__ T mean_parent;
__shared__ GainIdxPair shmem_pair;
__shared__ int shmem_col;
typedef cub::BlockReduce<GainIdxPair, 64> BlockReduce;
typedef cub::BlockReduce<GainIdxPair, TPB> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T shmem_min, shmem_max, best_min, best_delta;

Expand Down
10 changes: 6 additions & 4 deletions cpp/src/decisiontree/memory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void TemporaryMemory<T, L>::LevelMemAllocator(int nrows, int ncols,
// << gather_max_nodes << std::endl;
//std::cout << "Parent size --> " << parentsz << std::endl;
//std::cout << "Child size --> " << childsz << std::endl;
//std::cout << "Nrows size --> " << (nrows + 1) << std::endl;
d_flags =
new MLCommon::device_buffer<unsigned int>(device_allocator, stream, nrows);
h_new_node_flags =
Expand All @@ -104,14 +105,15 @@ void TemporaryMemory<T, L>::LevelMemAllocator(int nrows, int ncols,
new MLCommon::device_buffer<int>(device_allocator, stream, parentsz);
d_split_binidx =
new MLCommon::device_buffer<int>(device_allocator, stream, parentsz);
h_parent_metric = new MLCommon::host_buffer<T>(
host_allocator, stream, std::max(parentsz, (size_t)(nrows + 1)));
size_t metric_size = std::max(parentsz, (size_t)(nrows + 1));
h_parent_metric =
new MLCommon::host_buffer<T>(host_allocator, stream, metric_size);
d_parent_metric =
new MLCommon::device_buffer<T>(device_allocator, stream, metric_size);
h_child_best_metric =
new MLCommon::host_buffer<T>(host_allocator, stream, childsz);
h_outgain =
new MLCommon::host_buffer<float>(host_allocator, stream, parentsz);
d_parent_metric = new MLCommon::device_buffer<T>(
device_allocator, stream, std::max(parentsz, (size_t)(nrows + 1)));
d_child_best_metric =
new MLCommon::device_buffer<T>(device_allocator, stream, childsz);
d_outgain =
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/decisiontree/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
template <class T, class L>
struct TemporaryMemory {
//depth algorithm changer
const int swap_depth = 8;
const int swap_depth = 14;
static const int gather_threads = 256;
//Allocators parsed from CUML handle
std::shared_ptr<MLCommon::deviceAllocator> device_allocator;
std::shared_ptr<MLCommon::hostAllocator> host_allocator;
Expand Down

0 comments on commit a23ee8f

Please sign in to comment.