Assert dask client at early stage. (#5048)

This commit is contained in:
Jiaming Yuan 2019-11-19 10:55:26 +08:00 committed by GitHub
parent e67388fb8f
commit 98b051269b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -105,6 +105,12 @@ def _get_client_workers(client):
return workers
def _assert_client(client):
if not isinstance(client, (type(get_client()), type(None))):
raise TypeError(
_expect([type(get_client()), type(None)], type(client)))
class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes
'''DMatrix holding on references to Dask DataFrame or Dask Array.
@ -142,6 +148,7 @@ class DaskDMatrix:
feature_names=None,
feature_types=None):
_assert_dask_support()
_assert_client(client)
self._feature_names = feature_names
self._feature_types = feature_types
@ -362,11 +369,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
'''
_assert_dask_support()
_assert_client(client)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')
client = _xgb_get_client(client)
workers = list(_get_client_workers(client).keys())
@ -432,6 +440,7 @@ def predict(client, model, data, *args):
'''
_assert_dask_support()
_assert_client(client)
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):