Skip to content

Commit

Permalink
Assert dask client at early stage. (#5048)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Nov 19, 2019
1 parent e67388f commit 98b0512
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def _get_client_workers(client):
return workers


def _assert_client(client):
if not isinstance(client, (type(get_client()), type(None))):
raise TypeError(
_expect([type(get_client()), type(None)], type(client)))


class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes
'''DMatrix holding on references to Dask DataFrame or Dask Array.
Expand Down Expand Up @@ -142,6 +148,7 @@ def __init__(self,
feature_names=None,
feature_types=None):
_assert_dask_support()
_assert_client(client)

self._feature_names = feature_names
self._feature_types = feature_types
Expand Down Expand Up @@ -362,11 +369,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
'''
_assert_dask_support()

_assert_client(client)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')

client = _xgb_get_client(client)
workers = list(_get_client_workers(client).keys())

Expand Down Expand Up @@ -432,6 +440,7 @@ def predict(client, model, data, *args):
'''
_assert_dask_support()
_assert_client(client)
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
Expand Down

0 comments on commit 98b0512

Please sign in to comment.