From f606cb8ef4d256e3bb8a9efbe001af0d83275a0f Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Sat, 8 Sep 2018 04:48:45 +0200 Subject: [PATCH] Fixed the performance regression within EvaluateSplits(). (#3680) - it turns out creating an std::vector on every call is faster than cudaMallocHost()/cudaFreeHost() --- src/tree/updater_gpu_hist.cu | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 8231182132e9..0c559e4ed23e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -747,6 +747,7 @@ class GPUHistMaker : public TreeUpdater { struct ExpandEntry; GPUHistMaker() : initialised_(false), p_last_fmat_(nullptr) {} + void Init( const std::vector>& args) override { param_.InitAllowUnknown(args); @@ -919,9 +920,7 @@ class GPUHistMaker : public TreeUpdater { const std::vector& nidx_set, RegTree* p_tree) { auto columns = info_->num_col_; std::vector best_splits(nidx_set.size()); - DeviceSplitCandidate* candidate_splits; - dh::safe_cuda(cudaMallocHost(&candidate_splits, nidx_set.size() * - columns * sizeof(DeviceSplitCandidate))); + std::vector candidate_splits(nidx_set.size() * columns); // Use first device auto& shard = shards_.front(); dh::safe_cuda(cudaSetDevice(shard->device_idx)); @@ -952,10 +951,10 @@ class GPUHistMaker : public TreeUpdater { } dh::safe_cuda(cudaDeviceSynchronize()); - dh::safe_cuda( - cudaMemcpy(candidate_splits, shard->temp_memory.d_temp_storage, - sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), - cudaMemcpyDeviceToHost)); + dh::safe_cuda + (cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage, + sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), + cudaMemcpyDeviceToHost)); for (auto i = 0; i < nidx_set.size(); i++) { auto depth = p_tree->GetDepth(nidx_set[i]); DeviceSplitCandidate nidx_best; @@ -965,7 +964,6 @@ class GPUHistMaker : public TreeUpdater { } best_splits[i] = nidx_best; } - dh::safe_cuda(cudaFreeHost(candidate_splits)); return std::move(best_splits); }