diff --git a/src/common/span.h b/src/common/span.h index a618b682622b..a59c39e72222 100644 --- a/src/common/span.h +++ b/src/common/span.h @@ -120,7 +120,7 @@ class SpanIterator { using reference = typename std::conditional< // NOLINT IsConst, const ElementType, ElementType>::type&; - using pointer = typename std::add_pointer::type&; // NOLINT + using pointer = typename std::add_pointer::type; // NOLINT XGBOOST_DEVICE constexpr SpanIterator() : span_{nullptr}, index_{0} {} diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index e33a40b7f1f6..7a747a7c5c7f 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -194,8 +194,9 @@ class GBTree : public GradientBooster { CHECK_EQ(in_gpair->Size() % ngroup, 0U) << "must have exactly ngroup*nrow gpairs"; // TODO(canonizer): perform this on GPU if HostDeviceVector has device set. - HostDeviceVector tmp(in_gpair->Size() / ngroup, - GradientPair(), in_gpair->Distribution()); + HostDeviceVector tmp + (in_gpair->Size() / ngroup, GradientPair(), + GPUDistribution::Block(in_gpair->Distribution().Devices())); const auto& gpair_h = in_gpair->ConstHostVector(); auto nsize = static_cast(tmp.Size()); for (int gid = 0; gid < ngroup; ++gid) { diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index fdc5505fc6e5..e46716ce4349 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -22,7 +22,7 @@ struct HingeObjParam : public dmlc::Parameter { int n_gpus; int gpu_id; DMLC_DECLARE_PARAMETER(HingeObjParam) { - DMLC_DECLARE_FIELD(n_gpus).set_default(0).set_lower_bound(0) + DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1) .describe("Number of GPUs to use for multi-gpu algorithms."); DMLC_DECLARE_FIELD(gpu_id) .set_lower_bound(0) diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 2149af0f9279..317adda707f4 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -31,7 +31,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) { DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) .describe("Number of output class in the multi-class classification."); - DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1) .describe("Number of GPUs to use for multi-gpu algorithms."); DMLC_DECLARE_FIELD(gpu_id) .set_lower_bound(0) @@ -64,10 +64,6 @@ class SoftmaxMultiClassObj : public ObjFunction { const int nclass = param_.num_class; const auto ndata = static_cast(preds.Size() / nclass); - // clear out device memory; - out_gpair->Reshard(GPUSet::Empty()); - preds.Reshard(GPUSet::Empty()); - out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass)); info.labels_.Reshard(GPUDistribution::Block(devices_)); info.weights_.Reshard(GPUDistribution::Block(devices_)); @@ -109,11 +105,6 @@ class SoftmaxMultiClassObj : public ObjFunction { }, common::Range{0, ndata}, devices_, false) .Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_); - out_gpair->Reshard(GPUSet::Empty()); - out_gpair->Reshard(GPUDistribution::Block(devices_)); - preds.Reshard(GPUSet::Empty()); - preds.Reshard(GPUDistribution::Block(devices_)); - std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { if (flag != 1) { @@ -136,7 +127,6 @@ class SoftmaxMultiClassObj : public ObjFunction { const auto ndata = static_cast(io_preds->Size() / nclass); max_preds_.Resize(ndata); - io_preds->Reshard(GPUSet::Empty()); // clear out device memory if (prob) { common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { @@ -166,8 +156,6 @@ class SoftmaxMultiClassObj : public ObjFunction { io_preds->Resize(max_preds_.Size()); io_preds->Copy(max_preds_); } - io_preds->Reshard(GPUSet::Empty()); // clear out device memory - io_preds->Reshard(GPUDistribution::Block(devices_)); } private: diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index e74c82af1c95..590072d8f9d2 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -34,7 +34,7 @@ struct RegLossParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(RegLossParam) { DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) .describe("Scale the weight of positive examples by this factor"); - DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1) + DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1) .describe("Number of GPUs to use for multi-gpu algorithms."); DMLC_DECLARE_FIELD(gpu_id) .set_lower_bound(0) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 18c23c12662b..3787a9b1c9a9 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -27,10 +27,10 @@ struct GPUPredictionParam : public dmlc::Parameter { bool silent; // declare parameters DMLC_DECLARE_PARAMETER(GPUPredictionParam) { - DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe( + DMLC_DECLARE_FIELD(gpu_id).set_lower_bound(0).set_default(0).describe( "Device ordinal for GPU prediction."); - DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe( - "Number of devices to use for prediction (NOT IMPLEMENTED)."); + DMLC_DECLARE_FIELD(n_gpus).set_lower_bound(-1).set_default(1).describe( + "Number of devices to use for prediction."); DMLC_DECLARE_FIELD(silent).set_default(false).describe( "Do not print information during trainig."); } @@ -43,53 +43,12 @@ void IncrementOffset(IterT begin_itr, IterT end_itr, size_t amount) { [=] __device__(size_t elem) { return elem + amount; }); } -/** - * \struct DeviceMatrix - * - * \brief A csr representation of the input matrix allocated on the device. - */ - -struct DeviceMatrix { - DMatrix* p_mat; // Pointer to the original matrix on the host - dh::BulkAllocator ba; - dh::DVec row_ptr; - dh::DVec data; - thrust::device_vector predictions; - - DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) { - dh::safe_cuda(cudaSetDevice(device_idx)); - const auto& info = dmat->Info(); - ba.Allocate(device_idx, silent, &row_ptr, info.num_row_ + 1, &data, - info.num_nonzero_); - size_t data_offset = 0; - for (const auto &batch : dmat->GetRowBatches()) { - const auto& offset_vec = batch.offset.HostVector(); - const auto& data_vec = batch.data.HostVector(); - // Copy row ptr - dh::safe_cuda(cudaMemcpy( - row_ptr.Data() + batch.base_rowid, offset_vec.data(), - sizeof(size_t) * offset_vec.size(), cudaMemcpyHostToDevice)); - if (batch.base_rowid > 0) { - auto begin_itr = row_ptr.tbegin() + batch.base_rowid; - auto end_itr = begin_itr + batch.Size() + 1; - IncrementOffset(begin_itr, end_itr, batch.base_rowid); - } - dh::safe_cuda(cudaMemcpy(data.Data() + data_offset, data_vec.data(), - sizeof(Entry) * data_vec.size(), - cudaMemcpyHostToDevice)); - // Copy data - data_offset += batch.data.Size(); - } - } -}; - /** * \struct DevicePredictionNode * * \brief Packed 16 byte representation of a tree node for use in device * prediction */ - struct DevicePredictionNode { XGBOOST_DEVICE DevicePredictionNode() : fidx(-1), left_child_idx(-1), right_child_idx(-1) {} @@ -105,6 +64,7 @@ struct DevicePredictionNode { NodeValue val; DevicePredictionNode(const RegTree::Node& n) { // NOLINT + static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes"); this->left_child_idx = n.LeftChild(); this->right_child_idx = n.RightChild(); this->fidx = n.SplitIndex(); @@ -140,19 +100,21 @@ struct DevicePredictionNode { struct ElementLoader { bool use_shared; - size_t* d_row_ptr; - Entry* d_data; + common::Span d_row_ptr; + common::Span d_data; int num_features; float* smem; + size_t entry_start; - __device__ ElementLoader(bool use_shared, size_t* row_ptr, - Entry* entry, int num_features, - float* smem, int num_rows) + __device__ ElementLoader(bool use_shared, common::Span row_ptr, + common::Span entry, int num_features, + float* smem, int num_rows, size_t entry_start) : use_shared(use_shared), d_row_ptr(row_ptr), d_data(entry), num_features(num_features), - smem(smem) { + smem(smem), + entry_start(entry_start) { // Copy instances if (use_shared) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; @@ -163,7 +125,7 @@ struct ElementLoader { bst_uint elem_begin = d_row_ptr[global_idx]; bst_uint elem_end = d_row_ptr[global_idx + 1]; for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { - Entry elem = d_data[elem_idx]; + Entry elem = d_data[elem_idx - entry_start]; smem[threadIdx.x * num_features + elem.index] = elem.fvalue; } } @@ -175,9 +137,9 @@ struct ElementLoader { return smem[threadIdx.x * num_features + fidx]; } else { // Binary search - auto begin_ptr = d_data + d_row_ptr[ridx]; - auto end_ptr = d_data + d_row_ptr[ridx + 1]; - Entry* previous_middle = nullptr; + auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start); + auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start); + common::Span::iterator previous_middle; while (end_ptr != begin_ptr) { auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; if (middle == previous_middle) { @@ -220,22 +182,25 @@ __device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree, } template -__global__ void PredictKernel(const DevicePredictionNode* d_nodes, - float* d_out_predictions, size_t* d_tree_segments, - int* d_tree_group, size_t* d_row_ptr, - Entry* d_data, size_t tree_begin, +__global__ void PredictKernel(common::Span d_nodes, + common::Span d_out_predictions, + common::Span d_tree_segments, + common::Span d_tree_group, + common::Span d_row_ptr, + common::Span d_data, size_t tree_begin, size_t tree_end, size_t num_features, - size_t num_rows, bool use_shared, int num_group) { + size_t num_rows, size_t entry_start, + bool use_shared, int num_group) { extern __shared__ float smem[]; bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem, - num_rows); + num_rows, entry_start); if (global_idx >= num_rows) return; if (num_group == 1) { float sum = 0; for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { const DevicePredictionNode* d_tree = - d_nodes + d_tree_segments[tree_idx - tree_begin]; + &d_nodes[d_tree_segments[tree_idx - tree_begin]]; sum += GetLeafWeight(global_idx, d_tree, &loader); } d_out_predictions[global_idx] += sum; @@ -243,7 +208,7 @@ __global__ void PredictKernel(const DevicePredictionNode* d_nodes, for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { int tree_group = d_tree_group[tree_idx]; const DevicePredictionNode* d_tree = - d_nodes + d_tree_segments[tree_idx - tree_begin]; + &d_nodes[d_tree_segments[tree_idx - tree_begin]]; bst_uint out_prediction_idx = global_idx * num_group + tree_group; d_out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, d_tree, &loader); @@ -259,31 +224,89 @@ class GPUPredictor : public xgboost::Predictor { }; private: - void DevicePredictInternal(DMatrix* dmat, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, size_t tree_begin, - size_t tree_end) { - if (tree_end - tree_begin == 0) { - return; + void DeviceOffsets(const HostDeviceVector& data, std::vector* out_offsets) { + auto& offsets = *out_offsets; + offsets.resize(devices_.Size() + 1); + offsets[0] = 0; +#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1) + for (int shard = 0; shard < devices_.Size(); ++shard) { + int device = devices_[shard]; + auto data_span = data.DeviceSpan(device); + dh::safe_cuda(cudaSetDevice(device)); + // copy the last element from every shard + dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1), + &data_span[data_span.size()-1], + sizeof(size_t), cudaMemcpyDeviceToHost)); } + } - std::shared_ptr device_matrix; - // Matrix is not in host cache, create a temporary matrix - if (this->cache_.find(dmat) == this->cache_.end()) { - device_matrix = std::shared_ptr( - new DeviceMatrix(dmat, param.gpu_id, param.silent)); - } else { - // Create this matrix on device if doesn't exist - if (this->device_matrix_cache_.find(dmat) == - this->device_matrix_cache_.end()) { - this->device_matrix_cache_.emplace( - dmat, std::shared_ptr( - new DeviceMatrix(dmat, param.gpu_id, param.silent))); + struct DeviceShard { + DeviceShard() : device_(-1) {} + void Init(int device) { + this->device_ = device; + max_shared_memory_bytes = dh::MaxSharedMemory(this->device_); + } + void PredictInternal + (const SparsePage& batch, const MetaInfo& info, + HostDeviceVector* predictions, + const gbm::GBTreeModel& model, + const thrust::host_vector& h_tree_segments, + const thrust::host_vector& h_nodes, + size_t tree_begin, size_t tree_end) { + dh::safe_cuda(cudaSetDevice(device_)); + nodes.resize(h_nodes.size()); + dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(), + sizeof(DevicePredictionNode) * h_nodes.size(), + cudaMemcpyHostToDevice)); + tree_segments.resize(h_tree_segments.size()); + + dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(), + sizeof(size_t) * h_tree_segments.size(), + cudaMemcpyHostToDevice)); + tree_group.resize(model.tree_info.size()); + + dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(), + sizeof(int) * model.tree_info.size(), + cudaMemcpyHostToDevice)); + + const int BLOCK_THREADS = 128; + size_t num_rows = batch.offset.DeviceSize(device_) - 1; + + const int GRID_SIZE = static_cast(dh::DivRoundUp(num_rows, BLOCK_THREADS)); + + int shared_memory_bytes = static_cast + (sizeof(float) * info.num_col_ * BLOCK_THREADS); + bool use_shared = true; + if (shared_memory_bytes > max_shared_memory_bytes) { + shared_memory_bytes = 0; + use_shared = false; } - device_matrix = device_matrix_cache_.find(dmat)->second; + const auto& data_distr = batch.data.Distribution(); + int index = data_distr.Devices().Index(device_); + size_t entry_start = data_distr.ShardStart(batch.data.Size(), index); + + PredictKernel<<>> + (dh::ToSpan(nodes), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments), + dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_), + batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, + num_rows, entry_start, use_shared, model.param.num_output_group); + + dh::safe_cuda(cudaDeviceSynchronize()); } - dh::safe_cuda(cudaSetDevice(param.gpu_id)); + int device_; + thrust::device_vector nodes; + thrust::device_vector tree_segments; + thrust::device_vector tree_group; + size_t max_shared_memory_bytes; + }; + + void DevicePredictInternal(DMatrix* dmat, + HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, size_t tree_begin, + size_t tree_end) { + if (tree_end - tree_begin == 0) { return; } + CHECK_EQ(model.param.size_leaf_vector, 0); // Copy decision trees to device thrust::host_vector h_tree_segments; @@ -291,61 +314,33 @@ class GPUPredictor : public xgboost::Predictor { size_t sum = 0; h_tree_segments.push_back(sum); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - sum += model.trees[tree_idx]->GetNodes().size(); + sum += model.trees.at(tree_idx)->GetNodes().size(); h_tree_segments.push_back(sum); } thrust::host_vector h_nodes(h_tree_segments.back()); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - auto& src_nodes = model.trees[tree_idx]->GetNodes(); + auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); std::copy(src_nodes.begin(), src_nodes.end(), h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); } - nodes.resize(h_nodes.size()); - dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(), - sizeof(DevicePredictionNode) * h_nodes.size(), - cudaMemcpyHostToDevice)); - tree_segments.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(), - sizeof(size_t) * h_tree_segments.size(), - cudaMemcpyHostToDevice)); - tree_group.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(), - sizeof(int) * model.tree_info.size(), - cudaMemcpyHostToDevice)); - - device_matrix->predictions.resize(out_preds->Size()); - auto& predictions = device_matrix->predictions; - out_preds->GatherTo(predictions.data(), - predictions.data() + predictions.size()); - - dh::safe_cuda(cudaSetDevice(param.gpu_id)); - - const int BLOCK_THREADS = 128; - const int GRID_SIZE = static_cast( - dh::DivRoundUp(device_matrix->row_ptr.Size() - 1, BLOCK_THREADS)); - - int shared_memory_bytes = static_cast( - sizeof(float) * device_matrix->p_mat->Info().num_col_ * BLOCK_THREADS); - bool use_shared = true; - if (shared_memory_bytes > max_shared_memory_bytes) { - shared_memory_bytes = 0; - use_shared = false; - } + size_t i_batch = 0; - PredictKernel - <<>>( - dh::Raw(nodes), dh::Raw(device_matrix->predictions), - dh::Raw(tree_segments), dh::Raw(tree_group), - device_matrix->row_ptr.Data(), device_matrix->data.Data(), - tree_begin, tree_end, device_matrix->p_mat->Info().num_col_, - device_matrix->p_mat->Info().num_row_, use_shared, - model.param.num_output_group); - - dh::safe_cuda(cudaDeviceSynchronize()); - out_preds->ScatterFrom(predictions.data(), - predictions.data() + predictions.size()); + for (const auto &batch : dmat->GetRowBatches()) { + CHECK_EQ(i_batch, 0) << "External memory not supported"; + size_t n_rows = batch.offset.Size() - 1; + // out_preds have been resharded and resized in InitOutPredictions() + batch.offset.Reshard(GPUDistribution::Overlap(devices_, 1)); + std::vector device_offsets; + DeviceOffsets(batch.offset, &device_offsets); + batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); + dh::ExecuteShards(&shards, [&](DeviceShard& shard){ + shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments, + h_nodes, tree_begin, tree_end); + }); + i_batch++; + } } public: @@ -354,6 +349,10 @@ class GPUPredictor : public xgboost::Predictor { void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { + GPUSet devices = GPUSet::All( + param.n_gpus, dmat->Info().num_row_).Normalised(param.gpu_id); + ConfigureShards(devices); + if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { return; } @@ -372,9 +371,10 @@ class GPUPredictor : public xgboost::Predictor { void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { - size_t n = model.param.num_output_group * info.num_row_; + size_t n_classes = model.param.num_output_group; + size_t n = n_classes * info.num_row_; const HostDeviceVector& base_margin = info.base_margin_; - out_preds->Reshard(devices); + out_preds->Reshard(GPUDistribution::Granular(devices_, n_classes)); out_preds->Resize(n); if (base_margin.Size() != 0) { CHECK_EQ(out_preds->Size(), n); @@ -392,14 +392,13 @@ class GPUPredictor : public xgboost::Predictor { if (it != cache_.end()) { const HostDeviceVector& y = it->second.predictions; if (y.Size() != 0) { - out_preds->Reshard(devices); + out_preds->Reshard(y.Distribution()); out_preds->Resize(y.Size()); out_preds->Copy(y); return true; } } } - return false; } @@ -464,24 +463,33 @@ class GPUPredictor : public xgboost::Predictor { Predictor::Init(cfg, cache); cpu_predictor->Init(cfg, cache); param.InitAllowUnknown(cfg); - devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id); - max_shared_memory_bytes = dh::MaxSharedMemory(param.gpu_id); + + GPUSet devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id); + ConfigureShards(devices); } private: + /*! \brief Re configure shards when GPUSet is changed. */ + void ConfigureShards(GPUSet devices) { + if (devices_ == devices) return; + + devices_ = devices; + shards.clear(); + shards.resize(devices_.Size()); + dh::ExecuteIndexShards(&shards, [=](size_t i, DeviceShard& shard){ + shard.Init(devices_[i]); + }); + } + GPUPredictionParam param; std::unique_ptr cpu_predictor; - std::unordered_map> - device_matrix_cache_; - thrust::device_vector nodes; - thrust::device_vector tree_segments; - thrust::device_vector tree_group; - thrust::device_vector preds; - GPUSet devices; - size_t max_shared_memory_bytes; + std::vector shards; + GPUSet devices_; }; + XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") .describe("Make predictions using GPU.") .set_body([]() { return new GPUPredictor(); }); + } // namespace predictor } // namespace xgboost diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 82f7719972a9..1cf7e5155c35 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -9,6 +9,7 @@ namespace xgboost { namespace predictor { + TEST(gpu_predictor, Test) { std::unique_ptr gpu_predictor = std::unique_ptr(Predictor::Create("gpu_predictor")); @@ -41,8 +42,7 @@ TEST(gpu_predictor, Test) { std::vector& cpu_out_predictions_h = cpu_out_predictions.HostVector(); float abs_tolerance = 0.001; for (int i = 0; i < gpu_out_predictions.Size(); i++) { - ASSERT_LT(std::abs(gpu_out_predictions_h[i] - cpu_out_predictions_h[i]), - abs_tolerance); + ASSERT_NEAR(gpu_out_predictions_h[i], cpu_out_predictions_h[i], abs_tolerance); } // Test predict instance const auto &batch = *(*dmat)->GetRowBatches().begin(); @@ -76,5 +76,46 @@ TEST(gpu_predictor, Test) { delete dmat; } + +// multi-GPU predictor test +TEST(gpu_predictor, MGPU_Test) { + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor")); + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor")); + + gpu_predictor->Init({std::pair("n_gpus", "-1")}, {}); + cpu_predictor->Init({}, {}); + + for (size_t i = 1; i < 33; i *= 2) { + int n_row = i, n_col = i; + auto dmat = CreateDMatrix(n_row, n_col, 0); + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree())); + trees.back()->InitModel(); + (*trees.back())[0].SetLeaf(1.5f); + (*trees.back()).Stat(0).sum_hess = 1.0f; + gbm::GBTreeModel model(0.5); + model.CommitModel(std::move(trees), 0); + model.param.num_output_group = 1; + + // Test predict batch + HostDeviceVector gpu_out_predictions; + HostDeviceVector cpu_out_predictions; + + gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0); + cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0); + + std::vector& gpu_out_predictions_h = gpu_out_predictions.HostVector(); + std::vector& cpu_out_predictions_h = cpu_out_predictions.HostVector(); + float abs_tolerance = 0.001; + for (int i = 0; i < gpu_out_predictions.Size(); i++) { + ASSERT_NEAR(gpu_out_predictions_h[i], cpu_out_predictions_h[i], abs_tolerance); + } + delete dmat; + } +} + } // namespace predictor } // namespace xgboost