Skip to content

Commit

Permalink
Initial support for multioutput regression. (#7514)
Browse files Browse the repository at this point in the history
* Add num target model parameter, which is configured from input labels.
* Change elementwise metric and indexing for weights.
* Add demo.
* Add tests.
  • Loading branch information
trivialfis authored Dec 18, 2021
1 parent 9ab73f7 commit 58a6723
Show file tree
Hide file tree
Showing 22 changed files with 306 additions and 67 deletions.
48 changes: 48 additions & 0 deletions demo/guide-python/multioutput_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
A demo for multi-output regression
==================================
The demo is adopted from scikit-learn:
https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py
"""
import numpy as np
import xgboost as xgb
import argparse
from matplotlib import pyplot as plt


def plot_predt(y, y_predt, name):
s = 25
plt.scatter(y[:, 0], y[:, 1], c="navy", s=s,
edgecolor="black", label="data")
plt.scatter(y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s,
edgecolor="black")
plt.xlim([-1, 2])
plt.ylim([-1, 2])
plt.show()


def main(plot_result: bool):
"""Draw a circle with 2-dim coordinate as target variables."""
rng = np.random.RandomState(1994)
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
y[::5, :] += (0.5 - rng.rand(20, 2))
y = y - y.min()
y = y / y.max()

# Train a regressor on it
reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64)
reg.fit(X, y, eval_set=[(X, y)])

y_predt = reg.predict(X)
if plot_result:
plot_predt(y, y_predt, 'multi')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
args = parser.parse_args()
main(args.plot == 1)
10 changes: 10 additions & 0 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ class ObjFunction : public Configurable {
* \brief Return task of this objective.
*/
virtual struct ObjInfo Task() const = 0;
/**
* \brief Return number of targets for input matrix. Right now XGBoost supports only
* multi-target regression.
*/
virtual uint32_t Targets(MetaInfo const& info) const {
if (info.labels.Shape(1) > 1) {
LOG(FATAL) << "multioutput is not supported by current objective function";
}
return 1;
}

/*!
* \brief Create an objective function according to name.
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# meta info that can be a matrix instead of vector.
# For now it's base_margin for multi-class, but it can be extended to label once we have
# multi-output.
_matrix_meta = {"base_margin"}
_matrix_meta = {"base_margin", "label"}


def _warn_unused_missing(data, missing):
Expand Down
7 changes: 4 additions & 3 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ void LoadTensorField(dmlc::Stream* strm, std::string const& expected_name,
CHECK(strm->Read(&is_scalar)) << invalid;
CHECK(!is_scalar) << invalid << "Expected field " << expected_name
<< " to be a tensor; got a scalar";
std::array<size_t, D> shape;
size_t shape[D];
for (size_t i = 0; i < D; ++i) {
CHECK(strm->Read(&(shape[i])));
}
p_out->Reshape(shape);
auto& field = p_out->Data()->HostVector();
CHECK(strm->Read(&field)) << invalid;
}
Expand Down Expand Up @@ -411,6 +412,7 @@ template <int32_t D, typename T>
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
ArrayInterface<D> array{arr_interface};
if (array.n == 0) {
p_out->Reshape(array.shape);
return;
}
CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value.";
Expand Down Expand Up @@ -737,8 +739,7 @@ void MetaInfo::Validate(int32_t device) const {
return;
}
if (labels.Size() != 0) {
CHECK_EQ(labels.Size(), num_row_)
<< "Size of labels must equal to number of rows.";
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
check_device(*labels.Data());
return;
}
Expand Down
1 change: 1 addition & 0 deletions src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
ArrayInterface<D> array(arr_interface);
if (array.n == 0) {
p_out->SetDevice(0);
p_out->Reshape(array.shape);
return;
}
CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value.";
Expand Down
69 changes: 61 additions & 8 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
/*! \brief the version of XGBoost. */
uint32_t major_version;
uint32_t minor_version;

uint32_t num_target{1};
/*! \brief reserved field */
int reserved[27];
int reserved[26];
/*! \brief constructor */
LearnerModelParamLegacy() {
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
base_score = 0.5f;
num_target = 1;
major_version = std::get<0>(Version::Self());
minor_version = std::get<1>(Version::Self());
static_assert(sizeof(LearnerModelParamLegacy) == 136,
Expand All @@ -119,13 +122,24 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
CHECK(ret.ec == std::errc());
obj["num_class"] =
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};

ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
static_cast<int64_t>(num_target));
obj["num_target"] =
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};

return Json(std::move(obj));
}
void FromJson(Json const& obj) {
auto const& j_param = get<Object const>(obj);
std::map<std::string, std::string> m;
m["num_feature"] = get<String const>(j_param.at("num_feature"));
m["num_class"] = get<String const>(j_param.at("num_class"));
auto n_targets_it = j_param.find("num_target");
if (n_targets_it != j_param.cend()) {
m["num_target"] = get<String const>(n_targets_it->second);
}

this->Init(m);
std::string str = get<String const>(j_param.at("base_score"));
from_chars(str.c_str(), str.c_str() + str.size(), base_score);
Expand All @@ -139,6 +153,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
dmlc::ByteSwap(&x.contain_eval_metrics, sizeof(x.contain_eval_metrics), 1);
dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1);
dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1);
dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}
Expand All @@ -156,15 +171,24 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe(
"Number of class option for multi-class classifier. "
" By default equals 0 and corresponds to binary classifier.");
DMLC_DECLARE_FIELD(num_target)
.set_default(1)
.set_lower_bound(1)
.describe("Number of target for multi-target regression.");
}
};

LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin,
ObjInfo t)
: base_score{base_margin},
num_feature{user_param.num_feature},
num_output_group{user_param.num_class == 0 ? 1 : static_cast<uint32_t>(user_param.num_class)},
task{t} {}
: base_score{base_margin}, num_feature{user_param.num_feature}, task{t} {
auto n_classes = std::max(static_cast<uint32_t>(user_param.num_class), 1u);
auto n_targets = user_param.num_target;
num_output_group = std::max(n_classes, n_targets);
// For version < 1.6, n_targets == 0
CHECK(n_classes <= 1 || n_targets <= 1)
<< "Multi-class multi-output is not yet supported. n_classes:" << n_classes
<< ", n_targets:" << n_targets;
}

struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
Expand Down Expand Up @@ -325,6 +349,8 @@ class LearnerConfiguration : public Learner {
args = {cfg_.cbegin(), cfg_.cend()}; // renew
this->ConfigureObjective(old_tparam, &args);

auto task = this->ConfigureTargets();

// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
// keep the stability, we initialize it in binary LoadModel instead of configuration.
Expand All @@ -339,7 +365,7 @@ class LearnerConfiguration : public Learner {
// - model is configured second time due to change of parameter
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), task);
}

this->ConfigureGBM(old_tparam, args);
Expand Down Expand Up @@ -586,8 +612,7 @@ class LearnerConfiguration : public Learner {
CHECK(matrix.first);
CHECK(!matrix.second.ref.expired());
const uint64_t num_col = matrix.first->Info().num_col_;
CHECK_LE(num_col,
static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater";
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
Expand Down Expand Up @@ -652,6 +677,31 @@ class LearnerConfiguration : public Learner {
p_metric->Configure(args);
}
}

/**
* Get number of targets from objective function.
*/
ObjInfo ConfigureTargets() {
CHECK(this->obj_);
auto const& cache = this->GetPredictionCache()->Container();
size_t n_targets = 1;
for (auto const& d : cache) {
if (n_targets == 1) {
n_targets = this->obj_->Targets(d.first->Info());
} else {
auto t = this->obj_->Targets(d.first->Info());
CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
}
}
if (mparam_.num_target != 1) {
CHECK(n_targets == 1 || n_targets == mparam_.num_target)
<< "Inconsistent configuration of num_target. Configuration result from input data:"
<< n_targets << ", configuration from parameter:" << mparam_.num_target;
} else {
mparam_.num_target = n_targets;
}
return this->obj_->Task();
}
};

std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
Expand Down Expand Up @@ -784,6 +834,9 @@ class LearnerIO : public LearnerConfiguration {
if (!DMLC_IO_NO_ENDIAN_SWAP) {
mparam_ = mparam_.ByteSwap();
}
if (mparam_.num_target == 0) {
mparam_.num_target = 1;
}
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";

Expand Down
42 changes: 22 additions & 20 deletions src/metric/elementwise_metric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,26 @@ class ElementWiseMetricsReduction {

PackedReduceResult
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
const HostDeviceVector<bst_float> &labels,
linalg::TensorView<float const, 2> labels,
const HostDeviceVector<bst_float> &preds,
int32_t n_threads) const {
size_t ndata = labels.Size();
auto n_targets = std::max(labels.Shape(1), static_cast<size_t>(1));
auto h_labels = labels.Values();

const auto& h_labels = labels.HostVector();
const auto& h_weights = weights.HostVector();
const auto& h_preds = preds.HostVector();

std::vector<double> score_tloc(n_threads, 0.0);
std::vector<double> weight_tloc(n_threads, 0.0);

// We sum over losses over all samples and targets instead of performing this for each
// target since the first one approach more accurate while the second approach is used
// for approximation in distributed setting. For rmse:
// - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target
// - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed
common::ParallelFor(ndata, n_threads, [&](size_t i) {
float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
float wt = h_weights.size() > 0 ? h_weights[i / n_targets] : 1.0f;
auto t_idx = omp_get_thread_num();
score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
weight_tloc[t_idx] += wt;
Expand All @@ -66,14 +72,15 @@ class ElementWiseMetricsReduction {

PackedReduceResult DeviceReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
linalg::TensorView<float const, 2> labels,
const HostDeviceVector<bst_float>& preds) {
size_t n_data = preds.Size();
auto n_targets = std::max(labels.Shape(1), static_cast<size_t>(1));

thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data;

auto s_label = labels.DeviceSpan();
auto s_label = labels.Values();
auto s_preds = preds.DeviceSpan();
auto s_weights = weights.DeviceSpan();

Expand All @@ -86,7 +93,7 @@ class ElementWiseMetricsReduction {
thrust::cuda::par(alloc),
begin, end,
[=] XGBOOST_DEVICE(size_t idx) {
float weight = is_null_weight ? 1.0f : s_weights[idx];
float weight = is_null_weight ? 1.0f : s_weights[idx / n_targets];

float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]);
residue *= weight;
Expand All @@ -100,26 +107,22 @@ class ElementWiseMetricsReduction {

#endif // XGBOOST_USE_CUDA

PackedReduceResult Reduce(
const GenericParameter &ctx,
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) {
PackedReduceResult Reduce(const GenericParameter& ctx, const HostDeviceVector<bst_float>& weights,
linalg::Tensor<float, 2> const& labels,
const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result;

if (ctx.gpu_id < 0) {
auto n_threads = ctx.Threads();
result = CpuReduceMetrics(weights, labels, preds, n_threads);
result = CpuReduceMetrics(weights, labels.HostView(), preds, n_threads);
}
#if defined(XGBOOST_USE_CUDA)
else { // NOLINT
device_ = ctx.gpu_id;
preds.SetDevice(device_);
labels.SetDevice(device_);
weights.SetDevice(device_);
preds.SetDevice(ctx.gpu_id);
weights.SetDevice(ctx.gpu_id);

dh::safe_cuda(cudaSetDevice(device_));
result = DeviceReduceMetrics(weights, labels, preds);
dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
result = DeviceReduceMetrics(weights, labels.View(ctx.gpu_id), preds);
}
#endif // defined(XGBOOST_USE_CUDA)
return result;
Expand All @@ -128,7 +131,6 @@ class ElementWiseMetricsReduction {
private:
EvalRow policy_;
#if defined(XGBOOST_USE_CUDA)
int device_{-1};
#endif // defined(XGBOOST_USE_CUDA)
};

Expand Down Expand Up @@ -364,7 +366,7 @@ struct EvalEWiseBase : public Metric {
CHECK_EQ(preds.Size(), info.labels.Size())
<< "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification";
auto result = reducer_.Reduce(*tparam_, info.weights_, *info.labels.Data(), preds);
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels, preds);

double dat[2] { result.Residue(), result.Weights() };

Expand Down
Loading

0 comments on commit 58a6723

Please sign in to comment.