[dask] Fix missing value for scikit-learn interface. (#5435)
This commit is contained in:
@@ -97,6 +97,58 @@ def test_from_dask_array():
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
|
||||
|
||||
def test_dask_missing_value_reg():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((20 // 2, kCols))
|
||||
X_1 = np.zeros((20 // 2, kCols))
|
||||
X = np.concatenate([X_0, X_1], axis=0)
|
||||
np.random.shuffle(X)
|
||||
X = da.from_array(X)
|
||||
X = X.rechunk(20, 1)
|
||||
y = da.random.randint(0, 3, size=20)
|
||||
y.rechunk(20)
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2,
|
||||
missing=0.0)
|
||||
regressor.client = client
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
dd_predt = regressor.predict(X).compute()
|
||||
|
||||
np_X = X.compute()
|
||||
np_predt = regressor.get_booster().predict(
|
||||
xgb.DMatrix(np_X, missing=0.0))
|
||||
np.testing.assert_allclose(np_predt, dd_predt)
|
||||
|
||||
|
||||
def test_dask_missing_value_cls():
|
||||
# Multi-class doesn't handle empty DMatrix well. So we use lesser workers.
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((kRows // 2, kCols))
|
||||
X_1 = np.zeros((kRows // 2, kCols))
|
||||
X = np.concatenate([X_0, X_1], axis=0)
|
||||
np.random.shuffle(X)
|
||||
X = da.from_array(X)
|
||||
X = X.rechunk(20, None)
|
||||
y = da.random.randint(0, 3, size=kRows)
|
||||
y = y.rechunk(20, 1)
|
||||
cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2,
|
||||
tree_method='hist',
|
||||
missing=0.0)
|
||||
cls.client = client
|
||||
cls.fit(X, y, eval_set=[(X, y)])
|
||||
dd_predt = cls.predict(X).compute()
|
||||
|
||||
np_X = X.compute()
|
||||
np_predt = cls.get_booster().predict(
|
||||
xgb.DMatrix(np_X, missing=0.0))
|
||||
np.testing.assert_allclose(np_predt, dd_predt)
|
||||
|
||||
cls = xgb.dask.DaskXGBClassifier()
|
||||
assert hasattr(cls, 'missing')
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
|
||||
Reference in New Issue
Block a user