diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 4009e50c8611..fac23ece9f2c 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -73,18 +73,26 @@ class RegLossObj : public ObjFunction { << "Number of weights should be equal to number of data points."; } auto scale_pos_weight = param_.scale_pos_weight; - common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, + HostDeviceVector scale_pos_weight_; + scale_pos_weight_.Resize(1); + scale_pos_weight_.Fill(scale_pos_weight); + HostDeviceVector is_null_weight_; + is_null_weight_.Resize(1); + is_null_weight_.Fill(is_null_weight); + + common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, common::Span _out_gpair, common::Span _preds, common::Span _labels, - common::Span _weights) { + common::Span _weights, + common::Span _is_null_weight, + common::Span _scale_pos_weight) { bst_float p = Loss::PredTransform(_preds[_idx]); - bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + bst_float w = _is_null_weight[0] ? 1.0f : _weights[_idx]; bst_float label = _labels[_idx]; if (label == 1.0f) { - w *= scale_pos_weight; + w *= _scale_pos_weight[0]; } if (!Loss::CheckLabel(label)) { // If there is an incorrect label, the host code will know. @@ -94,7 +102,8 @@ class RegLossObj : public ObjFunction { Loss::SecondOrderGradient(p, label) * w); }, common::Range{0, static_cast(ndata)}, device).Eval( - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_, + &is_null_weight_, &scale_pos_weight_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector();