Allow kwargs in dask predict (#6117)
This commit is contained in:
@@ -215,7 +215,7 @@ def test_dask_classifier():
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
@@ -276,7 +276,6 @@ def test_sklearn_grid_search():
|
||||
|
||||
|
||||
def run_empty_dmatrix_reg(client, parameters):
|
||||
|
||||
def _check_outputs(out, predictions):
|
||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||
assert len(out['history']['validation']['rmse']) == 2
|
||||
@@ -424,7 +423,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
||||
classifier = await xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = await classifier.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
@@ -447,7 +446,6 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
||||
assert probas.shape[0] == kRows
|
||||
assert probas.shape[1] == 10
|
||||
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
@@ -472,6 +470,28 @@ def test_with_asyncio():
|
||||
asyncio.run(run_dask_classifier_asyncio(address))
|
||||
|
||||
|
||||
def test_predict():
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
|
||||
pred = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
assert pred.ndim == 1
|
||||
assert pred.shape[0] == kRows
|
||||
|
||||
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
|
||||
assert margin.ndim == 1
|
||||
assert margin.shape[0] == kRows
|
||||
|
||||
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
|
||||
assert shap.ndim == 2
|
||||
assert shap.shape[0] == kRows
|
||||
assert shap.shape[1] == kCols + 1
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def run_updater_test(self, client, params, num_rounds, dataset,
|
||||
tree_method):
|
||||
@@ -489,9 +509,9 @@ class TestWithDask:
|
||||
chunk = 128
|
||||
X = da.from_array(dataset.X,
|
||||
chunks=(chunk, dataset.X.shape[1]))
|
||||
y = da.from_array(dataset.y, chunks=(chunk, ))
|
||||
y = da.from_array(dataset.y, chunks=(chunk,))
|
||||
if dataset.w is not None:
|
||||
w = da.from_array(dataset.w, chunks=(chunk, ))
|
||||
w = da.from_array(dataset.w, chunks=(chunk,))
|
||||
else:
|
||||
w = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user