Fix dask predict shape infer. (#5989)
This commit is contained in:
parent
9c6e791e64
commit
801e6b6800
@ -738,7 +738,8 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
predt = booster.predict(data=local_x,
|
predt = booster.predict(data=local_x,
|
||||||
validate_features=local_x.num_row() != 0,
|
validate_features=local_x.num_row() != 0,
|
||||||
*args)
|
*args)
|
||||||
ret = (delayed(predt), order)
|
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||||
|
ret = ((delayed(predt), columns), order)
|
||||||
predictions.append(ret)
|
predictions.append(ret)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -775,7 +776,9 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
arrays = []
|
arrays = []
|
||||||
for i, shape in enumerate(shapes):
|
for i, shape in enumerate(shapes):
|
||||||
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
arrays.append(da.from_delayed(
|
||||||
|
results[i][0], shape=(shape[0],)
|
||||||
|
if results[i][1] == 1 else (shape[0], results[i][1]),
|
||||||
dtype=numpy.float32))
|
dtype=numpy.float32))
|
||||||
predictions = await da.concatenate(arrays, axis=0)
|
predictions = await da.concatenate(arrays, axis=0)
|
||||||
return predictions
|
return predictions
|
||||||
@ -978,6 +981,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
def client(self, clt):
|
def client(self, clt):
|
||||||
self._client = clt
|
self._client = clt
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
['estimators', 'model'])
|
['estimators', 'model'])
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||||
@ -1032,9 +1036,6 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
['estimators', 'model']
|
['estimators', 'model']
|
||||||
)
|
)
|
||||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring
|
|
||||||
_client = None
|
|
||||||
|
|
||||||
async def _fit_async(self, X, y,
|
async def _fit_async(self, X, y,
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -36,7 +37,7 @@ def generate_array():
|
|||||||
|
|
||||||
|
|
||||||
def test_from_dask_dataframe():
|
def test_from_dask_dataframe():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ def test_from_dask_dataframe():
|
|||||||
|
|
||||||
|
|
||||||
def test_from_dask_array():
|
def test_from_dask_array():
|
||||||
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
|
with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
@ -104,8 +105,28 @@ def test_from_dask_array():
|
|||||||
assert np.all(single_node_predt == from_arr.compute())
|
assert np.all(single_node_predt == from_arr.compute())
|
||||||
|
|
||||||
|
|
||||||
|
def test_dask_predict_shape_infer():
|
||||||
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||||
|
n_classes=3)
|
||||||
|
X_ = dd.from_array(X, chunksize=100)
|
||||||
|
y_ = dd.from_array(y, chunksize=100)
|
||||||
|
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
|
||||||
|
|
||||||
|
model = xgb.dask.train(
|
||||||
|
client,
|
||||||
|
{"objective": "multi:softprob", "num_class": 3},
|
||||||
|
dtrain=dtrain
|
||||||
|
)
|
||||||
|
|
||||||
|
preds = xgb.dask.predict(client, model, dtrain)
|
||||||
|
assert preds.shape[0] == preds.compute().shape[0]
|
||||||
|
assert preds.shape[1] == preds.compute().shape[1]
|
||||||
|
|
||||||
|
|
||||||
def test_dask_missing_value_reg():
|
def test_dask_missing_value_reg():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X_0 = np.ones((20 // 2, kCols))
|
X_0 = np.ones((20 // 2, kCols))
|
||||||
X_1 = np.zeros((20 // 2, kCols))
|
X_1 = np.zeros((20 // 2, kCols))
|
||||||
@ -156,7 +177,7 @@ def test_dask_missing_value_cls():
|
|||||||
|
|
||||||
|
|
||||||
def test_dask_regressor():
|
def test_dask_regressor():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||||
@ -178,7 +199,7 @@ def test_dask_regressor():
|
|||||||
|
|
||||||
|
|
||||||
def test_dask_classifier():
|
def test_dask_classifier():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
y = (y * 10).astype(np.int32)
|
y = (y * 10).astype(np.int32)
|
||||||
@ -188,7 +209,7 @@ def test_dask_classifier():
|
|||||||
classifier.fit(X, y, eval_set=[(X, y)])
|
classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = classifier.predict(X)
|
prediction = classifier.predict(X)
|
||||||
|
|
||||||
assert prediction.ndim == 1
|
assert prediction.ndim == 2
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
history = classifier.evals_result()
|
history = classifier.evals_result()
|
||||||
@ -211,14 +232,14 @@ def test_dask_classifier():
|
|||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
prediction = classifier.predict(X_d)
|
prediction = classifier.predict(X_d)
|
||||||
|
|
||||||
assert prediction.ndim == 1
|
assert prediction.ndim == 2
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_sklearn_grid_search():
|
def test_sklearn_grid_search():
|
||||||
from sklearn.model_selection import GridSearchCV
|
from sklearn.model_selection import GridSearchCV
|
||||||
with LocalCluster(n_workers=4) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||||
@ -292,7 +313,9 @@ def run_empty_dmatrix_cls(client, parameters):
|
|||||||
evals=[(dtrain, 'validation')],
|
evals=[(dtrain, 'validation')],
|
||||||
num_boost_round=2)
|
num_boost_round=2)
|
||||||
predictions = xgb.dask.predict(client=client, model=out,
|
predictions = xgb.dask.predict(client=client, model=out,
|
||||||
data=dtrain).compute()
|
data=dtrain)
|
||||||
|
assert predictions.shape[1] == n_classes
|
||||||
|
predictions = predictions.compute()
|
||||||
_check_outputs(out, predictions)
|
_check_outputs(out, predictions)
|
||||||
|
|
||||||
# train has more rows than evals
|
# train has more rows than evals
|
||||||
@ -315,7 +338,7 @@ def run_empty_dmatrix_cls(client, parameters):
|
|||||||
# environment and Exact doesn't support it.
|
# environment and Exact doesn't support it.
|
||||||
|
|
||||||
def test_empty_dmatrix_hist():
|
def test_empty_dmatrix_hist():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
parameters = {'tree_method': 'hist'}
|
parameters = {'tree_method': 'hist'}
|
||||||
run_empty_dmatrix_reg(client, parameters)
|
run_empty_dmatrix_reg(client, parameters)
|
||||||
@ -323,7 +346,7 @@ def test_empty_dmatrix_hist():
|
|||||||
|
|
||||||
|
|
||||||
def test_empty_dmatrix_approx():
|
def test_empty_dmatrix_approx():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
parameters = {'tree_method': 'approx'}
|
parameters = {'tree_method': 'approx'}
|
||||||
run_empty_dmatrix_reg(client, parameters)
|
run_empty_dmatrix_reg(client, parameters)
|
||||||
@ -384,7 +407,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = await classifier.predict(X)
|
prediction = await classifier.predict(X)
|
||||||
|
|
||||||
assert prediction.ndim == 1
|
assert prediction.ndim == 2
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
history = classifier.evals_result()
|
history = classifier.evals_result()
|
||||||
@ -407,8 +430,9 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
prediction = await classifier.predict(X_d)
|
prediction = await classifier.predict(X_d)
|
||||||
|
|
||||||
assert prediction.ndim == 1
|
assert prediction.ndim == 2
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
assert prediction.shape[1] == 10
|
||||||
|
|
||||||
|
|
||||||
def test_with_asyncio():
|
def test_with_asyncio():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user