Dask device dmatrix (#5901)

* Fix softprob with empty dmatrix.
This commit is contained in:
Jiaming Yuan 2020-07-17 13:17:43 +08:00 committed by GitHub
parent e471056ec4
commit 7c2686146e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 392 additions and 149 deletions

View File

@ -14,6 +14,7 @@ https://github.com/dask/dask-xgboost
import platform import platform
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from threading import Thread from threading import Thread
import numpy import numpy
@ -28,7 +29,7 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import CUDF_concat from .compat import CUDF_concat
from .compat import lazy_isinstance from .compat import lazy_isinstance
from .core import DMatrix, Booster, _expect from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .training import train as worker_train from .training import train as worker_train
from .tracker import RabitTracker from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
@ -357,6 +358,146 @@ class DaskDMatrix:
return (rows, cols) return (rows, cols)
class DaskPartitionIter(DataIter): # pylint: disable=R0902
'''A data iterator for `DaskDeviceQuantileDMatrix`.
'''
def __init__(self, data, label=None, weight=None, base_margin=None,
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None):
self._data = data
self._labels = label
self._weights = weight
self._base_margin = base_margin
self._label_lower_bound = label_lower_bound
self._label_upper_bound = label_upper_bound
self._feature_names = feature_names
self._feature_types = feature_types
assert isinstance(self._data, Sequence)
types = (Sequence, type(None))
assert isinstance(self._labels, types)
assert isinstance(self._weights, types)
assert isinstance(self._base_margin, types)
assert isinstance(self._label_lower_bound, types)
assert isinstance(self._label_upper_bound, types)
self._iter = 0 # set iterator to 0
super().__init__()
def data(self):
'''Utility function for obtaining current batch of data.'''
return self._data[self._iter]
def labels(self):
'''Utility function for obtaining current batch of label.'''
if self._labels is not None:
return self._labels[self._iter]
return None
def weights(self):
'''Utility function for obtaining current batch of label.'''
if self._weights is not None:
return self._weights[self._iter]
return None
def base_margins(self):
'''Utility function for obtaining current batch of base_margin.'''
if self._base_margin is not None:
return self._base_margin[self._iter]
return None
def label_lower_bounds(self):
'''Utility function for obtaining current batch of label_lower_bound.
'''
if self._label_lower_bound is not None:
return self._label_lower_bound[self._iter]
return None
def label_upper_bounds(self):
'''Utility function for obtaining current batch of label_upper_bound.
'''
if self._label_upper_bound is not None:
return self._label_upper_bound[self._iter]
return None
def reset(self):
'''Reset the iterator'''
self._iter = 0
def next(self, input_data):
'''Yield next batch of data'''
if self._iter == len(self._data):
# Return 0 when there's no more batch.
return 0
if self._feature_names:
feature_names = self._feature_names
else:
if hasattr(self.data(), 'columns'):
feature_names = self.data().columns.format()
else:
feature_names = None
input_data(data=self.data(), label=self.labels(),
weight=self.weights(), group=None,
label_lower_bound=self.label_lower_bounds(),
label_upper_bound=self.label_upper_bounds(),
feature_names=feature_names,
feature_types=self._feature_types)
self._iter += 1
return 1
class DaskDeviceQuantileDMatrix(DaskDMatrix):
'''Specialized data type for `gpu_hist` tree method. This class is
used to reduce the memory usage by eliminating data copies.
Internally the data is merged by weighted GK sketching. So the
number of partitions from dask may affect training accuracy as GK
generates error for each merge.
.. versionadded:: 1.2.0
Parameters
----------
max_bin: Number of bins for histogram construction.
'''
def __init__(self, client, data, label=None, weight=None,
missing=None,
feature_names=None,
feature_types=None,
max_bin=256):
super().__init__(client=client, data=data, label=label, weight=weight,
missing=missing,
feature_names=feature_names,
feature_types=feature_types)
self.max_bin = max_bin
def get_worker_data(self, worker):
if worker.address not in set(self.worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
LOGGER.warning(msg)
import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types,
max_bin=self.max_bin)
return d
data, labels, weights = self.get_worker_parts(worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights)
dmatrix = DeviceQuantileDMatrix(it,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=worker.nthreads,
max_bin=self.max_bin)
return dmatrix
def _get_rabit_args(worker_map, client): def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.''' '''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address) host = distributed_comm.get_address_host(client.scheduler.address)

View File

@ -15,7 +15,7 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
def _warn_unused_missing(data, missing): def _warn_unused_missing(data, missing):
if not (np.isnan(missing) or None): if (not np.isnan(missing)) or (missing is None):
warnings.warn( warnings.warn(
'`missing` is not used for current input data type:' + '`missing` is not used for current input data type:' +
str(type(data))) str(type(data)))

View File

@ -85,6 +85,7 @@ class SketchContainer {
// Initialize Sketches for this dmatrix // Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_); this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1); this->columns_ptr_.Resize(num_columns + 1);
CHECK_GE(device, 0);
timer_.Init(__func__); timer_.Init(__func__);
} }
/* \brief Return GPU ID for this container. */ /* \brief Return GPU ID for this container. */

View File

@ -114,7 +114,6 @@ class ArrayInterfaceHandler {
get<Array const>( get<Array const>(
obj.at("data")) obj.at("data"))
.at(0)))); .at(0))));
CHECK(p_data);
return p_data; return p_data;
} }
@ -224,6 +223,9 @@ class ArrayInterfaceHandler {
auto shape = ExtractShape(column); auto shape = ExtractShape(column);
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column); T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
if (!p_data) {
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
}
return common::Span<T>{p_data, shape.first * shape.second}; return common::Span<T>{p_data, shape.first * shape.second};
} }
}; };
@ -234,7 +236,6 @@ class ArrayInterface {
bool allow_mask = true) { bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column); ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column); data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
auto shape = ArrayInterfaceHandler::ExtractShape(column); auto shape = ArrayInterfaceHandler::ExtractShape(column);
num_rows = shape.first; num_rows = shape.first;
num_cols = shape.second; num_cols = shape.second;

View File

@ -488,6 +488,15 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1, this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
group_ptr.end()); group_ptr.end());
} }
if (!that.feature_names.empty()) {
this->feature_names = that.feature_names;
}
if (!that.feature_type_names.empty()) {
this->feature_type_names = that.feature_type_names;
auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types);
}
} }
void MetaInfo::Validate(int32_t device) const { void MetaInfo::Validate(int32_t device) const {

View File

@ -69,6 +69,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
<< "Meta info " << key << " should be dense, found validity mask"; << "Meta info " << key << " should be dense, found validity mask";
CHECK_EQ(array_interface.num_cols, 1) CHECK_EQ(array_interface.num_cols, 1)
<< "Meta info should be a single column."; << "Meta info should be a single column.";
if (array_interface.num_rows == 0) {
return;
}
if (key == "label") { if (key == "label") {
CopyInfoImpl(array_interface, &labels_); CopyInfoImpl(array_interface, &labels_);

View File

@ -122,10 +122,14 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian(); CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
std::vector<ArrayInterface> columns; std::vector<ArrayInterface> columns;
auto first_column = ArrayInterface(get<Object const>(json_columns[0])); auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
num_rows_ = first_column.num_rows;
if (num_rows_ == 0) {
return;
}
device_idx_ = dh::CudaGetPointerDevice(first_column.data); device_idx_ = dh::CudaGetPointerDevice(first_column.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, -1);
dh::safe_cuda(cudaSetDevice(device_idx_)); dh::safe_cuda(cudaSetDevice(device_idx_));
num_rows_ = first_column.num_rows;
for (auto& json_col : json_columns) { for (auto& json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col)); auto column = ArrayInterface(get<Object const>(json_col));
columns.push_back(column); columns.push_back(column);
@ -183,9 +187,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
Json json_array_interface = Json json_array_interface =
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()}); Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
array_interface_ = ArrayInterface(get<Object const>(json_array_interface), false); array_interface_ = ArrayInterface(get<Object const>(json_array_interface), false);
batch_ = CupyAdapterBatch(array_interface_);
if (array_interface_.num_rows == 0) {
return;
}
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, -1);
batch_ = CupyAdapterBatch(array_interface_);
} }
const CupyAdapterBatch& Value() const override { return batch_; } const CupyAdapterBatch& Value() const override { return batch_; }

View File

@ -62,23 +62,30 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t batches = 0; size_t batches = 0;
size_t accumulated_rows = 0; size_t accumulated_rows = 0;
bst_feature_t cols = 0; bst_feature_t cols = 0;
int32_t device = -1; int32_t device = GenericParameter::kCpuId;
int32_t current_device_;
dh::safe_cuda(cudaGetDevice(&current_device_));
auto get_device = [&]() -> int32_t {
int32_t d = GenericParameter::kCpuId ? current_device_ : device;
return d;
};
while (iter.Next()) { while (iter.Next()) {
device = proxy->DeviceIdx(); device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(get_device()));
if (cols == 0) { if (cols == 0) {
cols = num_cols(); cols = num_cols();
rabit::Allreduce<rabit::op::Max>(&cols, 1);
} else { } else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
} }
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device); sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device());
auto* p_sketch = &sketch_containers.back(); auto* p_sketch = &sketch_containers.back();
proxy->Info().weights_.SetDevice(device); proxy->Info().weights_.SetDevice(get_device());
Dispatch(proxy, [&](auto const &value) { Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketch(value, batch_param_.max_bin, common::AdapterDeviceSketch(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch); proxy->Info(), missing, p_sketch);
}); });
auto batch_rows = num_rows(); auto batch_rows = num_rows();
accumulated_rows += batch_rows; accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0); dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
@ -86,19 +93,15 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
row_counts.size()); row_counts.size());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) { row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
return GetRowCounts(value, row_counts_span, return GetRowCounts(value, row_counts_span,
device, missing); get_device(), missing);
})); }));
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
row_counts.end()); row_counts.end());
batches++; batches++;
} }
iter.Reset();
if (device < 0) { // error or empty dh::safe_cuda(cudaSetDevice(get_device()));
this->page_.reset(new EllpackPage); common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device());
return;
}
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
for (auto const& sketch : sketch_containers) { for (auto const& sketch : sketch_containers) {
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
final_sketch.FixError(); final_sketch.FixError();
@ -113,14 +116,14 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
this->info_.num_row_ = accumulated_rows; this->info_.num_row_ = accumulated_rows;
this->info_.num_nonzero_ = nnz; this->info_.num_nonzero_ = nnz;
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() { auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows,
get_device]() {
if (!page_) { if (!page_) {
// Should be put inside the while loop to protect against empty batch. In // Should be put inside the while loop to protect against empty batch. In
// that case device id is invalid. // that case device id is invalid.
page_.reset(new EllpackPage); page_.reset(new EllpackPage);
*(page_->Impl()) = *(page_->Impl()) = EllpackPageImpl(get_device(), cuts, this->IsDense(),
EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride, row_stride, accumulated_rows);
accumulated_rows);
} }
}; };
@ -130,21 +133,20 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t n_batches_for_verification = 0; size_t n_batches_for_verification = 0;
while (iter.Next()) { while (iter.Next()) {
init_page(); init_page();
auto device = proxy->DeviceIdx(); dh::safe_cuda(cudaSetDevice(get_device()));
dh::safe_cuda(cudaSetDevice(device));
auto rows = num_rows(); auto rows = num_rows();
dh::caching_device_vector<size_t> row_counts(rows + 1, 0); dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size()); row_counts.size());
Dispatch(proxy, [=](auto const& value) { Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, device, missing); return GetRowCounts(value, row_counts_span, get_device(), missing);
}); });
auto is_dense = this->IsDense(); auto is_dense = this->IsDense();
auto new_impl = Dispatch(proxy, [&](auto const &value) { auto new_impl = Dispatch(proxy, [&](auto const &value) {
return EllpackPageImpl(value, missing, device, is_dense, nthread, return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
row_counts_span, row_stride, rows, cols, cuts); row_counts_span, row_stride, rows, cols, cuts);
}); });
size_t num_elements = page_->Impl()->Copy(device, &new_impl, offset); size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
offset += num_elements; offset += num_elements;
proxy->Info().num_row_ = num_rows(); proxy->Info().num_row_ = num_rows();

View File

@ -158,18 +158,22 @@ struct EvalMClassBase : public Metric {
bst_float Eval(const HostDeviceVector<bst_float> &preds, bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info, const MetaInfo &info,
bool distributed) override { bool distributed) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; if (info.labels_.Size() == 0) {
CHECK(preds.Size() % info.labels_.Size() == 0) CHECK_EQ(preds.Size(), 0);
<< "label and prediction size not match"; } else {
CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match";
}
double dat[2] { 0.0, 0.0 };
if (info.labels_.Size() != 0) {
const size_t nclass = preds.Size() / info.labels_.Size(); const size_t nclass = preds.Size() / info.labels_.Size();
CHECK_GE(nclass, 1U) CHECK_GE(nclass, 1U)
<< "mlogloss and merror are only used for multi-class classification," << "mlogloss and merror are only used for multi-class classification,"
<< " use logloss for binary classification"; << " use logloss for binary classification";
int device = tparam_->gpu_id; int device = tparam_->gpu_id;
auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds); auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);
double dat[2] { result.Residue(), result.Weights() }; dat[0] = result.Residue();
dat[1] = result.Weights();
}
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
} }

View File

@ -49,7 +49,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
const MetaInfo& info, const MetaInfo& info,
int iter, int iter,
HostDeviceVector<GradientPair>* out_gpair) override { HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; if (info.labels_.Size() == 0) {
return;
}
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size())) CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size()))
<< "SoftmaxMultiClassObj: label size and pred size does not match.\n" << "SoftmaxMultiClassObj: label size and pred size does not match.\n"
<< "label.Size() * num_class: " << "label.Size() * num_class: "

View File

@ -6,13 +6,15 @@ import unittest
import xgboost import xgboost
import subprocess import subprocess
from hypothesis import given, strategies, settings, note from hypothesis import given, strategies, settings, note
from hypothesis._settings import duration
from test_gpu_updaters import parameter_strategy from test_gpu_updaters import parameter_strategy
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
sys.path.append("tests/python") sys.path.append("tests/python")
from test_with_dask import run_empty_dmatrix # noqa from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import generate_array # noqa from test_with_dask import generate_array # noqa
import testing as tm # noqa import testing as tm # noqa
@ -28,15 +30,7 @@ except ImportError:
pass pass
class TestDistributedGPU(unittest.TestCase): def run_with_dask_dataframe(DMatrixT, client):
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask_cudf())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_dask_dataframe(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
import cupy as cp import cupy as cp
cp.cuda.runtime.setDevice(0) cp.cuda.runtime.setDevice(0)
X, y = generate_array() X, y = generate_array()
@ -47,7 +41,7 @@ class TestDistributedGPU(unittest.TestCase):
X = X.map_partitions(cudf.from_pandas) X = X.map_partitions(cudf.from_pandas)
y = y.map_partitions(cudf.from_pandas) y = y.map_partitions(cudf.from_pandas)
dtrain = dxgb.DaskDMatrix(client, X, y) dtrain = DMatrixT(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist', out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True}, 'debug_synchronize': True},
dtrain=dtrain, dtrain=dtrain,
@ -85,53 +79,15 @@ class TestDistributedGPU(unittest.TestCase):
cp.testing.assert_allclose( cp.testing.assert_allclose(
predt.values.compute(), single_node) predt.values.compute(), single_node)
@given(parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@settings(deadline=None)
@pytest.mark.mgpu
def test_gpu_hist(self, params, num_rounds, dataset):
with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client:
params['tree_method'] = 'gpu_hist'
params = dataset.set_params(params)
# multi class doesn't handle empty dataset well (empty
# means at least 1 worker has data).
if params['objective'] == "multi:softmax":
return
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return
chunk = 128 def run_with_dask_array(DMatrixT, client):
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
else:
w = None
m = dxgb.DaskDMatrix(
client, data=X, label=y, weight=w)
history = dxgb.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_dask_array(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
import cupy as cp import cupy as cp
cp.cuda.runtime.setDevice(0) cp.cuda.runtime.setDevice(0)
X, y = generate_array() X, y = generate_array()
X = X.map_blocks(cp.asarray) X = X.map_blocks(cp.asarray)
y = y.map_blocks(cp.asarray) y = y.map_blocks(cp.asarray)
dtrain = dxgb.DaskDMatrix(client, X, y) dtrain = DMatrixT(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist', out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True}, 'debug_synchronize': True},
dtrain=dtrain, dtrain=dtrain,
@ -151,6 +107,81 @@ class TestDistributedGPU(unittest.TestCase):
single_node, single_node,
inplace_predictions) inplace_predictions)
def to_cp(x, DMatrixT):
import cupy
if isinstance(x, np.ndarray) and \
DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
X = cupy.array(x)
else:
X = x
return X
def run_gpu_hist(params, num_rounds, dataset, DMatrixT, client):
params['tree_method'] = 'gpu_hist'
params = dataset.set_params(params)
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return
chunk = 128
X = to_cp(dataset.X, DMatrixT)
X = da.from_array(X,
chunks=(chunk, dataset.X.shape[1]))
y = to_cp(dataset.y, DMatrixT)
y = da.from_array(y, chunks=(chunk, ))
if dataset.w is not None:
w = to_cp(dataset.w, DMatrixT)
w = da.from_array(w, chunks=(chunk, ))
else:
w = None
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
m = DMatrixT(client, data=X, label=y, weight=w,
max_bin=params.get('max_bin', 256))
else:
m = DMatrixT(client, data=X, label=y, weight=w)
history = dxgb.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
class TestDistributedGPU(unittest.TestCase):
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask_cudf())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_dask_dataframe(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
run_with_dask_dataframe(dxgb.DaskDMatrix, client)
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client)
@given(parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@settings(deadline=duration(seconds=120))
@pytest.mark.mgpu
def test_gpu_hist(self, params, num_rounds, dataset):
with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client:
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix,
client)
run_gpu_hist(params, num_rounds, dataset,
dxgb.DaskDeviceQuantileDMatrix, client)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_dask_array(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
run_with_dask_array(dxgb.DaskDMatrix, client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
@ -159,7 +190,8 @@ class TestDistributedGPU(unittest.TestCase):
with Client(cluster) as client: with Client(cluster) as client:
parameters = {'tree_method': 'gpu_hist', parameters = {'tree_method': 'gpu_hist',
'debug_synchronize': True} 'debug_synchronize': True}
run_empty_dmatrix(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)
def run_quantile(self, name): def run_quantile(self, name):
if sys.platform.startswith("win"): if sys.platform.startswith("win"):

View File

@ -128,8 +128,7 @@ def test_dask_missing_value_reg():
def test_dask_missing_value_cls(): def test_dask_missing_value_cls():
# Multi-class doesn't handle empty DMatrix well. So we use lesser workers. with LocalCluster() as cluster:
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
X_0 = np.ones((kRows // 2, kCols)) X_0 = np.ones((kRows // 2, kCols))
X_1 = np.zeros((kRows // 2, kCols)) X_1 = np.zeros((kRows // 2, kCols))
@ -234,7 +233,7 @@ def test_sklearn_grid_search():
assert len(means) == len(set(means)) assert len(means) == len(set(means))
def run_empty_dmatrix(client, parameters): def run_empty_dmatrix_reg(client, parameters):
def _check_outputs(out, predictions): def _check_outputs(out, predictions):
assert isinstance(out['booster'], xgb.dask.Booster) assert isinstance(out['booster'], xgb.dask.Booster)
@ -271,6 +270,46 @@ def run_empty_dmatrix(client, parameters):
_check_outputs(out, predictions) _check_outputs(out, predictions)
def run_empty_dmatrix_cls(client, parameters):
n_classes = 4
def _check_outputs(out, predictions):
assert isinstance(out['booster'], xgb.dask.Booster)
assert len(out['history']['validation']['merror']) == 2
assert isinstance(predictions, np.ndarray)
assert predictions.shape[1] == n_classes, predictions.shape
kRows, kCols = 1, 97
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
dtrain = xgb.dask.DaskDMatrix(client, X, y)
parameters['objective'] = 'multi:softprob'
parameters['num_class'] = n_classes
out = xgb.dask.train(client, parameters,
dtrain=dtrain,
evals=[(dtrain, 'validation')],
num_boost_round=2)
predictions = xgb.dask.predict(client=client, model=out,
data=dtrain).compute()
_check_outputs(out, predictions)
# train has more rows than evals
valid = dtrain
kRows += 1
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
dtrain = xgb.dask.DaskDMatrix(client, X, y)
out = xgb.dask.train(client, parameters,
dtrain=dtrain,
evals=[(valid, 'validation')],
num_boost_round=2)
predictions = xgb.dask.predict(client=client, model=out,
data=valid).compute()
_check_outputs(out, predictions)
# No test for Exact, as empty DMatrix handling are mostly for distributed # No test for Exact, as empty DMatrix handling are mostly for distributed
# environment and Exact doesn't support it. # environment and Exact doesn't support it.
@ -278,11 +317,13 @@ def test_empty_dmatrix_hist():
with LocalCluster(n_workers=5) as cluster: with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
parameters = {'tree_method': 'hist'} parameters = {'tree_method': 'hist'}
run_empty_dmatrix(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)
def test_empty_dmatrix_approx(): def test_empty_dmatrix_approx():
with LocalCluster(n_workers=5) as cluster: with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
parameters = {'tree_method': 'approx'} parameters = {'tree_method': 'approx'}
run_empty_dmatrix(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters)