[dask] Fix missing value for scikit-learn interface. (#5435)
This commit is contained in:
parent
4b7e2b7bff
commit
cd7d6f7d59
@ -131,8 +131,14 @@ Basic functionalities including training and generating predictions for regressi
|
||||
classification are implemented. But there are still some other limitations we haven't
|
||||
addressed yet.
|
||||
|
||||
- Label encoding for Scikit-Learn classifier.
|
||||
- Ranking
|
||||
- Label encoding for Scikit-Learn classifier may not be supported. Meaning that user need
|
||||
to encode their training labels into discrete values first.
|
||||
- Ranking is not supported right now.
|
||||
- Empty worker is not well supported by classifier. If the training hangs for classifier
|
||||
with a warning about empty DMatrix, please consider balancing your data first. But
|
||||
regressor works fine with empty DMatrix.
|
||||
- Callback functions are not tested.
|
||||
- To use cross validation one needs to explicitly train different models instead of using
|
||||
a functional API like ``xgboost.cv``.
|
||||
- Only ``GridSearchCV`` from ``scikit-learn`` is supported for dask interface. Meaning
|
||||
that we can distribute data among workers but have to train one model at a time. If you
|
||||
want to scale up grid searching with model parallelism by ``dask-ml``, please consider
|
||||
using normal ``scikit-learn`` interface like `xgboost.XGBRegressor` for now.
|
||||
|
||||
@ -572,7 +572,7 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
||||
return predictions
|
||||
|
||||
|
||||
def _evaluation_matrices(client, validation_set, sample_weights):
|
||||
def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
@ -597,7 +597,8 @@ def _evaluation_matrices(client, validation_set, sample_weights):
|
||||
for i, e in enumerate(validation_set):
|
||||
w = (sample_weights[i]
|
||||
if sample_weights is not None else None)
|
||||
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w)
|
||||
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w,
|
||||
missing=missing)
|
||||
evals.append((dmat, 'validation_{}'.format(i)))
|
||||
else:
|
||||
evals = None
|
||||
@ -672,10 +673,12 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
dtrain = DaskDMatrix(client=self.client,
|
||||
data=X, label=y, weight=sample_weights)
|
||||
data=X, label=y, weight=sample_weights,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
evals = _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set)
|
||||
eval_set, sample_weight_eval_set,
|
||||
self.missing)
|
||||
|
||||
results = train(self.client, params, dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
@ -688,7 +691,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
|
||||
def predict(self, data): # pylint: disable=arguments-differ
|
||||
_assert_dask_support()
|
||||
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
||||
test_dmatrix = DaskDMatrix(client=self.client, data=data,
|
||||
missing=self.missing)
|
||||
pred_probs = predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix)
|
||||
return pred_probs
|
||||
@ -711,7 +715,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
dtrain = DaskDMatrix(client=self.client,
|
||||
data=X, label=y, weight=sample_weights)
|
||||
data=X, label=y, weight=sample_weights,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
@ -728,7 +733,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
params["objective"] = "binary:logistic"
|
||||
|
||||
evals = _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set)
|
||||
eval_set, sample_weight_eval_set,
|
||||
self.missing)
|
||||
results = train(self.client, params, dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals, verbose_eval=verbose)
|
||||
@ -739,7 +745,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
def predict(self, data): # pylint: disable=arguments-differ
|
||||
_assert_dask_support()
|
||||
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
||||
test_dmatrix = DaskDMatrix(client=self.client, data=data,
|
||||
missing=self.missing)
|
||||
pred_probs = predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix)
|
||||
return pred_probs
|
||||
|
||||
@ -97,6 +97,58 @@ def test_from_dask_array():
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
|
||||
|
||||
def test_dask_missing_value_reg():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((20 // 2, kCols))
|
||||
X_1 = np.zeros((20 // 2, kCols))
|
||||
X = np.concatenate([X_0, X_1], axis=0)
|
||||
np.random.shuffle(X)
|
||||
X = da.from_array(X)
|
||||
X = X.rechunk(20, 1)
|
||||
y = da.random.randint(0, 3, size=20)
|
||||
y.rechunk(20)
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2,
|
||||
missing=0.0)
|
||||
regressor.client = client
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
dd_predt = regressor.predict(X).compute()
|
||||
|
||||
np_X = X.compute()
|
||||
np_predt = regressor.get_booster().predict(
|
||||
xgb.DMatrix(np_X, missing=0.0))
|
||||
np.testing.assert_allclose(np_predt, dd_predt)
|
||||
|
||||
|
||||
def test_dask_missing_value_cls():
|
||||
# Multi-class doesn't handle empty DMatrix well. So we use lesser workers.
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((kRows // 2, kCols))
|
||||
X_1 = np.zeros((kRows // 2, kCols))
|
||||
X = np.concatenate([X_0, X_1], axis=0)
|
||||
np.random.shuffle(X)
|
||||
X = da.from_array(X)
|
||||
X = X.rechunk(20, None)
|
||||
y = da.random.randint(0, 3, size=kRows)
|
||||
y = y.rechunk(20, 1)
|
||||
cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2,
|
||||
tree_method='hist',
|
||||
missing=0.0)
|
||||
cls.client = client
|
||||
cls.fit(X, y, eval_set=[(X, y)])
|
||||
dd_predt = cls.predict(X).compute()
|
||||
|
||||
np_X = X.compute()
|
||||
np_predt = cls.get_booster().predict(
|
||||
xgb.DMatrix(np_X, missing=0.0))
|
||||
np.testing.assert_allclose(np_predt, dd_predt)
|
||||
|
||||
cls = xgb.dask.DaskXGBClassifier()
|
||||
assert hasattr(cls, 'missing')
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user