Skip to content

Commit

Permalink
less resizes
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed May 29, 2020
1 parent 3136f63 commit 5e04cc6
Showing 1 changed file with 18 additions and 25 deletions.
43 changes: 18 additions & 25 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
HostDeviceVector<float> additional_input_;

public:
RegLossObj() = default;
// 0 - label_correct flag, 1 - scale_pos_weight, 2 - is_null_weight
RegLossObj(): additional_input_(3) {}

void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
Expand All @@ -64,53 +65,45 @@ class RegLossObj : public ObjFunction {
size_t const ndata = preds.Size();
out_gpair->Resize(ndata);
auto device = tparam_->gpu_id;
label_correct_.Resize(1);
label_correct_.Fill(1);
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag

bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight;
HostDeviceVector<float> scale_pos_weight_;
scale_pos_weight_.Resize(1);
scale_pos_weight_.Fill(scale_pos_weight);
HostDeviceVector<int> is_null_weight_;
is_null_weight_.Resize(1);
is_null_weight_.Fill(is_null_weight);
additional_input_.HostVector().begin()[1] = scale_pos_weight;
additional_input_.HostVector().begin()[2] = is_null_weight;

common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<float> _additional_input,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights,
common::Span<int> _is_null_weight,
common::Span<float> _scale_pos_weight) {
common::Span<const bst_float> _weights) {
const float _scale_pos_weight = _additional_input[1];
const bool _is_null_weight = _additional_input[2];

bst_float p = Loss::PredTransform(_preds[_idx]);
bst_float w = _is_null_weight[0] ? 1.0f : _weights[_idx];
bst_float w = _is_null_weight ? 1.0f : _weights[_idx];
bst_float label = _labels[_idx];
if (label == 1.0f) {
w *= _scale_pos_weight[0];
w *= _scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
_label_correct[0] = 0;
_additional_input[0] = 0;
}
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_,
&is_null_weight_, &scale_pos_weight_);
&additional_input_, out_gpair, &preds, &info.labels_, &info.weights_);

// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
auto const flag = additional_input_.HostVector().begin()[0];
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}

Expand Down

0 comments on commit 5e04cc6

Please sign in to comment.