Fix dask prediction. (#4941)

* Fix dask prediction.

* Add better error messages for wrong partition.
This commit is contained in:
Jiaming Yuan
2019-10-14 23:19:34 -04:00
committed by GitHub
parent b61d534472
commit 2ebdec8aa6
5 changed files with 51 additions and 24 deletions

View File

@@ -21,12 +21,12 @@ except ImportError:
pass
kRows = 1000
kCols = 10
def generate_array():
n = 10
partition_size = 20
X = da.random.random((kRows, n), partition_size)
X = da.random.random((kRows, kCols), partition_size)
y = da.random.random(kRows, partition_size)
return X, y
@@ -44,7 +44,7 @@ def test_from_dask_dataframe(client):
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert isinstance(prediction, da.Array)
assert prediction.shape[0] == kRows, prediction
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
with pytest.raises(ValueError):
# evals_result is not supported in dask interface.
@@ -59,6 +59,7 @@ def test_from_dask_array(client):
result = xgb.dask.train(client, {}, dtrain)
prediction = xgb.dask.predict(client, result, dtrain)
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
assert isinstance(prediction, da.Array)
@@ -71,6 +72,8 @@ def test_regressor(client):
regressor.fit(X, y, eval_set=[(X, y)])
prediction = regressor.predict(X)
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
history = regressor.evals_result()
assert isinstance(prediction, da.Array)
@@ -88,6 +91,8 @@ def test_classifier(client):
classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X)
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
history = classifier.evals_result()
assert isinstance(prediction, da.Array)