[dask] Fix empty partition with pandas input. (#7644)

Empty partition is different from empty dataset.  For the former case, each worker has
non-empty dask collections, but each collection might contain empty partition.
This commit is contained in:
Jiaming Yuan 2022-02-14 19:35:51 +08:00 committed by GitHub
parent 1f020a6097
commit b52c4e13b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 12 deletions

View File

@ -1062,6 +1062,9 @@ def _maybe_dataframe(
prediction, columns=columns, dtype=numpy.float32, index=index prediction, columns=columns, dtype=numpy.float32, index=index
) )
else: else:
if prediction.size == 0:
return DataFrame({}, columns=columns, dtype=numpy.float32, index=index)
prediction = DataFrame( prediction = DataFrame(
prediction, columns=columns, dtype=numpy.float32, index=index prediction, columns=columns, dtype=numpy.float32, index=index
) )

View File

@ -85,12 +85,9 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes);
out_preds->Copy(*base_margin); out_preds->Copy(*base_margin);
} else { } else {
if (out_preds->Empty()) { out_preds->Resize(n);
out_preds->Resize(n, model.learner_model_param->base_score); // cannot rely on the Resize to fill as it might skip if the size is already correct.
} else { out_preds->Fill(model.learner_model_param->base_score);
out_preds->Resize(n);
out_preds->Fill(model.learner_model_param->base_score);
}
} }
} }
} // namespace xgboost } // namespace xgboost

View File

@ -7,8 +7,6 @@ import numpy as np
import asyncio import asyncio
import xgboost import xgboost
import subprocess import subprocess
import tempfile
import json
from collections import OrderedDict from collections import OrderedDict
from inspect import signature from inspect import signature
from hypothesis import given, strategies, settings, note from hypothesis import given, strategies, settings, note
@ -321,9 +319,9 @@ class TestDistributedGPU:
mult = 100 mult = 100
df = cudf.DataFrame( df = cudf.DataFrame(
{ {
"a": [1,2,3,4,5.1] * mult, "a": [1, 2, 3, 4, 5.1] * mult,
"b": [10,15,29.3,30,31] * mult, "b": [10, 15, 29.3, 30, 31] * mult,
"y": [10,20,30,40.,50] * mult, "y": [10, 20, 30, 40., 50] * mult,
} }
) )
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True} parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
@ -350,10 +348,34 @@ class TestDistributedGPU:
y = ddf[["y"]] y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y) dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")]) bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")])
predt = dxgb.predict(client, bst, X).compute().values
predt = dxgb.predict(client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty) cupy.testing.assert_allclose(predt, predt_empty)
predt = dxgb.predict(client, bst, dtrain).compute()
cupy.testing.assert_allclose(predt, predt_empty)
predt = dxgb.inplace_predict(client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty)
df = df.to_pandas()
empty = df.iloc[:0]
ddf = dd.concat(
[dd.from_pandas(empty, npartitions=1)]
+ [dd.from_pandas(df, npartitions=3)]
+ [dd.from_pandas(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
predt_empty = cupy.asnumpy(predt_empty)
predt = dxgb.predict(client, bst_empty, X).compute().values
np.testing.assert_allclose(predt, predt_empty)
in_predt = dxgb.inplace_predict(client, bst_empty, X).compute().values
np.testing.assert_allclose(predt, in_predt)
def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None: def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client: with Client(local_cuda_cluster) as client:
n_workers = len(_get_client_workers(client)) n_workers = len(_get_client_workers(client))