Skip to content

Commit

Permalink
Fix GPU Predictor.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 3, 2020
1 parent fc6c123 commit def3135
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,20 @@ class GPUPredictor : public xgboost::Predictor {
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override {
int device = generic_param_->gpu_id;
auto* out_preds = &predts->predictions;
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
ConfigureDevice(device);

auto* out_preds = &predts->predictions;
if (predts->version == 0) {
CHECK_EQ(out_preds->Size(), 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
}

ntree_limit *= model.learner_model_param_->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
}
auto tree_end = tree_begin + ntree_limit;
CHECK_GE(tree_end, predts->version);
if (tree_end - predts->version == 0) {
CHECK_EQ(out_preds->Size(), dmat->Info().num_row_);
} else {
Expand Down

0 comments on commit def3135

Please sign in to comment.