[dask] Fix empty partition with pandas input. (#7644)
Empty partition is different from empty dataset. For the former case, each worker has non-empty dask collections, but each collection might contain empty partition.
This commit is contained in:
parent
1f020a6097
commit
b52c4e13b0
@ -1062,6 +1062,9 @@ def _maybe_dataframe(
|
|||||||
prediction, columns=columns, dtype=numpy.float32, index=index
|
prediction, columns=columns, dtype=numpy.float32, index=index
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if prediction.size == 0:
|
||||||
|
return DataFrame({}, columns=columns, dtype=numpy.float32, index=index)
|
||||||
|
|
||||||
prediction = DataFrame(
|
prediction = DataFrame(
|
||||||
prediction, columns=columns, dtype=numpy.float32, index=index
|
prediction, columns=columns, dtype=numpy.float32, index=index
|
||||||
)
|
)
|
||||||
|
|||||||
@ -85,12 +85,9 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
|
|||||||
ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes);
|
ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes);
|
||||||
out_preds->Copy(*base_margin);
|
out_preds->Copy(*base_margin);
|
||||||
} else {
|
} else {
|
||||||
if (out_preds->Empty()) {
|
out_preds->Resize(n);
|
||||||
out_preds->Resize(n, model.learner_model_param->base_score);
|
// cannot rely on the Resize to fill as it might skip if the size is already correct.
|
||||||
} else {
|
out_preds->Fill(model.learner_model_param->base_score);
|
||||||
out_preds->Resize(n);
|
|
||||||
out_preds->Fill(model.learner_model_param->base_score);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -7,8 +7,6 @@ import numpy as np
|
|||||||
import asyncio
|
import asyncio
|
||||||
import xgboost
|
import xgboost
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
|
||||||
import json
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from hypothesis import given, strategies, settings, note
|
from hypothesis import given, strategies, settings, note
|
||||||
@ -321,9 +319,9 @@ class TestDistributedGPU:
|
|||||||
mult = 100
|
mult = 100
|
||||||
df = cudf.DataFrame(
|
df = cudf.DataFrame(
|
||||||
{
|
{
|
||||||
"a": [1,2,3,4,5.1] * mult,
|
"a": [1, 2, 3, 4, 5.1] * mult,
|
||||||
"b": [10,15,29.3,30,31] * mult,
|
"b": [10, 15, 29.3, 30, 31] * mult,
|
||||||
"y": [10,20,30,40.,50] * mult,
|
"y": [10, 20, 30, 40., 50] * mult,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
|
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
|
||||||
@ -350,10 +348,34 @@ class TestDistributedGPU:
|
|||||||
y = ddf[["y"]]
|
y = ddf[["y"]]
|
||||||
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
|
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
|
||||||
bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")])
|
bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")])
|
||||||
predt = dxgb.predict(client, bst, X).compute().values
|
|
||||||
|
|
||||||
|
predt = dxgb.predict(client, bst, X).compute().values
|
||||||
cupy.testing.assert_allclose(predt, predt_empty)
|
cupy.testing.assert_allclose(predt, predt_empty)
|
||||||
|
|
||||||
|
predt = dxgb.predict(client, bst, dtrain).compute()
|
||||||
|
cupy.testing.assert_allclose(predt, predt_empty)
|
||||||
|
|
||||||
|
predt = dxgb.inplace_predict(client, bst, X).compute().values
|
||||||
|
cupy.testing.assert_allclose(predt, predt_empty)
|
||||||
|
|
||||||
|
df = df.to_pandas()
|
||||||
|
empty = df.iloc[:0]
|
||||||
|
ddf = dd.concat(
|
||||||
|
[dd.from_pandas(empty, npartitions=1)]
|
||||||
|
+ [dd.from_pandas(df, npartitions=3)]
|
||||||
|
+ [dd.from_pandas(df, npartitions=3)]
|
||||||
|
)
|
||||||
|
X = ddf[ddf.columns.difference(["y"])]
|
||||||
|
y = ddf[["y"]]
|
||||||
|
|
||||||
|
predt_empty = cupy.asnumpy(predt_empty)
|
||||||
|
|
||||||
|
predt = dxgb.predict(client, bst_empty, X).compute().values
|
||||||
|
np.testing.assert_allclose(predt, predt_empty)
|
||||||
|
|
||||||
|
in_predt = dxgb.inplace_predict(client, bst_empty, X).compute().values
|
||||||
|
np.testing.assert_allclose(predt, in_predt)
|
||||||
|
|
||||||
def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
|
def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||||
with Client(local_cuda_cluster) as client:
|
with Client(local_cuda_cluster) as client:
|
||||||
n_workers = len(_get_client_workers(client))
|
n_workers = len(_get_client_workers(client))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user