Fix dask prediction. (#4941)
* Fix dask prediction. * Add better error messages for wrong partition.
This commit is contained in:
@@ -21,12 +21,12 @@ except ImportError:
|
||||
pass
|
||||
|
||||
kRows = 1000
|
||||
kCols = 10
|
||||
|
||||
|
||||
def generate_array():
|
||||
n = 10
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, n), partition_size)
|
||||
X = da.random.random((kRows, kCols), partition_size)
|
||||
y = da.random.random(kRows, partition_size)
|
||||
return X, y
|
||||
|
||||
@@ -44,7 +44,7 @@ def test_from_dask_dataframe(client):
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert prediction.shape[0] == kRows, prediction
|
||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# evals_result is not supported in dask interface.
|
||||
@@ -59,6 +59,7 @@ def test_from_dask_array(client):
|
||||
result = xgb.dask.train(client, {}, dtrain)
|
||||
|
||||
prediction = xgb.dask.predict(client, result, dtrain)
|
||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
|
||||
@@ -71,6 +72,8 @@ def test_regressor(client):
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = regressor.predict(X)
|
||||
|
||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
||||
|
||||
history = regressor.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
@@ -88,6 +91,8 @@ def test_classifier(client):
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
|
||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
||||
|
||||
history = classifier.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
|
||||
Reference in New Issue
Block a user