Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Back port fixes to 1.2 #6002

Merged
merged 4 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,8 @@ def dispatched_predict(worker_id):
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
ret = (delayed(predt), order)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order)
predictions.append(ret)
return predictions

Expand Down Expand Up @@ -775,8 +776,10 @@ async def map_function(func):
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
dtype=numpy.float32))
arrays.append(da.from_delayed(
results[i][0], shape=(shape[0],)
if results[i][1] == 1 else (shape[0], results[i][1]),
dtype=numpy.float32))
predictions = await da.concatenate(arrays, axis=0)
return predictions

Expand Down Expand Up @@ -978,6 +981,7 @@ def client(self):
def client(self, clt):
self._client = clt


@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model'])
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
Expand Down Expand Up @@ -1032,9 +1036,6 @@ def predict(self, data):
['estimators', 'model']
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-docstring
_client = None

async def _fit_async(self, X, y,
sample_weights=None,
eval_set=None,
Expand Down Expand Up @@ -1078,13 +1079,34 @@ def fit(self, X, y,
return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set, verbose)

async def _predict_async(self, data):
async def _predict_proba_async(self, data):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs

def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
_assert_dask_support()
return self.client.sync(self._predict_proba_async, data)

async def _predict_async(self, data):
_assert_dask_support()

test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)

if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int)
else:
preds = da.argmax(pred_probs, axis=1)

return preds

def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support()
return self.client.sync(self._predict_async, data)
2 changes: 1 addition & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def inner(preds, dmatrix):
gamma : float
Minimum loss reduction required to make a further partition on a leaf
node of the tree.
min_child_weight : int
min_child_weight : float
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
Maximum delta step we allow each tree's weight estimation to be.
Expand Down
24 changes: 13 additions & 11 deletions src/gbm/gbtree_model.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/*!
* Copyright 2019 by Contributors
* Copyright 2019-2020 by Contributors
*/
#include <utility>

#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "gbtree_model.h"
Expand Down Expand Up @@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
out["gbtree_model_param"] = ToJson(param);
std::vector<Json> trees_json;
size_t t = 0;
for (auto const& tree : trees) {
std::vector<Json> trees_json(trees.size());

for (size_t t = 0; t < trees.size(); ++t) {
auto const& tree = trees[t];
Json tree_json{Object()};
tree->SaveModel(&tree_json);
// The field is not used in XGBoost, but might be useful for external project.
tree_json["id"] = Integer(t);
trees_json.emplace_back(tree_json);
t++;
tree_json["id"] = Integer(static_cast<Integer::Int>(t));
trees_json[t] = std::move(tree_json);
}

std::vector<Json> tree_info_json(tree_info.size());
Expand All @@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) {
auto const& trees_json = get<Array const>(in["trees"]);
trees.resize(trees_json.size());

for (size_t t = 0; t < trees.size(); ++t) {
trees[t].reset( new RegTree() );
trees[t]->LoadModel(trees_json[t]);
for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT
auto tree_id = get<Integer>(trees_json[t]["id"]);
trees.at(tree_id).reset(new RegTree());
trees.at(tree_id)->LoadModel(trees_json[t]);
}

tree_info.resize(param.num_trees);
Expand Down
11 changes: 10 additions & 1 deletion tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) {
Json out { Object() };
learner->SaveModel(&out);

learner->LoadModel(out);
dmlc::TemporaryDirectory tmpdir;

std::ofstream fout (tmpdir.path + "/model.json");
fout << out;
fout.close();

auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json");
Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()});

learner->LoadModel(loaded);
learner->Configure();

Json new_in { Object() };
Expand Down
64 changes: 52 additions & 12 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import json
import asyncio
from sklearn.datasets import make_classification

if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
Expand Down Expand Up @@ -36,7 +37,7 @@ def generate_array():


def test_from_dask_dataframe():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = generate_array()

Expand Down Expand Up @@ -74,7 +75,7 @@ def test_from_dask_dataframe():


def test_from_dask_array():
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster:
with Client(cluster) as client:
X, y = generate_array()
dtrain = DaskDMatrix(client, X, y)
Expand Down Expand Up @@ -104,8 +105,28 @@ def test_from_dask_array():
assert np.all(single_node_predt == from_arr.compute())


def test_dask_predict_shape_infer():
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = make_classification(n_samples=1000, n_informative=5,
n_classes=3)
X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)

model = xgb.dask.train(
client,
{"objective": "multi:softprob", "num_class": 3},
dtrain=dtrain
)

preds = xgb.dask.predict(client, model, dtrain)
assert preds.shape[0] == preds.compute().shape[0]
assert preds.shape[1] == preds.compute().shape[1]


def test_dask_missing_value_reg():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X_0 = np.ones((20 // 2, kCols))
X_1 = np.zeros((20 // 2, kCols))
Expand Down Expand Up @@ -144,19 +165,19 @@ def test_dask_missing_value_cls():
missing=0.0)
cls.client = client
cls.fit(X, y, eval_set=[(X, y)])
dd_predt = cls.predict(X).compute()
dd_pred_proba = cls.predict_proba(X).compute()

np_X = X.compute()
np_predt = cls.get_booster().predict(
np_pred_proba = cls.get_booster().predict(
xgb.DMatrix(np_X, missing=0.0))
np.testing.assert_allclose(np_predt, dd_predt)
np.testing.assert_allclose(np_pred_proba, dd_pred_proba)

cls = xgb.dask.DaskXGBClassifier()
assert hasattr(cls, 'missing')


def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = generate_array()
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
Expand All @@ -178,7 +199,7 @@ def test_dask_regressor():


def test_dask_classifier():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = generate_array()
y = (y * 10).astype(np.int32)
Expand All @@ -201,7 +222,18 @@ def test_dask_classifier():
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2

# Test .predict_proba()
probas = classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10

cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())

np.testing.assert_allclose(single_node_proba,
probas.compute())

# Test with dataframe.
X_d = dd.from_dask_array(X)
Expand All @@ -218,7 +250,7 @@ def test_dask_classifier():
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_grid_search():
from sklearn.model_selection import GridSearchCV
with LocalCluster(n_workers=4) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = generate_array()
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
Expand Down Expand Up @@ -292,7 +324,9 @@ def _check_outputs(out, predictions):
evals=[(dtrain, 'validation')],
num_boost_round=2)
predictions = xgb.dask.predict(client=client, model=out,
data=dtrain).compute()
data=dtrain)
assert predictions.shape[1] == n_classes
predictions = predictions.compute()
_check_outputs(out, predictions)

# train has more rows than evals
Expand All @@ -315,15 +349,15 @@ def _check_outputs(out, predictions):
# environment and Exact doesn't support it.

def test_empty_dmatrix_hist():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'hist'}
run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)


def test_empty_dmatrix_approx():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'approx'}
run_empty_dmatrix_reg(client, parameters)
Expand Down Expand Up @@ -397,7 +431,13 @@ async def run_dask_classifier_asyncio(scheduler_address):
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2

# Test .predict_proba()
probas = await classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10


# Test with dataframe.
X_d = dd.from_dask_array(X)
Expand Down