Assert dask client at early stage. (#5048)
This commit is contained in:
parent
e67388fb8f
commit
98b051269b
@ -105,6 +105,12 @@ def _get_client_workers(client):
|
||||
return workers
|
||||
|
||||
|
||||
def _assert_client(client):
|
||||
if not isinstance(client, (type(get_client()), type(None))):
|
||||
raise TypeError(
|
||||
_expect([type(get_client()), type(None)], type(client)))
|
||||
|
||||
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
'''DMatrix holding on references to Dask DataFrame or Dask Array.
|
||||
@ -142,6 +148,7 @@ class DaskDMatrix:
|
||||
feature_names=None,
|
||||
feature_types=None):
|
||||
_assert_dask_support()
|
||||
_assert_client(client)
|
||||
|
||||
self._feature_names = feature_names
|
||||
self._feature_types = feature_types
|
||||
@ -362,11 +369,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
'''
|
||||
_assert_dask_support()
|
||||
|
||||
_assert_client(client)
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
client = _xgb_get_client(client)
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
|
||||
@ -432,6 +440,7 @@ def predict(client, model, data, *args):
|
||||
|
||||
'''
|
||||
_assert_dask_support()
|
||||
_assert_client(client)
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user