-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RF: memset and batch size optimization for computing splits #4001
Changes from 8 commits
797d157
93187f1
5fa1faf
c4b1548
4e20c28
98694d5
8a0762e
23ff5da
d9c8c04
68edd49
4b326af
142bed0
70665a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,13 +338,16 @@ struct Builder { | |
raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s); | ||
|
||
int total_samples_in_curr_batch = 0; | ||
int n_large_nodes_in_curr_batch = 0; | ||
total_num_blocks = 0; | ||
for (int n = 0; n < batchSize; n++) { | ||
total_samples_in_curr_batch += h_nodes[node_start + n].count; | ||
int num_blocks = raft::ceildiv(h_nodes[node_start + n].count, | ||
SAMPLES_PER_THREAD * Traits::TPB_DEFAULT); | ||
num_blocks = std::max(1, num_blocks); | ||
|
||
if (num_blocks > 1) ++n_large_nodes_in_curr_batch; | ||
|
||
bool is_leaf = leafBasedOnParams<DataT, IdxT>( | ||
h_nodes[node_start + n].depth, params.max_depth, | ||
params.min_samples_split, params.max_leaves, h_n_leaves, | ||
|
@@ -353,6 +356,8 @@ struct Builder { | |
|
||
for (int b = 0; b < num_blocks; b++) { | ||
h_workload_info[total_num_blocks + b].nodeid = n; | ||
h_workload_info[total_num_blocks + b].large_nodeid = | ||
n_large_nodes_in_curr_batch - 1; | ||
h_workload_info[total_num_blocks + b].offset_blockid = b; | ||
h_workload_info[total_num_blocks + b].num_blocks = num_blocks; | ||
} | ||
|
@@ -364,7 +369,8 @@ struct Builder { | |
auto n_col_blks = n_blks_for_cols; | ||
if (total_num_blocks) { | ||
for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) { | ||
Traits::computeSplit(*this, c, batchSize, params.split_criterion, s); | ||
Traits::computeSplit(*this, c, batchSize, params.split_criterion, | ||
n_large_nodes_in_curr_batch, s); | ||
CUDA_CHECK(cudaGetLastError()); | ||
} | ||
} | ||
|
@@ -426,7 +432,7 @@ struct ClsTraits { | |
*/ | ||
static void computeSplit(Builder<ClsTraits<DataT, LabelT, IdxT>>& b, IdxT col, | ||
IdxT batchSize, CRITERION splitType, | ||
cudaStream_t s) { | ||
int& n_large_nodes_in_curr_batch, cudaStream_t s) { | ||
ML::PUSH_RANGE( | ||
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); | ||
auto nbins = b.params.n_bins; | ||
|
@@ -446,7 +452,9 @@ struct ClsTraits { | |
// Pick the max of two | ||
size_t smemSize = std::max(smemSize1, smemSize2); | ||
dim3 grid(b.total_num_blocks, colBlks, 1); | ||
CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s)); | ||
int nHistBins = 0; | ||
nHistBins = n_large_nodes_in_curr_batch * (1 + nbins) * colBlks * nclasses; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0 initialising this variable is unnecessary. The 1+ in (1+nbins) should disappear when you merge the objective function PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, that was necessary in my initial prototyping version but missed it somehow.. changed it 👍🏾 |
||
CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * nHistBins, s)); | ||
ML::PUSH_RANGE( | ||
"computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]"); | ||
computeSplitClassificationKernel<DataT, LabelT, IdxT, TPB_DEFAULT> | ||
|
@@ -507,7 +515,7 @@ struct RegTraits { | |
*/ | ||
static void computeSplit(Builder<RegTraits<DataT, IdxT>>& b, IdxT col, | ||
IdxT batchSize, CRITERION splitType, | ||
cudaStream_t s) { | ||
int& n_large_nodes_in_curr_batch, cudaStream_t s) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is n_large_nodes_in_curr_batch passed by reference? Is it modified somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it was in my initial version, but you're right, unnecessary here. I have changed it to |
||
ML::PUSH_RANGE( | ||
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); | ||
auto colBlks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); | ||
|
@@ -529,10 +537,10 @@ struct RegTraits { | |
size_t smemSize = std::max(smemSize1, smemSize2); | ||
dim3 grid(b.total_num_blocks, colBlks, 1); | ||
|
||
CUDA_CHECK( | ||
cudaMemsetAsync(b.pred, 0, sizeof(DataT) * b.nPredCounts * 2, s)); | ||
CUDA_CHECK( | ||
cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * b.nPredCounts, s)); | ||
int nPredCounts = 0; | ||
nPredCounts = n_large_nodes_in_curr_batch * nbins * colBlks; | ||
CUDA_CHECK(cudaMemsetAsync(b.pred, 0, sizeof(DataT) * nPredCounts * 2, s)); | ||
CUDA_CHECK(cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * nPredCounts, s)); | ||
|
||
ML::PUSH_RANGE( | ||
"computeSplitRegressionKernel @builder_base.cuh [batched-levelalgo]"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ namespace DecisionTree { | |
template <typename IdxT> | ||
struct WorkloadInfo { | ||
IdxT nodeid; // Node in the batch on which the threadblock needs to work | ||
IdxT large_nodeid; // counts only large nodes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure you comment what large nodes means. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done 👍🏾 |
||
IdxT offset_blockid; // Offset threadblock id among all the blocks that are | ||
// working on this node | ||
IdxT num_blocks; // Total number of blocks that are working on the node | ||
|
@@ -387,6 +388,7 @@ __global__ void computeSplitClassificationKernel( | |
// Read workload info for this block | ||
WorkloadInfo<IdxT> workload_info_cta = workload_info[blockIdx.x]; | ||
IdxT nid = workload_info_cta.nodeid; | ||
IdxT large_nid = workload_info_cta.large_nodeid; | ||
auto node = nodes[nid]; | ||
auto range_start = node.start; | ||
auto range_len = node.count; | ||
|
@@ -445,7 +447,7 @@ __global__ void computeSplitClassificationKernel( | |
__syncthreads(); | ||
if (num_blocks > 1) { | ||
// update the corresponding global location | ||
auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len; | ||
auto histOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_shist_len; | ||
for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) { | ||
atomicAdd(hist + histOffset + i, pdf_shist[i]); | ||
} | ||
|
@@ -530,6 +532,7 @@ __global__ void computeSplitRegressionKernel( | |
// Read workload info for this block | ||
WorkloadInfo<IdxT> workload_info_cta = workload_info[blockIdx.x]; | ||
IdxT nid = workload_info_cta.nodeid; | ||
IdxT large_nid = workload_info_cta.large_nodeid; | ||
|
||
auto node = nodes[nid]; | ||
auto range_start = node.start; | ||
|
@@ -598,13 +601,13 @@ __global__ void computeSplitRegressionKernel( | |
|
||
if (num_blocks > 1) { | ||
// update the corresponding global location for counts | ||
auto gcOffset = ((nid * gridDim.y) + blockIdx.y) * nbins; | ||
auto gcOffset = ((large_nid * gridDim.y) + blockIdx.y) * nbins; | ||
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { | ||
atomicAdd(count + gcOffset + i, pdf_scount[i]); | ||
} | ||
|
||
// update the corresponding global location for preds | ||
auto gOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_spred_len; | ||
auto gOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_spred_len; | ||
for (IdxT i = threadIdx.x; i < pdf_spred_len; i += blockDim.x) { | ||
atomicAdd(pred + gOffset + i, pdf_spred[i]); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment what this variable is. e.g. nodes with number of training instances larger than block size. These nodes require global memory for histograms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 👍🏾