[dask] Fix missing value for scikit-learn interface. (#5435)

This commit is contained in:
Jiaming Yuan 2020-03-20 22:56:01 +08:00 committed by GitHub
parent 4b7e2b7bff
commit cd7d6f7d59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 12 deletions

View File

@ -131,8 +131,14 @@ Basic functionalities including training and generating predictions for regressi
classification are implemented. But there are still some other limitations we haven't classification are implemented. But there are still some other limitations we haven't
addressed yet. addressed yet.
- Label encoding for Scikit-Learn classifier. - Label encoding for Scikit-Learn classifier may not be supported. Meaning that user need
- Ranking to encode their training labels into discrete values first.
- Ranking is not supported right now.
- Empty worker is not well supported by classifier. If the training hangs for classifier
with a warning about empty DMatrix, please consider balancing your data first. But
regressor works fine with empty DMatrix.
- Callback functions are not tested. - Callback functions are not tested.
- To use cross validation one needs to explicitly train different models instead of using - Only ``GridSearchCV`` from ``scikit-learn`` is supported for dask interface. Meaning
a functional API like ``xgboost.cv``. that we can distribute data among workers but have to train one model at a time. If you
want to scale up grid searching with model parallelism by ``dask-ml``, please consider
using normal ``scikit-learn`` interface like `xgboost.XGBRegressor` for now.

View File

@ -572,7 +572,7 @@ def predict(client, model, data, *args, missing=numpy.nan):
return predictions return predictions
def _evaluation_matrices(client, validation_set, sample_weights): def _evaluation_matrices(client, validation_set, sample_weights, missing):
''' '''
Parameters Parameters
---------- ----------
@ -597,7 +597,8 @@ def _evaluation_matrices(client, validation_set, sample_weights):
for i, e in enumerate(validation_set): for i, e in enumerate(validation_set):
w = (sample_weights[i] w = (sample_weights[i]
if sample_weights is not None else None) if sample_weights is not None else None)
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w) dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w,
missing=missing)
evals.append((dmat, 'validation_{}'.format(i))) evals.append((dmat, 'validation_{}'.format(i)))
else: else:
evals = None evals = None
@ -672,10 +673,12 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
verbose=True): verbose=True):
_assert_dask_support() _assert_dask_support()
dtrain = DaskDMatrix(client=self.client, dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights) data=X, label=y, weight=sample_weights,
missing=self.missing)
params = self.get_xgb_params() params = self.get_xgb_params()
evals = _evaluation_matrices(self.client, evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set) eval_set, sample_weight_eval_set,
self.missing)
results = train(self.client, params, dtrain, results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
@ -688,7 +691,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
def predict(self, data): # pylint: disable=arguments-differ def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support() _assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data) test_dmatrix = DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = predict(client=self.client, pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix)
return pred_probs return pred_probs
@ -711,7 +715,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
verbose=True): verbose=True):
_assert_dask_support() _assert_dask_support()
dtrain = DaskDMatrix(client=self.client, dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights) data=X, label=y, weight=sample_weights,
missing=self.missing)
params = self.get_xgb_params() params = self.get_xgb_params()
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
@ -728,7 +733,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params["objective"] = "binary:logistic" params["objective"] = "binary:logistic"
evals = _evaluation_matrices(self.client, evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set) eval_set, sample_weight_eval_set,
self.missing)
results = train(self.client, params, dtrain, results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, verbose_eval=verbose) evals=evals, verbose_eval=verbose)
@ -739,7 +745,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
def predict(self, data): # pylint: disable=arguments-differ def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support() _assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data) test_dmatrix = DaskDMatrix(client=self.client, data=data,
missing=self.missing)
pred_probs = predict(client=self.client, pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix)
return pred_probs return pred_probs

View File

@ -97,6 +97,58 @@ def test_from_dask_array():
assert np.all(single_node_predt == from_arr.compute()) assert np.all(single_node_predt == from_arr.compute())
def test_dask_missing_value_reg():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X_0 = np.ones((20 // 2, kCols))
X_1 = np.zeros((20 // 2, kCols))
X = np.concatenate([X_0, X_1], axis=0)
np.random.shuffle(X)
X = da.from_array(X)
X = X.rechunk(20, 1)
y = da.random.randint(0, 3, size=20)
y.rechunk(20)
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2,
missing=0.0)
regressor.client = client
regressor.set_params(tree_method='hist')
regressor.fit(X, y, eval_set=[(X, y)])
dd_predt = regressor.predict(X).compute()
np_X = X.compute()
np_predt = regressor.get_booster().predict(
xgb.DMatrix(np_X, missing=0.0))
np.testing.assert_allclose(np_predt, dd_predt)
def test_dask_missing_value_cls():
# Multi-class doesn't handle empty DMatrix well. So we use lesser workers.
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
X_0 = np.ones((kRows // 2, kCols))
X_1 = np.zeros((kRows // 2, kCols))
X = np.concatenate([X_0, X_1], axis=0)
np.random.shuffle(X)
X = da.from_array(X)
X = X.rechunk(20, None)
y = da.random.randint(0, 3, size=kRows)
y = y.rechunk(20, 1)
cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2,
tree_method='hist',
missing=0.0)
cls.client = client
cls.fit(X, y, eval_set=[(X, y)])
dd_predt = cls.predict(X).compute()
np_X = X.compute()
np_predt = cls.get_booster().predict(
xgb.DMatrix(np_X, missing=0.0))
np.testing.assert_allclose(np_predt, dd_predt)
cls = xgb.dask.DaskXGBClassifier()
assert hasattr(cls, 'missing')
def test_dask_regressor(): def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster: with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client: with Client(cluster) as client: