Assert dask client at early stage. (#5048)
This commit is contained in:
parent
e67388fb8f
commit
98b051269b
@ -105,6 +105,12 @@ def _get_client_workers(client):
|
|||||||
return workers
|
return workers
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_client(client):
|
||||||
|
if not isinstance(client, (type(get_client()), type(None))):
|
||||||
|
raise TypeError(
|
||||||
|
_expect([type(get_client()), type(None)], type(client)))
|
||||||
|
|
||||||
|
|
||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
'''DMatrix holding on references to Dask DataFrame or Dask Array.
|
'''DMatrix holding on references to Dask DataFrame or Dask Array.
|
||||||
@ -142,6 +148,7 @@ class DaskDMatrix:
|
|||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None):
|
feature_types=None):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
|
_assert_client(client)
|
||||||
|
|
||||||
self._feature_names = feature_names
|
self._feature_names = feature_names
|
||||||
self._feature_types = feature_types
|
self._feature_types = feature_types
|
||||||
@ -362,11 +369,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
|
_assert_client(client)
|
||||||
if 'evals_result' in kwargs.keys():
|
if 'evals_result' in kwargs.keys():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'evals_result is not supported in dask interface.',
|
'evals_result is not supported in dask interface.',
|
||||||
'The evaluation history is returned as result of training.')
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
|
||||||
@ -432,6 +440,7 @@ def predict(client, model, data, *args):
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
|
_assert_client(client)
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user