[dask] Fix ddqdm with empty partition. (#7510)

* Fix empty partition.

* war.
This commit is contained in:
Jiaming Yuan 2021-12-16 20:37:29 +08:00 committed by GitHub
parent a512b4b394
commit 70b12d898a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 35 deletions

View File

@ -1071,7 +1071,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
c_feature_types,
c_bst_ulong(len(feature_types))))
if len(feature_types) != self.num_col():
if len(feature_types) != self.num_col() and self.num_col() != 0:
msg = 'feature_types must have the same length as data'
raise ValueError(msg)
else:

View File

@ -1015,6 +1015,8 @@ def _maybe_dataframe(
index = getattr(data, "index", None)
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
import cudf
if prediction.size == 0:
return cudf.DataFrame({}, columns=columns, dtype=numpy.float32)
prediction = cudf.DataFrame(
prediction, columns=columns, dtype=numpy.float32, index=index

View File

@ -599,7 +599,7 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
}
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
if (size != 0) {
if (size != 0 && this->num_col_ != 0) {
CHECK_EQ(size, this->num_col_)
<< "Length of " << key << " must be equal to number of columns.";
}

View File

@ -39,7 +39,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
private:
common::Span<ArrayInterface<1>> columns_;
size_t num_rows_;
size_t num_rows_{0};
};
/*!

View File

@ -16,8 +16,8 @@ namespace data {
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
auto device =
adapter->DeviceIdx() < 0 ? dh::CurrentDevice() : adapter->DeviceIdx();
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
: adapter->DeviceIdx();
CHECK_GE(device, 0);
dh::safe_cuda(cudaSetDevice(device));

View File

@ -56,7 +56,7 @@ TEST(MetaInfo, GetSetFeature) {
std::vector<char const*> c_types(kCols);
std::transform(types.cbegin(), types.cend(), c_types.begin(),
[](auto const &str) { return str.c_str(); });
// Info has 0 column
info.num_col_ = 1;
EXPECT_THROW(
info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()),
dmlc::Error);

View File

@ -18,6 +18,12 @@ if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
sys.path.append("tests/python")
import testing as tm # noqa
if tm.no_dask_cuda()["condition"]:
pytest.skip(tm.no_dask_cuda()["reason"], allow_module_level=True)
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa
from test_with_dask import run_auc # noqa
@ -30,7 +36,7 @@ from test_with_dask import generate_array # noqa
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
from test_with_dask import run_tree_stats # noqa
import testing as tm # noqa
try:
@ -312,11 +318,7 @@ def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask_cudf())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_dask_dataframe(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
run_with_dask_dataframe(dxgb.DaskDMatrix, client)
@ -328,13 +330,10 @@ class TestDistributedGPU:
dataset=tm.dataset_strategy,
)
@settings(deadline=duration(seconds=120), suppress_health_check=suppress)
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.parametrize(
"local_cuda_cluster", [{"n_workers": 2}], indirect=["local_cuda_cluster"]
)
@pytest.mark.mgpu
def test_gpu_hist(
self,
params: Dict,
@ -349,17 +348,12 @@ class TestDistributedGPU:
)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_dask_array(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
run_with_dask_array(dxgb.DaskDMatrix, client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
def test_early_stopping(self, local_cuda_cluster: LocalCUDACluster) -> None:
from sklearn.datasets import load_breast_cancer
with Client(local_cuda_cluster) as client:
@ -394,8 +388,6 @@ class TestDistributedGPU:
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.parametrize("model", ["boosting"])
def test_dask_classifier(
self, model: str, local_cuda_cluster: LocalCUDACluster
@ -409,9 +401,6 @@ class TestDistributedGPU:
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
run_dask_classifier(X, y, w, model, "gpu_hist", client, 10)
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
parameters = {'tree_method': 'gpu_hist',
@ -419,6 +408,48 @@ class TestDistributedGPU:
run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)
@pytest.mark.skipif(**tm.no_dask_cudf())
def test_empty_partition(self, local_cuda_cluster: LocalCUDACluster) -> None:
import dask_cudf
import cudf
import cupy
with Client(local_cuda_cluster) as client:
mult = 100
df = cudf.DataFrame(
{
"a": [1,2,3,4,5.1] * mult,
"b": [10,15,29.3,30,31] * mult,
"y": [10,20,30,40.,50] * mult,
}
)
parameters = {"tree_method": "gpu_hist", "debug_synchronize": True}
empty = df.iloc[:0]
ddf = dask_cudf.concat(
[dask_cudf.from_cudf(empty, npartitions=1)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
bst_empty = xgb.dask.train(
client, parameters, dtrain, evals=[(dtrain, "train")]
)
predt_empty = dxgb.predict(client, bst_empty, X).compute().values
ddf = dask_cudf.concat(
[dask_cudf.from_cudf(df, npartitions=3)]
+ [dask_cudf.from_cudf(df, npartitions=3)]
)
X = ddf[ddf.columns.difference(["y"])]
y = ddf[["y"]]
dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y)
bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")])
predt = dxgb.predict(client, bst, X).compute().values
cupy.testing.assert_allclose(predt, predt_empty)
def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
n_workers = len(_get_client_workers(client))
@ -550,16 +581,10 @@ class TestDistributedGPU:
assert msg.find('1 test from GPUQuantile') != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
@pytest.mark.gtest
def test_quantile_basic(self, local_cuda_cluster: LocalCUDACluster) -> None:
self.run_quantile('AllReduceBasic', local_cuda_cluster)
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
@pytest.mark.gtest
def test_quantile_same_on_all_workers(
self, local_cuda_cluster: LocalCUDACluster
@ -594,10 +619,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur
return output
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_with_asyncio(local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
address = client.scheduler.address

View File

@ -503,12 +503,12 @@ def test_empty_dmatrix_training_continuation(client: "Client") -> None:
kRows, kCols = 1, 97
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.rand(kRows))
X.columns = ['X' + str(i) for i in range(0, 97)]
X.columns = ['X' + str(i) for i in range(0, kCols)]
dtrain = xgb.dask.DaskDMatrix(client, X, y)
kRows += 1000
X = dd.from_array(np.random.randn(kRows, kCols), chunksize=10)
X.columns = ['X' + str(i) for i in range(0, 97)]
X.columns = ['X' + str(i) for i in range(0, kCols)]
y = dd.from_array(np.random.rand(kRows), chunksize=10)
valid = xgb.dask.DaskDMatrix(client, X, y)