From 98b051269b0c78804530d57a71c3ae7cacb333b8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 19 Nov 2019 10:55:26 +0800 Subject: [PATCH] Assert dask client at early stage. (#5048) --- python-package/xgboost/dask.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 3181c5ebdab4..796608d1d546 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -105,6 +105,12 @@ def _get_client_workers(client): return workers +def _assert_client(client): + if not isinstance(client, (type(get_client()), type(None))): + raise TypeError( + _expect([type(get_client()), type(None)], type(client))) + + class DaskDMatrix: # pylint: disable=missing-docstring, too-many-instance-attributes '''DMatrix holding on references to Dask DataFrame or Dask Array. @@ -142,6 +148,7 @@ def __init__(self, feature_names=None, feature_types=None): _assert_dask_support() + _assert_client(client) self._feature_names = feature_names self._feature_types = feature_types @@ -362,11 +369,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): ''' _assert_dask_support() - + _assert_client(client) if 'evals_result' in kwargs.keys(): raise ValueError( 'evals_result is not supported in dask interface.', 'The evaluation history is returned as result of training.') + client = _xgb_get_client(client) workers = list(_get_client_workers(client).keys()) @@ -432,6 +440,7 @@ def predict(client, model, data, *args): ''' _assert_dask_support() + _assert_client(client) if isinstance(model, Booster): booster = model elif isinstance(model, dict):