Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix release degradation #5720

Merged
merged 2 commits into from
May 31, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
HostDeviceVector<float> additional_input_;

public:
RegLossObj() = default;
// 0 - label_correct flag, 1 - scale_pos_weight, 2 - is_null_weight
RegLossObj(): additional_input_(3) {}

void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
Expand All @@ -64,44 +65,45 @@ class RegLossObj : public ObjFunction {
size_t const ndata = preds.Size();
out_gpair->Resize(ndata);
auto device = tparam_->gpu_id;
label_correct_.Resize(1);
label_correct_.Fill(1);
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag

bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
additional_input_.HostVector().begin()[1] = scale_pos_weight;
additional_input_.HostVector().begin()[2] = is_null_weight;

common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not enough to change [=] to [&] in lambda?
Also, if we need to set scalars to the function - do we really need use common::Span instead of just scalars?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't pass a ref of host data to device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems scalars are also not applicable, as it isn't aligned with common interface of Transform::Eval

common::Span<float> _additional_input,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
const float _scale_pos_weight = _additional_input[1];
const bool _is_null_weight = _additional_input[2];

bst_float p = Loss::PredTransform(_preds[_idx]);
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float w = _is_null_weight ? 1.0f : _weights[_idx];
bst_float label = _labels[_idx];
if (label == 1.0f) {
w *= scale_pos_weight;
w *= _scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
_label_correct[0] = 0;
_additional_input[0] = 0;
}
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
&additional_input_, out_gpair, &preds, &info.labels_, &info.weights_);

// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
auto const flag = additional_input_.HostVector().begin()[0];
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}

Expand Down