diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 11138140c..e4187a710 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1071,7 +1071,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes c_feature_types, c_bst_ulong(len(feature_types)))) - if len(feature_types) != self.num_col(): + if len(feature_types) != self.num_col() and self.num_col() != 0: msg = 'feature_types must have the same length as data' raise ValueError(msg) else: diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 2f5b732bc..665f7c264 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1015,6 +1015,8 @@ def _maybe_dataframe( index = getattr(data, "index", None) if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): import cudf + if prediction.size == 0: + return cudf.DataFrame({}, columns=columns, dtype=numpy.float32) prediction = cudf.DataFrame( prediction, columns=columns, dtype=numpy.float32, index=index diff --git a/src/data/data.cc b/src/data/data.cc index 465a72bef..fd3f2b6db 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -599,7 +599,7 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, } void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) { - if (size != 0) { + if (size != 0 && this->num_col_ != 0) { CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns."; } diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index d1bda280a..145bb56dd 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -39,7 +39,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { private: common::Span> columns_; - size_t num_rows_; + size_t num_rows_{0}; }; /*! diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index e770d4aa2..da4000ed1 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -16,8 +16,8 @@ namespace data { // be supported in future. Does not currently support inferring row/column size template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { - auto device = - adapter->DeviceIdx() < 0 ? dh::CurrentDevice() : adapter->DeviceIdx(); + auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice() + : adapter->DeviceIdx(); CHECK_GE(device, 0); dh::safe_cuda(cudaSetDevice(device)); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 4a1ef7a72..c8fbe43a4 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -56,7 +56,7 @@ TEST(MetaInfo, GetSetFeature) { std::vector c_types(kCols); std::transform(types.cbegin(), types.cend(), c_types.begin(), [](auto const &str) { return str.c_str(); }); - // Info has 0 column + info.num_col_ = 1; EXPECT_THROW( info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error); diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 7150566c8..63ba4f94c 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -18,6 +18,12 @@ if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) sys.path.append("tests/python") +import testing as tm # noqa + +if tm.no_dask_cuda()["condition"]: + pytest.skip(tm.no_dask_cuda()["reason"], allow_module_level=True) + + from test_with_dask import run_empty_dmatrix_reg # noqa from test_with_dask import run_empty_dmatrix_auc # noqa from test_with_dask import run_auc # noqa @@ -30,7 +36,7 @@ from test_with_dask import generate_array # noqa from test_with_dask import kCols as random_cols # noqa from test_with_dask import suppress # noqa from test_with_dask import run_tree_stats # noqa -import testing as tm # noqa + try: @@ -312,11 +318,7 @@ def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None: class TestDistributedGPU: - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf()) - @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.mgpu def test_dask_dataframe(self, local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: run_with_dask_dataframe(dxgb.DaskDMatrix, client) @@ -328,13 +330,10 @@ class TestDistributedGPU: dataset=tm.dataset_strategy, ) @settings(deadline=duration(seconds=120), suppress_health_check=suppress) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.parametrize( "local_cuda_cluster", [{"n_workers": 2}], indirect=["local_cuda_cluster"] ) - @pytest.mark.mgpu def test_gpu_hist( self, params: Dict, @@ -349,17 +348,12 @@ class TestDistributedGPU: ) @pytest.mark.skipif(**tm.no_cupy()) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.mgpu def test_dask_array(self, local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: run_with_dask_array(dxgb.DaskDMatrix, client) run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client) @pytest.mark.skipif(**tm.no_cupy()) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) def test_early_stopping(self, local_cuda_cluster: LocalCUDACluster) -> None: from sklearn.datasets import load_breast_cancer with Client(local_cuda_cluster) as client: @@ -394,8 +388,6 @@ class TestDistributedGPU: assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 @pytest.mark.skipif(**tm.no_cudf()) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.parametrize("model", ["boosting"]) def test_dask_classifier( self, model: str, local_cuda_cluster: LocalCUDACluster @@ -409,9 +401,6 @@ class TestDistributedGPU: w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_)) run_dask_classifier(X, y, w, model, "gpu_hist", client, 10) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.mgpu def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: parameters = {'tree_method': 'gpu_hist', @@ -419,6 +408,48 @@ class TestDistributedGPU: run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) + @pytest.mark.skipif(**tm.no_dask_cudf()) + def test_empty_partition(self, local_cuda_cluster: LocalCUDACluster) -> None: + import dask_cudf + import cudf + import cupy + with Client(local_cuda_cluster) as client: + mult = 100 + df = cudf.DataFrame( + { + "a": [1,2,3,4,5.1] * mult, + "b": [10,15,29.3,30,31] * mult, + "y": [10,20,30,40.,50] * mult, + } + ) + parameters = {"tree_method": "gpu_hist", "debug_synchronize": True} + + empty = df.iloc[:0] + ddf = dask_cudf.concat( + [dask_cudf.from_cudf(empty, npartitions=1)] + + [dask_cudf.from_cudf(df, npartitions=3)] + + [dask_cudf.from_cudf(df, npartitions=3)] + ) + X = ddf[ddf.columns.difference(["y"])] + y = ddf[["y"]] + dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y) + bst_empty = xgb.dask.train( + client, parameters, dtrain, evals=[(dtrain, "train")] + ) + predt_empty = dxgb.predict(client, bst_empty, X).compute().values + + ddf = dask_cudf.concat( + [dask_cudf.from_cudf(df, npartitions=3)] + + [dask_cudf.from_cudf(df, npartitions=3)] + ) + X = ddf[ddf.columns.difference(["y"])] + y = ddf[["y"]] + dtrain = dxgb.DaskDeviceQuantileDMatrix(client, X, y) + bst = xgb.dask.train(client, parameters, dtrain, evals=[(dtrain, "train")]) + predt = dxgb.predict(client, bst, X).compute().values + + cupy.testing.assert_allclose(predt, predt_empty) + def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: n_workers = len(_get_client_workers(client)) @@ -550,16 +581,10 @@ class TestDistributedGPU: assert msg.find('1 test from GPUQuantile') != -1, msg assert ret.returncode == 0, msg - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.mgpu @pytest.mark.gtest def test_quantile_basic(self, local_cuda_cluster: LocalCUDACluster) -> None: self.run_quantile('AllReduceBasic', local_cuda_cluster) - @pytest.mark.skipif(**tm.no_dask()) - @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.mgpu @pytest.mark.gtest def test_quantile_same_on_all_workers( self, local_cuda_cluster: LocalCUDACluster @@ -594,10 +619,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur return output -@pytest.mark.skipif(**tm.no_dask()) -@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_cupy()) -@pytest.mark.mgpu def test_with_asyncio(local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: address = client.scheduler.address diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 2602529ce..68f3d8eff 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -503,12 +503,12 @@ def test_empty_dmatrix_training_continuation(client: "Client") -> None: kRows, kCols = 1, 97 X = dd.from_array(np.random.randn(kRows, kCols)) y = dd.from_array(np.random.rand(kRows)) - X.columns = ['X' + str(i) for i in range(0, 97)] + X.columns = ['X' + str(i) for i in range(0, kCols)] dtrain = xgb.dask.DaskDMatrix(client, X, y) kRows += 1000 X = dd.from_array(np.random.randn(kRows, kCols), chunksize=10) - X.columns = ['X' + str(i) for i in range(0, 97)] + X.columns = ['X' + str(i) for i in range(0, kCols)] y = dd.from_array(np.random.rand(kRows), chunksize=10) valid = xgb.dask.DaskDMatrix(client, X, y)