diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 5e7e8624f..b13306cb3 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -14,6 +14,7 @@ https://github.com/dask/dask-xgboost import platform import logging from collections import defaultdict +from collections.abc import Sequence from threading import Thread import numpy @@ -28,7 +29,7 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat from .compat import CUDF_concat from .compat import lazy_isinstance -from .core import DMatrix, Booster, _expect +from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter from .training import train as worker_train from .tracker import RabitTracker from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase @@ -357,6 +358,146 @@ class DaskDMatrix: return (rows, cols) +class DaskPartitionIter(DataIter): # pylint: disable=R0902 + '''A data iterator for `DaskDeviceQuantileDMatrix`. + ''' + def __init__(self, data, label=None, weight=None, base_margin=None, + label_lower_bound=None, label_upper_bound=None, + feature_names=None, feature_types=None): + self._data = data + self._labels = label + self._weights = weight + self._base_margin = base_margin + self._label_lower_bound = label_lower_bound + self._label_upper_bound = label_upper_bound + self._feature_names = feature_names + self._feature_types = feature_types + + assert isinstance(self._data, Sequence) + + types = (Sequence, type(None)) + assert isinstance(self._labels, types) + assert isinstance(self._weights, types) + assert isinstance(self._base_margin, types) + assert isinstance(self._label_lower_bound, types) + assert isinstance(self._label_upper_bound, types) + + self._iter = 0 # set iterator to 0 + super().__init__() + + def data(self): + '''Utility function for obtaining current batch of data.''' + return self._data[self._iter] + + def labels(self): + '''Utility function for obtaining current batch of label.''' + if self._labels is not None: + return self._labels[self._iter] + return None + + def weights(self): + '''Utility function for obtaining current batch of label.''' + if self._weights is not None: + return self._weights[self._iter] + return None + + def base_margins(self): + '''Utility function for obtaining current batch of base_margin.''' + if self._base_margin is not None: + return self._base_margin[self._iter] + return None + + def label_lower_bounds(self): + '''Utility function for obtaining current batch of label_lower_bound. + ''' + if self._label_lower_bound is not None: + return self._label_lower_bound[self._iter] + return None + + def label_upper_bounds(self): + '''Utility function for obtaining current batch of label_upper_bound. + ''' + if self._label_upper_bound is not None: + return self._label_upper_bound[self._iter] + return None + + def reset(self): + '''Reset the iterator''' + self._iter = 0 + + def next(self, input_data): + '''Yield next batch of data''' + if self._iter == len(self._data): + # Return 0 when there's no more batch. + return 0 + if self._feature_names: + feature_names = self._feature_names + else: + if hasattr(self.data(), 'columns'): + feature_names = self.data().columns.format() + else: + feature_names = None + input_data(data=self.data(), label=self.labels(), + weight=self.weights(), group=None, + label_lower_bound=self.label_lower_bounds(), + label_upper_bound=self.label_upper_bounds(), + feature_names=feature_names, + feature_types=self._feature_types) + self._iter += 1 + return 1 + + +class DaskDeviceQuantileDMatrix(DaskDMatrix): + '''Specialized data type for `gpu_hist` tree method. This class is + used to reduce the memory usage by eliminating data copies. + Internally the data is merged by weighted GK sketching. So the + number of partitions from dask may affect training accuracy as GK + generates error for each merge. + + .. versionadded:: 1.2.0 + + Parameters + ---------- + max_bin: Number of bins for histogram construction. + + ''' + def __init__(self, client, data, label=None, weight=None, + missing=None, + feature_names=None, + feature_types=None, + max_bin=256): + super().__init__(client=client, data=data, label=label, weight=weight, + missing=missing, + feature_names=feature_names, + feature_types=feature_types) + self.max_bin = max_bin + + def get_worker_data(self, worker): + if worker.address not in set(self.worker_map.keys()): + msg = 'worker {address} has an empty DMatrix. ' \ + 'All workers associated with this DMatrix: {workers}'.format( + address=worker.address, + workers=set(self.worker_map.keys())) + LOGGER.warning(msg) + import cupy # pylint: disable=import-error + d = DeviceQuantileDMatrix(cupy.zeros((0, 0)), + feature_names=self.feature_names, + feature_types=self.feature_types, + max_bin=self.max_bin) + return d + + data, labels, weights = self.get_worker_parts(worker) + it = DaskPartitionIter(data=data, label=labels, weight=weights) + + dmatrix = DeviceQuantileDMatrix(it, + missing=self.missing, + feature_names=self.feature_names, + feature_types=self.feature_types, + nthread=worker.nthreads, + max_bin=self.max_bin) + return dmatrix + + def _get_rabit_args(worker_map, client): '''Get rabit context arguments from data distribution in DaskDMatrix.''' host = distributed_comm.get_address_host(client.scheduler.address) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 33064d4cd..0caab3ec1 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -15,7 +15,7 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name def _warn_unused_missing(data, missing): - if not (np.isnan(missing) or None): + if (not np.isnan(missing)) or (missing is None): warnings.warn( '`missing` is not used for current input data type:' + str(type(data))) diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index e7a1218bb..cd5833914 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -85,6 +85,7 @@ class SketchContainer { // Initialize Sketches for this dmatrix this->columns_ptr_.SetDevice(device_); this->columns_ptr_.Resize(num_columns + 1); + CHECK_GE(device, 0); timer_.Init(__func__); } /* \brief Return GPU ID for this container. */ diff --git a/src/data/array_interface.h b/src/data/array_interface.h index c8abb2d45..5539e16f0 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -114,7 +114,6 @@ class ArrayInterfaceHandler { get( obj.at("data")) .at(0)))); - CHECK(p_data); return p_data; } @@ -224,6 +223,9 @@ class ArrayInterfaceHandler { auto shape = ExtractShape(column); T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData(column); + if (!p_data) { + CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape."; + } return common::Span{p_data, shape.first * shape.second}; } }; @@ -234,7 +236,6 @@ class ArrayInterface { bool allow_mask = true) { ArrayInterfaceHandler::Validate(column); data = ArrayInterfaceHandler::GetPtrFromArrayData(column); - CHECK(data) << "Column is null"; auto shape = ArrayInterfaceHandler::ExtractShape(column); num_rows = shape.first; num_cols = shape.second; diff --git a/src/data/data.cc b/src/data/data.cc index b3652c45b..e8a9a8582 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -488,6 +488,15 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) { this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1, group_ptr.end()); } + + if (!that.feature_names.empty()) { + this->feature_names = that.feature_names; + } + if (!that.feature_type_names.empty()) { + this->feature_type_names = that.feature_type_names; + auto &h_feature_types = feature_types.HostVector(); + LoadFeatureType(this->feature_type_names, &h_feature_types); + } } void MetaInfo::Validate(int32_t device) const { diff --git a/src/data/data.cu b/src/data/data.cu index fb57f4751..5e63a828c 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -69,6 +69,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { << "Meta info " << key << " should be dense, found validity mask"; CHECK_EQ(array_interface.num_cols, 1) << "Meta info should be a single column."; + if (array_interface.num_rows == 0) { + return; + } if (key == "label") { CopyInfoImpl(array_interface, &labels_); diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 10ae6ba3d..709368f5c 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -122,10 +122,14 @@ class CudfAdapter : public detail::SingleBatchDataIter { CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian(); std::vector columns; auto first_column = ArrayInterface(get(json_columns[0])); + num_rows_ = first_column.num_rows; + if (num_rows_ == 0) { + return; + } + device_idx_ = dh::CudaGetPointerDevice(first_column.data); CHECK_NE(device_idx_, -1); dh::safe_cuda(cudaSetDevice(device_idx_)); - num_rows_ = first_column.num_rows; for (auto& json_col : json_columns) { auto column = ArrayInterface(get(json_col)); columns.push_back(column); @@ -183,9 +187,12 @@ class CupyAdapter : public detail::SingleBatchDataIter { Json json_array_interface = Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()}); array_interface_ = ArrayInterface(get(json_array_interface), false); + batch_ = CupyAdapterBatch(array_interface_); + if (array_interface_.num_rows == 0) { + return; + } device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); CHECK_NE(device_idx_, -1); - batch_ = CupyAdapterBatch(array_interface_); } const CupyAdapterBatch& Value() const override { return batch_; } diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 3f142acd3..b99f99590 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -62,23 +62,30 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin size_t batches = 0; size_t accumulated_rows = 0; bst_feature_t cols = 0; - int32_t device = -1; + int32_t device = GenericParameter::kCpuId; + int32_t current_device_; + dh::safe_cuda(cudaGetDevice(¤t_device_)); + auto get_device = [&]() -> int32_t { + int32_t d = GenericParameter::kCpuId ? current_device_ : device; + return d; + }; + while (iter.Next()) { device = proxy->DeviceIdx(); - dh::safe_cuda(cudaSetDevice(device)); + dh::safe_cuda(cudaSetDevice(get_device())); if (cols == 0) { cols = num_cols(); + rabit::Allreduce(&cols, 1); } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; } - sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device); + sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device()); auto* p_sketch = &sketch_containers.back(); - proxy->Info().weights_.SetDevice(device); + proxy->Info().weights_.SetDevice(get_device()); Dispatch(proxy, [&](auto const &value) { common::AdapterDeviceSketch(value, batch_param_.max_bin, proxy->Info(), missing, p_sketch); }); - auto batch_rows = num_rows(); accumulated_rows += batch_rows; dh::caching_device_vector row_counts(batch_rows + 1, 0); @@ -86,19 +93,15 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin row_counts.size()); row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) { return GetRowCounts(value, row_counts_span, - device, missing); + get_device(), missing); })); nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); batches++; } - - if (device < 0) { // error or empty - this->page_.reset(new EllpackPage); - return; - } - - common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device); + iter.Reset(); + dh::safe_cuda(cudaSetDevice(get_device())); + common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device()); for (auto const& sketch : sketch_containers) { final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); final_sketch.FixError(); @@ -113,14 +116,14 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin this->info_.num_row_ = accumulated_rows; this->info_.num_nonzero_ = nnz; - auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() { + auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows, + get_device]() { if (!page_) { // Should be put inside the while loop to protect against empty batch. In // that case device id is invalid. page_.reset(new EllpackPage); - *(page_->Impl()) = - EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride, - accumulated_rows); + *(page_->Impl()) = EllpackPageImpl(get_device(), cuts, this->IsDense(), + row_stride, accumulated_rows); } }; @@ -130,21 +133,20 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin size_t n_batches_for_verification = 0; while (iter.Next()) { init_page(); - auto device = proxy->DeviceIdx(); - dh::safe_cuda(cudaSetDevice(device)); + dh::safe_cuda(cudaSetDevice(get_device())); auto rows = num_rows(); dh::caching_device_vector row_counts(rows + 1, 0); common::Span row_counts_span(row_counts.data().get(), row_counts.size()); Dispatch(proxy, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, device, missing); + return GetRowCounts(value, row_counts_span, get_device(), missing); }); auto is_dense = this->IsDense(); auto new_impl = Dispatch(proxy, [&](auto const &value) { - return EllpackPageImpl(value, missing, device, is_dense, nthread, - row_counts_span, row_stride, rows, cols, cuts); + return EllpackPageImpl(value, missing, get_device(), is_dense, nthread, + row_counts_span, row_stride, rows, cols, cuts); }); - size_t num_elements = page_->Impl()->Copy(device, &new_impl, offset); + size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset); offset += num_elements; proxy->Info().num_row_ = num_rows(); diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 377a05010..177402ca1 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -158,18 +158,22 @@ struct EvalMClassBase : public Metric { bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK(preds.Size() % info.labels_.Size() == 0) - << "label and prediction size not match"; - const size_t nclass = preds.Size() / info.labels_.Size(); - CHECK_GE(nclass, 1U) - << "mlogloss and merror are only used for multi-class classification," - << " use logloss for binary classification"; - - int device = tparam_->gpu_id; - auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds); - double dat[2] { result.Residue(), result.Weights() }; - + if (info.labels_.Size() == 0) { + CHECK_EQ(preds.Size(), 0); + } else { + CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match"; + } + double dat[2] { 0.0, 0.0 }; + if (info.labels_.Size() != 0) { + const size_t nclass = preds.Size() / info.labels_.Size(); + CHECK_GE(nclass, 1U) + << "mlogloss and merror are only used for multi-class classification," + << " use logloss for binary classification"; + int device = tparam_->gpu_id; + auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds); + dat[0] = result.Residue(); + dat[1] = result.Weights(); + } if (distributed) { rabit::Allreduce(dat, 2); } diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 29af5e0d2..4d26d40e3 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -49,7 +49,9 @@ class SoftmaxMultiClassObj : public ObjFunction { const MetaInfo& info, int iter, HostDeviceVector* out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + if (info.labels_.Size() == 0) { + return; + } CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels_.Size())) << "SoftmaxMultiClassObj: label size and pred size does not match.\n" << "label.Size() * num_class: " diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 382835c58..8cb0bc27e 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -6,13 +6,15 @@ import unittest import xgboost import subprocess from hypothesis import given, strategies, settings, note +from hypothesis._settings import duration from test_gpu_updaters import parameter_strategy if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) sys.path.append("tests/python") -from test_with_dask import run_empty_dmatrix # noqa +from test_with_dask import run_empty_dmatrix_reg # noqa +from test_with_dask import run_empty_dmatrix_cls # noqa from test_with_dask import generate_array # noqa import testing as tm # noqa @@ -28,6 +30,126 @@ except ImportError: pass +def run_with_dask_dataframe(DMatrixT, client): + import cupy as cp + cp.cuda.runtime.setDevice(0) + X, y = generate_array() + + X = dd.from_dask_array(X) + y = dd.from_dask_array(y) + + X = X.map_partitions(cudf.from_pandas) + y = y.map_partitions(cudf.from_pandas) + + dtrain = DMatrixT(client, X, y) + out = dxgb.train(client, {'tree_method': 'gpu_hist', + 'debug_synchronize': True}, + dtrain=dtrain, + evals=[(dtrain, 'X')], + num_boost_round=4) + + assert isinstance(out['booster'], dxgb.Booster) + assert len(out['history']['X']['rmse']) == 4 + + predictions = dxgb.predict(client, out, dtrain).compute() + assert isinstance(predictions, np.ndarray) + + series_predictions = dxgb.inplace_predict(client, out, X) + assert isinstance(series_predictions, dd.Series) + series_predictions = series_predictions.compute() + + single_node = out['booster'].predict( + xgboost.DMatrix(X.compute())) + + cp.testing.assert_allclose(single_node, predictions) + np.testing.assert_allclose(single_node, + series_predictions.to_array()) + + predt = dxgb.predict(client, out, X) + assert isinstance(predt, dd.Series) + + def is_df(part): + assert isinstance(part, cudf.DataFrame), part + return part + + predt.map_partitions( + is_df, + meta=dd.utils.make_meta({'prediction': 'f4'})) + + cp.testing.assert_allclose( + predt.values.compute(), single_node) + + +def run_with_dask_array(DMatrixT, client): + import cupy as cp + cp.cuda.runtime.setDevice(0) + X, y = generate_array() + + X = X.map_blocks(cp.asarray) + y = y.map_blocks(cp.asarray) + dtrain = DMatrixT(client, X, y) + out = dxgb.train(client, {'tree_method': 'gpu_hist', + 'debug_synchronize': True}, + dtrain=dtrain, + evals=[(dtrain, 'X')], + num_boost_round=2) + from_dmatrix = dxgb.predict(client, out, dtrain).compute() + inplace_predictions = dxgb.inplace_predict( + client, out, X).compute() + single_node = out['booster'].predict( + xgboost.DMatrix(X.compute())) + np.testing.assert_allclose(single_node, from_dmatrix) + device = cp.cuda.runtime.getDevice() + assert device == inplace_predictions.device.id + single_node = cp.array(single_node) + assert device == single_node.device.id + cp.testing.assert_allclose( + single_node, + inplace_predictions) + + +def to_cp(x, DMatrixT): + import cupy + if isinstance(x, np.ndarray) and \ + DMatrixT is dxgb.DaskDeviceQuantileDMatrix: + X = cupy.array(x) + else: + X = x + return X + + +def run_gpu_hist(params, num_rounds, dataset, DMatrixT, client): + params['tree_method'] = 'gpu_hist' + params = dataset.set_params(params) + # It doesn't make sense to distribute a completely + # empty dataset. + if dataset.X.shape[0] == 0: + return + + chunk = 128 + X = to_cp(dataset.X, DMatrixT) + X = da.from_array(X, + chunks=(chunk, dataset.X.shape[1])) + y = to_cp(dataset.y, DMatrixT) + y = da.from_array(y, chunks=(chunk, )) + if dataset.w is not None: + w = to_cp(dataset.w, DMatrixT) + w = da.from_array(w, chunks=(chunk, )) + else: + w = None + + if DMatrixT is dxgb.DaskDeviceQuantileDMatrix: + m = DMatrixT(client, data=X, label=y, weight=w, + max_bin=params.get('max_bin', 256)) + else: + m = DMatrixT(client, data=X, label=y, weight=w) + history = dxgb.train(client, params=params, dtrain=m, + num_boost_round=num_rounds, + evals=[(m, 'train')])['history'] + note(history) + assert tm.non_increasing(history['train'][dataset.metric]) + + class TestDistributedGPU(unittest.TestCase): @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_cudf()) @@ -37,119 +159,28 @@ class TestDistributedGPU(unittest.TestCase): def test_dask_dataframe(self): with LocalCUDACluster() as cluster: with Client(cluster) as client: - import cupy as cp - cp.cuda.runtime.setDevice(0) - X, y = generate_array() - - X = dd.from_dask_array(X) - y = dd.from_dask_array(y) - - X = X.map_partitions(cudf.from_pandas) - y = y.map_partitions(cudf.from_pandas) - - dtrain = dxgb.DaskDMatrix(client, X, y) - out = dxgb.train(client, {'tree_method': 'gpu_hist', - 'debug_synchronize': True}, - dtrain=dtrain, - evals=[(dtrain, 'X')], - num_boost_round=4) - - assert isinstance(out['booster'], dxgb.Booster) - assert len(out['history']['X']['rmse']) == 4 - - predictions = dxgb.predict(client, out, dtrain).compute() - assert isinstance(predictions, np.ndarray) - - series_predictions = dxgb.inplace_predict(client, out, X) - assert isinstance(series_predictions, dd.Series) - series_predictions = series_predictions.compute() - - single_node = out['booster'].predict( - xgboost.DMatrix(X.compute())) - - cp.testing.assert_allclose(single_node, predictions) - np.testing.assert_allclose(single_node, - series_predictions.to_array()) - - predt = dxgb.predict(client, out, X) - assert isinstance(predt, dd.Series) - - def is_df(part): - assert isinstance(part, cudf.DataFrame), part - return part - - predt.map_partitions( - is_df, - meta=dd.utils.make_meta({'prediction': 'f4'})) - - cp.testing.assert_allclose( - predt.values.compute(), single_node) + run_with_dask_dataframe(dxgb.DaskDMatrix, client) + run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client) @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) - @settings(deadline=None) + @settings(deadline=duration(seconds=120)) @pytest.mark.mgpu def test_gpu_hist(self, params, num_rounds, dataset): with LocalCUDACluster(n_workers=2) as cluster: with Client(cluster) as client: - params['tree_method'] = 'gpu_hist' - params = dataset.set_params(params) - # multi class doesn't handle empty dataset well (empty - # means at least 1 worker has data). - if params['objective'] == "multi:softmax": - return - # It doesn't make sense to distribute a completely - # empty dataset. - if dataset.X.shape[0] == 0: - return - - chunk = 128 - X = da.from_array(dataset.X, - chunks=(chunk, dataset.X.shape[1])) - y = da.from_array(dataset.y, chunks=(chunk, )) - if dataset.w is not None: - w = da.from_array(dataset.w, chunks=(chunk, )) - else: - w = None - - m = dxgb.DaskDMatrix( - client, data=X, label=y, weight=w) - history = dxgb.train(client, params=params, dtrain=m, - num_boost_round=num_rounds, - evals=[(m, 'train')])['history'] - note(history) - assert tm.non_increasing(history['train'][dataset.metric]) + run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, + client) + run_gpu_hist(params, num_rounds, dataset, + dxgb.DaskDeviceQuantileDMatrix, client) @pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.mgpu def test_dask_array(self): with LocalCUDACluster() as cluster: with Client(cluster) as client: - import cupy as cp - cp.cuda.runtime.setDevice(0) - X, y = generate_array() - - X = X.map_blocks(cp.asarray) - y = y.map_blocks(cp.asarray) - dtrain = dxgb.DaskDMatrix(client, X, y) - out = dxgb.train(client, {'tree_method': 'gpu_hist', - 'debug_synchronize': True}, - dtrain=dtrain, - evals=[(dtrain, 'X')], - num_boost_round=2) - from_dmatrix = dxgb.predict(client, out, dtrain).compute() - inplace_predictions = dxgb.inplace_predict( - client, out, X).compute() - single_node = out['booster'].predict( - xgboost.DMatrix(X.compute())) - np.testing.assert_allclose(single_node, from_dmatrix) - device = cp.cuda.runtime.getDevice() - assert device == inplace_predictions.device.id - single_node = cp.array(single_node) - assert device == single_node.device.id - cp.testing.assert_allclose( - single_node, - inplace_predictions) + run_with_dask_array(dxgb.DaskDMatrix, client) + run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @@ -159,7 +190,8 @@ class TestDistributedGPU(unittest.TestCase): with Client(cluster) as client: parameters = {'tree_method': 'gpu_hist', 'debug_synchronize': True} - run_empty_dmatrix(client, parameters) + run_empty_dmatrix_reg(client, parameters) + run_empty_dmatrix_cls(client, parameters) def run_quantile(self, name): if sys.platform.startswith("win"): diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 744da4393..728c7e4b3 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -128,8 +128,7 @@ def test_dask_missing_value_reg(): def test_dask_missing_value_cls(): - # Multi-class doesn't handle empty DMatrix well. So we use lesser workers. - with LocalCluster(n_workers=2) as cluster: + with LocalCluster() as cluster: with Client(cluster) as client: X_0 = np.ones((kRows // 2, kCols)) X_1 = np.zeros((kRows // 2, kCols)) @@ -234,7 +233,7 @@ def test_sklearn_grid_search(): assert len(means) == len(set(means)) -def run_empty_dmatrix(client, parameters): +def run_empty_dmatrix_reg(client, parameters): def _check_outputs(out, predictions): assert isinstance(out['booster'], xgb.dask.Booster) @@ -271,6 +270,46 @@ def run_empty_dmatrix(client, parameters): _check_outputs(out, predictions) +def run_empty_dmatrix_cls(client, parameters): + n_classes = 4 + + def _check_outputs(out, predictions): + assert isinstance(out['booster'], xgb.dask.Booster) + assert len(out['history']['validation']['merror']) == 2 + assert isinstance(predictions, np.ndarray) + assert predictions.shape[1] == n_classes, predictions.shape + + kRows, kCols = 1, 97 + X = dd.from_array(np.random.randn(kRows, kCols)) + y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows)) + dtrain = xgb.dask.DaskDMatrix(client, X, y) + parameters['objective'] = 'multi:softprob' + parameters['num_class'] = n_classes + + out = xgb.dask.train(client, parameters, + dtrain=dtrain, + evals=[(dtrain, 'validation')], + num_boost_round=2) + predictions = xgb.dask.predict(client=client, model=out, + data=dtrain).compute() + _check_outputs(out, predictions) + + # train has more rows than evals + valid = dtrain + kRows += 1 + X = dd.from_array(np.random.randn(kRows, kCols)) + y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows)) + dtrain = xgb.dask.DaskDMatrix(client, X, y) + + out = xgb.dask.train(client, parameters, + dtrain=dtrain, + evals=[(valid, 'validation')], + num_boost_round=2) + predictions = xgb.dask.predict(client=client, model=out, + data=valid).compute() + _check_outputs(out, predictions) + + # No test for Exact, as empty DMatrix handling are mostly for distributed # environment and Exact doesn't support it. @@ -278,11 +317,13 @@ def test_empty_dmatrix_hist(): with LocalCluster(n_workers=5) as cluster: with Client(cluster) as client: parameters = {'tree_method': 'hist'} - run_empty_dmatrix(client, parameters) + run_empty_dmatrix_reg(client, parameters) + run_empty_dmatrix_cls(client, parameters) def test_empty_dmatrix_approx(): with LocalCluster(n_workers=5) as cluster: with Client(cluster) as client: parameters = {'tree_method': 'approx'} - run_empty_dmatrix(client, parameters) + run_empty_dmatrix_reg(client, parameters) + run_empty_dmatrix_cls(client, parameters)