Skip to content

Commit

Permalink
Reduce dask test search space.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 10, 2020
1 parent a0ab2c8 commit 8587012
Showing 1 changed file with 48 additions and 47 deletions.
95 changes: 48 additions & 47 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,55 +452,56 @@ def test_with_asyncio():
asyncio.run(run_dask_classifier_asyncio(address))


class TestWithDask(unittest.TestCase):
def run_updater_test(self, client, params, num_rounds, dataset,
tree_method):
params['tree_method'] = tree_method
params = dataset.set_params(params)
# multi class doesn't handle empty dataset well (empty
# means at least 1 worker has data).
if params['objective'] == "multi:softmax":
return
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return
def run_updater_test(client, params, num_rounds, dataset,
tree_method):
params['tree_method'] = tree_method
params = dataset.set_params(params)
# multi class doesn't handle empty dataset well (empty
# means at least 1 worker has data).
if params['objective'] == "multi:softmax":
return
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return

chunk = 128
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
else:
w = None

m = xgb.dask.DaskDMatrix(
client, data=X, label=y, weight=w)
history = xgb.dask.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])


@given(hist_parameter_strategy, tm.dataset_strategy)
@settings(deadline=None)
def test_hist(params, dataset):
with LocalCluster() as cluster:
with Client(cluster) as client:
run_updater_test(
client, params, 20, dataset, 'hist')


@given(exact_parameter_strategy, tm.dataset_strategy)
@settings(deadline=None)
def test_approx(params, dataset):
with LocalCluster() as cluster:
with Client(cluster) as client:
run_updater_test(
client, params, 20, dataset, 'approx')

chunk = 128
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
else:
w = None

m = xgb.dask.DaskDMatrix(
client, data=X, label=y, weight=w)
history = xgb.dask.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])

@given(hist_parameter_strategy, strategies.integers(10, 20),
tm.dataset_strategy)
@settings(deadline=None)
def test_hist(self, params, num_rounds, dataset):
with LocalCluster() as cluster:
with Client(cluster) as client:
self.run_updater_test(
client, params, num_rounds, dataset, 'hist')

@given(exact_parameter_strategy, strategies.integers(10, 20),
tm.dataset_strategy)
@settings(deadline=None)
def test_approx(self, params, num_rounds, dataset):
with LocalCluster() as cluster:
with Client(cluster) as client:
self.run_updater_test(
client, params, num_rounds, dataset, 'approx')

class TestWithDask(unittest.TestCase):
def run_quantile(self, name):
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
Expand Down

0 comments on commit 8587012

Please sign in to comment.