Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dart inplace prediction with GPU input. #6777

Merged
merged 2 commits into from
Mar 25, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
@@ -185,9 +185,15 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
size_t begin, size_t end, SketchContainer *sketch_container,
int num_cuts_per_feature, size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
dh::device_vector<Entry> sorted_entries;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to avoid copying data when input is already on GPU.

if (page.data.DeviceCanRead()) {
const auto& device_data = page.data.ConstDevicePointer();
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end);
} else {
const auto& host_data = page.data.ConstHostVector();
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin,
host_data.begin() + end);
}
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

5 changes: 4 additions & 1 deletion src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
@@ -92,7 +92,10 @@ class HostDeviceVectorImpl {
} else {
gpu_access_ = GPUAccess::kWrite;
SetDevice();
thrust::fill(data_d_->begin(), data_d_->end(), v);
auto s_data = dh::ToSpan(*data_d_);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid synchronization.

dh::LaunchN(device_, data_d_->size(), [=]XGBOOST_DEVICE(size_t i) {
s_data[i] = v;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the bound-checked interface here, given that the size of data_d_ is already clear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can use pointer.

});
}
}

16 changes: 12 additions & 4 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
@@ -407,7 +407,6 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t gpu_batch_nrows =
std::min(dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(row_batch.Size()));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();

size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);

@@ -429,9 +428,18 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
if (row_batch.data.DeviceCanRead()) {
Copy link
Member Author

@trivialfis trivialfis Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid copying data when it's already on GPU.

auto const& d_data = row_batch.data.ConstDeviceSpan();
dh::safe_cuda(cudaMemcpyAsync(
entries_d.data().get(), d_data.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
} else {
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
dh::safe_cuda(cudaMemcpyAsync(
entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
}

const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);
79 changes: 62 additions & 17 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2020 by Contributors
* Copyright 2014-2021 by Contributors
* \file gbtree.cc
* \brief gradient boosted tree implementation.
* \author Tianqi Chen
@@ -558,6 +558,23 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
return cpu_predictor_;
}

/** Increment the prediction on GPU.
*
* \param out_predts Prediction for the whole model.
* \param predts Prediction for current tree.
* \param tree_w Tree weight.
*/
void GPUDartPredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w, size_t n_rows,
bst_group_t n_groups, bst_group_t group)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
}
#endif

class Dart : public GBTree {
public:
explicit Dart(LearnerModelParam const* booster_config) :
@@ -647,31 +664,46 @@ class Dart : public GBTree {
model_);
p_out_preds->version = 0;
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
std::tie(tree_begin, tree_end) =
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
auto n_groups = model_.learner_model_param->num_output_group;

PredictionCacheEntry predts; // temporary storage for prediction
if (generic_param_->gpu_id != GenericParameter::kCpuId) {
predts.predictions.SetDevice(generic_param_->gpu_id);
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);

for (size_t i = tree_begin; i < tree_end; i += 1) {
if (training &&
std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
if (training && std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
continue;
}

CHECK_GE(i, p_out_preds->version);
auto version = i / this->LayerTrees();
p_out_preds->version = version;

auto n_groups = model_.learner_model_param->num_output_group;
PredictionCacheEntry predts;
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
predts.predictions.Fill(0);
predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1);

// Multiple the weight to output prediction.
auto w = this->weight_drop_.at(i);
auto &h_predts = predts.predictions.HostVector();
auto group = model_.tree_info.at(i);
auto &h_out_predts = p_out_preds->predictions.HostVector();
CHECK_EQ(h_out_predts.size(), h_predts.size());
for (size_t ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
const size_t offset = ridx * n_groups + group;
h_out_predts[offset] += (h_predts[offset] * w);
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());

size_t n_rows = p_fmat->Info().num_row_;
if (predts.predictions.DeviceIdx() != GenericParameter::kCpuId) {
p_out_preds->predictions.SetDevice(predts.predictions.DeviceIdx());
GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(),
predts.predictions.DeviceSpan(), w, n_rows, n_groups,
group);
} else {
auto &h_out_predts = p_out_preds->predictions.HostVector();
auto &h_predts = predts.predictions.HostVector();
#pragma omp parallel for
for (omp_ulong ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
const size_t offset = ridx * n_groups + group;
h_out_predts[offset] += (h_predts[offset] * w);
}
}
}
}
@@ -699,6 +731,7 @@ class Dart : public GBTree {

MetaInfo info;
StringView msg{"Unsupported data type for inplace predict."};
int32_t device = GenericParameter::kCpuId;
// Inplace predict is not used for training, so no need to drop tree.
for (size_t i = tree_begin; i < tree_end; ++i) {
PredictionCacheEntry predts;
@@ -709,21 +742,26 @@ class Dart : public GBTree {
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
i + 1)) {
success = true;
#if defined(XGBOOST_USE_CUDA)
device = predts.predictions.DeviceIdx();
#endif // defined(XGBOOST_USE_CUDA)
break;
}
}
CHECK(success) << msg;
} else {
// No base margin for each tree
bool success = this->GetPredictor()->InplacePredict(
x, nullptr, model_, missing, &predts, tree_begin, tree_end);
x, nullptr, model_, missing, &predts, i, i + 1);
device = predts.predictions.DeviceIdx();
CHECK(success) << msg;
}

auto w = this->weight_drop_.at(i);
auto &h_predts = predts.predictions.HostVector();
auto &h_out_predts = out_preds->predictions.HostVector();
if (h_out_predts.empty()) {

if (i == tree_begin) {
auto n_rows =
h_predts.size() / model_.learner_model_param->num_output_group;
if (p_m) {
@@ -739,12 +777,19 @@ class Dart : public GBTree {

// Multiple the tree weight
CHECK_EQ(h_predts.size(), h_out_predts.size());
for (size_t i = 0; i < h_out_predts.size(); ++i) {

#pragma omp parallel for
for (omp_ulong i = 0; i < h_out_predts.size(); ++i) {
// Need to remove the base margin from indiviual tree.
h_out_predts[i] +=
(h_predts[i] - model_.learner_model_param->base_score) * w;
}
}

if (device != GenericParameter::kCpuId) {
out_preds->predictions.SetDevice(device);
out_preds->predictions.DeviceSpan();
}
}

void PredictInstance(const SparsePage::Inst &inst,
18 changes: 18 additions & 0 deletions src/gbm/gbtree.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*!
* Copyright 2021 by Contributors
*/
#include "xgboost/span.h"
#include "../common/device_helpers.cuh"

namespace xgboost {
namespace gbm {
void GPUDartPredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w, size_t n_rows,
bst_group_t n_groups, bst_group_t group) {
dh::LaunchN(dh::CurrentDevice(), n_rows, [=]XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] * tree_w);
});
}
} // namespace gbm
} // namespace xgboost
30 changes: 30 additions & 0 deletions tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
@@ -312,3 +312,33 @@ def test_predict_categorical_split(self, df):
pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)

def test_predict_dart(self):
import cupy as cp
rng = cp.random.RandomState(1994)
n_samples = 1000
X = rng.randn(n_samples, 10)
y = rng.randn(n_samples)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
},
Xy,
num_boost_round=32
)
# predictor=auto
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)

copied = cp.array(copied)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)

booster.set_param({"predictor": "gpu_predictor"})
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)

copied = cp.array(copied)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)