parent
e471056ec4
commit
7c2686146e
@ -14,6 +14,7 @@ https://github.com/dask/dask-xgboost
|
||||
import platform
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from threading import Thread
|
||||
|
||||
import numpy
|
||||
@ -28,7 +29,7 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import CUDF_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
from .core import DMatrix, Booster, _expect
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
@ -357,6 +358,146 @@ class DaskDMatrix:
|
||||
return (rows, cols)
|
||||
|
||||
|
||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
'''A data iterator for `DaskDeviceQuantileDMatrix`.
|
||||
'''
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
label_lower_bound=None, label_upper_bound=None,
|
||||
feature_names=None, feature_types=None):
|
||||
self._data = data
|
||||
self._labels = label
|
||||
self._weights = weight
|
||||
self._base_margin = base_margin
|
||||
self._label_lower_bound = label_lower_bound
|
||||
self._label_upper_bound = label_upper_bound
|
||||
self._feature_names = feature_names
|
||||
self._feature_types = feature_types
|
||||
|
||||
assert isinstance(self._data, Sequence)
|
||||
|
||||
types = (Sequence, type(None))
|
||||
assert isinstance(self._labels, types)
|
||||
assert isinstance(self._weights, types)
|
||||
assert isinstance(self._base_margin, types)
|
||||
assert isinstance(self._label_lower_bound, types)
|
||||
assert isinstance(self._label_upper_bound, types)
|
||||
|
||||
self._iter = 0 # set iterator to 0
|
||||
super().__init__()
|
||||
|
||||
def data(self):
|
||||
'''Utility function for obtaining current batch of data.'''
|
||||
return self._data[self._iter]
|
||||
|
||||
def labels(self):
|
||||
'''Utility function for obtaining current batch of label.'''
|
||||
if self._labels is not None:
|
||||
return self._labels[self._iter]
|
||||
return None
|
||||
|
||||
def weights(self):
|
||||
'''Utility function for obtaining current batch of label.'''
|
||||
if self._weights is not None:
|
||||
return self._weights[self._iter]
|
||||
return None
|
||||
|
||||
def base_margins(self):
|
||||
'''Utility function for obtaining current batch of base_margin.'''
|
||||
if self._base_margin is not None:
|
||||
return self._base_margin[self._iter]
|
||||
return None
|
||||
|
||||
def label_lower_bounds(self):
|
||||
'''Utility function for obtaining current batch of label_lower_bound.
|
||||
'''
|
||||
if self._label_lower_bound is not None:
|
||||
return self._label_lower_bound[self._iter]
|
||||
return None
|
||||
|
||||
def label_upper_bounds(self):
|
||||
'''Utility function for obtaining current batch of label_upper_bound.
|
||||
'''
|
||||
if self._label_upper_bound is not None:
|
||||
return self._label_upper_bound[self._iter]
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
'''Reset the iterator'''
|
||||
self._iter = 0
|
||||
|
||||
def next(self, input_data):
|
||||
'''Yield next batch of data'''
|
||||
if self._iter == len(self._data):
|
||||
# Return 0 when there's no more batch.
|
||||
return 0
|
||||
if self._feature_names:
|
||||
feature_names = self._feature_names
|
||||
else:
|
||||
if hasattr(self.data(), 'columns'):
|
||||
feature_names = self.data().columns.format()
|
||||
else:
|
||||
feature_names = None
|
||||
input_data(data=self.data(), label=self.labels(),
|
||||
weight=self.weights(), group=None,
|
||||
label_lower_bound=self.label_lower_bounds(),
|
||||
label_upper_bound=self.label_upper_bounds(),
|
||||
feature_names=feature_names,
|
||||
feature_types=self._feature_types)
|
||||
self._iter += 1
|
||||
return 1
|
||||
|
||||
|
||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
'''Specialized data type for `gpu_hist` tree method. This class is
|
||||
used to reduce the memory usage by eliminating data copies.
|
||||
Internally the data is merged by weighted GK sketching. So the
|
||||
number of partitions from dask may affect training accuracy as GK
|
||||
generates error for each merge.
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bin: Number of bins for histogram construction.
|
||||
|
||||
'''
|
||||
def __init__(self, client, data, label=None, weight=None,
|
||||
missing=None,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
max_bin=256):
|
||||
super().__init__(client=client, data=data, label=label, weight=weight,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
self.max_bin = max_bin
|
||||
|
||||
def get_worker_data(self, worker):
|
||||
if worker.address not in set(self.worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(self.worker_map.keys()))
|
||||
LOGGER.warning(msg)
|
||||
import cupy # pylint: disable=import-error
|
||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
max_bin=self.max_bin)
|
||||
return d
|
||||
|
||||
data, labels, weights = self.get_worker_parts(worker)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
||||
|
||||
dmatrix = DeviceQuantileDMatrix(it,
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=self.max_bin)
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _get_rabit_args(worker_map, client):
|
||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||
host = distributed_comm.get_address_host(client.scheduler.address)
|
||||
|
||||
@ -15,7 +15,7 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _warn_unused_missing(data, missing):
|
||||
if not (np.isnan(missing) or None):
|
||||
if (not np.isnan(missing)) or (missing is None):
|
||||
warnings.warn(
|
||||
'`missing` is not used for current input data type:' +
|
||||
str(type(data)))
|
||||
|
||||
@ -85,6 +85,7 @@ class SketchContainer {
|
||||
// Initialize Sketches for this dmatrix
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
this->columns_ptr_.Resize(num_columns + 1);
|
||||
CHECK_GE(device, 0);
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
/* \brief Return GPU ID for this container. */
|
||||
|
||||
@ -114,7 +114,6 @@ class ArrayInterfaceHandler {
|
||||
get<Array const>(
|
||||
obj.at("data"))
|
||||
.at(0))));
|
||||
CHECK(p_data);
|
||||
return p_data;
|
||||
}
|
||||
|
||||
@ -224,6 +223,9 @@ class ArrayInterfaceHandler {
|
||||
auto shape = ExtractShape(column);
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
if (!p_data) {
|
||||
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
|
||||
}
|
||||
return common::Span<T>{p_data, shape.first * shape.second};
|
||||
}
|
||||
};
|
||||
@ -234,7 +236,6 @@ class ArrayInterface {
|
||||
bool allow_mask = true) {
|
||||
ArrayInterfaceHandler::Validate(column);
|
||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
CHECK(data) << "Column is null";
|
||||
auto shape = ArrayInterfaceHandler::ExtractShape(column);
|
||||
num_rows = shape.first;
|
||||
num_cols = shape.second;
|
||||
|
||||
@ -488,6 +488,15 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
|
||||
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
|
||||
group_ptr.end());
|
||||
}
|
||||
|
||||
if (!that.feature_names.empty()) {
|
||||
this->feature_names = that.feature_names;
|
||||
}
|
||||
if (!that.feature_type_names.empty()) {
|
||||
this->feature_type_names = that.feature_type_names;
|
||||
auto &h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
|
||||
@ -69,6 +69,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
<< "Meta info " << key << " should be dense, found validity mask";
|
||||
CHECK_EQ(array_interface.num_cols, 1)
|
||||
<< "Meta info should be a single column.";
|
||||
if (array_interface.num_rows == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (key == "label") {
|
||||
CopyInfoImpl(array_interface, &labels_);
|
||||
|
||||
@ -122,10 +122,14 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
|
||||
std::vector<ArrayInterface> columns;
|
||||
auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
|
||||
num_rows_ = first_column.num_rows;
|
||||
if (num_rows_ == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||
CHECK_NE(device_idx_, -1);
|
||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||
num_rows_ = first_column.num_rows;
|
||||
for (auto& json_col : json_columns) {
|
||||
auto column = ArrayInterface(get<Object const>(json_col));
|
||||
columns.push_back(column);
|
||||
@ -183,9 +187,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
Json json_array_interface =
|
||||
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
|
||||
array_interface_ = ArrayInterface(get<Object const>(json_array_interface), false);
|
||||
batch_ = CupyAdapterBatch(array_interface_);
|
||||
if (array_interface_.num_rows == 0) {
|
||||
return;
|
||||
}
|
||||
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
||||
CHECK_NE(device_idx_, -1);
|
||||
batch_ = CupyAdapterBatch(array_interface_);
|
||||
}
|
||||
const CupyAdapterBatch& Value() const override { return batch_; }
|
||||
|
||||
|
||||
@ -62,23 +62,30 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
size_t batches = 0;
|
||||
size_t accumulated_rows = 0;
|
||||
bst_feature_t cols = 0;
|
||||
int32_t device = -1;
|
||||
int32_t device = GenericParameter::kCpuId;
|
||||
int32_t current_device_;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device_));
|
||||
auto get_device = [&]() -> int32_t {
|
||||
int32_t d = GenericParameter::kCpuId ? current_device_ : device;
|
||||
return d;
|
||||
};
|
||||
|
||||
while (iter.Next()) {
|
||||
device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
rabit::Allreduce<rabit::op::Max>(&cols, 1);
|
||||
} else {
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
|
||||
sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device());
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
proxy->Info().weights_.SetDevice(device);
|
||||
proxy->Info().weights_.SetDevice(get_device());
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketch(value, batch_param_.max_bin,
|
||||
proxy->Info(), missing, p_sketch);
|
||||
});
|
||||
|
||||
auto batch_rows = num_rows();
|
||||
accumulated_rows += batch_rows;
|
||||
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
|
||||
@ -86,19 +93,15 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
row_counts.size());
|
||||
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
|
||||
return GetRowCounts(value, row_counts_span,
|
||||
device, missing);
|
||||
get_device(), missing);
|
||||
}));
|
||||
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
|
||||
row_counts.end());
|
||||
batches++;
|
||||
}
|
||||
|
||||
if (device < 0) { // error or empty
|
||||
this->page_.reset(new EllpackPage);
|
||||
return;
|
||||
}
|
||||
|
||||
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
|
||||
iter.Reset();
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device());
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
final_sketch.FixError();
|
||||
@ -113,14 +116,14 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
this->info_.num_row_ = accumulated_rows;
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
|
||||
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() {
|
||||
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows,
|
||||
get_device]() {
|
||||
if (!page_) {
|
||||
// Should be put inside the while loop to protect against empty batch. In
|
||||
// that case device id is invalid.
|
||||
page_.reset(new EllpackPage);
|
||||
*(page_->Impl()) =
|
||||
EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride,
|
||||
accumulated_rows);
|
||||
*(page_->Impl()) = EllpackPageImpl(get_device(), cuts, this->IsDense(),
|
||||
row_stride, accumulated_rows);
|
||||
}
|
||||
};
|
||||
|
||||
@ -130,21 +133,20 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
size_t n_batches_for_verification = 0;
|
||||
while (iter.Next()) {
|
||||
init_page();
|
||||
auto device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
auto rows = num_rows();
|
||||
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, device, missing);
|
||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||
});
|
||||
auto is_dense = this->IsDense();
|
||||
auto new_impl = Dispatch(proxy, [&](auto const &value) {
|
||||
return EllpackPageImpl(value, missing, device, is_dense, nthread,
|
||||
row_counts_span, row_stride, rows, cols, cuts);
|
||||
return EllpackPageImpl(value, missing, get_device(), is_dense, nthread,
|
||||
row_counts_span, row_stride, rows, cols, cuts);
|
||||
});
|
||||
size_t num_elements = page_->Impl()->Copy(device, &new_impl, offset);
|
||||
size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset);
|
||||
offset += num_elements;
|
||||
|
||||
proxy->Info().num_row_ = num_rows();
|
||||
|
||||
@ -158,18 +158,22 @@ struct EvalMClassBase : public Metric {
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0)
|
||||
<< "label and prediction size not match";
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
|
||||
int device = tparam_->gpu_id;
|
||||
auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);
|
||||
double dat[2] { result.Residue(), result.Weights() };
|
||||
|
||||
if (info.labels_.Size() == 0) {
|
||||
CHECK_EQ(preds.Size(), 0);
|
||||
} else {
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match";
|
||||
}
|
||||
double dat[2] { 0.0, 0.0 };
|
||||
if (info.labels_.Size() != 0) {
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
int device = tparam_->gpu_id;
|
||||
auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);
|
||||
dat[0] = result.Residue();
|
||||
dat[1] = result.Weights();
|
||||
}
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
|
||||
@ -49,7 +49,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
if (info.labels_.Size() == 0) {
|
||||
return;
|
||||
}
|
||||
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size()))
|
||||
<< "SoftmaxMultiClassObj: label size and pred size does not match.\n"
|
||||
<< "label.Size() * num_class: "
|
||||
|
||||
@ -6,13 +6,15 @@ import unittest
|
||||
import xgboost
|
||||
import subprocess
|
||||
from hypothesis import given, strategies, settings, note
|
||||
from hypothesis._settings import duration
|
||||
from test_gpu_updaters import parameter_strategy
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
sys.path.append("tests/python")
|
||||
from test_with_dask import run_empty_dmatrix # noqa
|
||||
from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import generate_array # noqa
|
||||
import testing as tm # noqa
|
||||
|
||||
@ -28,6 +30,126 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def run_with_dask_dataframe(DMatrixT, client):
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
X, y = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
|
||||
X = X.map_partitions(cudf.from_pandas)
|
||||
y = y.map_partitions(cudf.from_pandas)
|
||||
|
||||
dtrain = DMatrixT(client, X, y)
|
||||
out = dxgb.train(client, {'tree_method': 'gpu_hist',
|
||||
'debug_synchronize': True},
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'X')],
|
||||
num_boost_round=4)
|
||||
|
||||
assert isinstance(out['booster'], dxgb.Booster)
|
||||
assert len(out['history']['X']['rmse']) == 4
|
||||
|
||||
predictions = dxgb.predict(client, out, dtrain).compute()
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
|
||||
series_predictions = dxgb.inplace_predict(client, out, X)
|
||||
assert isinstance(series_predictions, dd.Series)
|
||||
series_predictions = series_predictions.compute()
|
||||
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
|
||||
cp.testing.assert_allclose(single_node, predictions)
|
||||
np.testing.assert_allclose(single_node,
|
||||
series_predictions.to_array())
|
||||
|
||||
predt = dxgb.predict(client, out, X)
|
||||
assert isinstance(predt, dd.Series)
|
||||
|
||||
def is_df(part):
|
||||
assert isinstance(part, cudf.DataFrame), part
|
||||
return part
|
||||
|
||||
predt.map_partitions(
|
||||
is_df,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'}))
|
||||
|
||||
cp.testing.assert_allclose(
|
||||
predt.values.compute(), single_node)
|
||||
|
||||
|
||||
def run_with_dask_array(DMatrixT, client):
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
X, y = generate_array()
|
||||
|
||||
X = X.map_blocks(cp.asarray)
|
||||
y = y.map_blocks(cp.asarray)
|
||||
dtrain = DMatrixT(client, X, y)
|
||||
out = dxgb.train(client, {'tree_method': 'gpu_hist',
|
||||
'debug_synchronize': True},
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'X')],
|
||||
num_boost_round=2)
|
||||
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
|
||||
inplace_predictions = dxgb.inplace_predict(
|
||||
client, out, X).compute()
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
np.testing.assert_allclose(single_node, from_dmatrix)
|
||||
device = cp.cuda.runtime.getDevice()
|
||||
assert device == inplace_predictions.device.id
|
||||
single_node = cp.array(single_node)
|
||||
assert device == single_node.device.id
|
||||
cp.testing.assert_allclose(
|
||||
single_node,
|
||||
inplace_predictions)
|
||||
|
||||
|
||||
def to_cp(x, DMatrixT):
|
||||
import cupy
|
||||
if isinstance(x, np.ndarray) and \
|
||||
DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
|
||||
X = cupy.array(x)
|
||||
else:
|
||||
X = x
|
||||
return X
|
||||
|
||||
|
||||
def run_gpu_hist(params, num_rounds, dataset, DMatrixT, client):
|
||||
params['tree_method'] = 'gpu_hist'
|
||||
params = dataset.set_params(params)
|
||||
# It doesn't make sense to distribute a completely
|
||||
# empty dataset.
|
||||
if dataset.X.shape[0] == 0:
|
||||
return
|
||||
|
||||
chunk = 128
|
||||
X = to_cp(dataset.X, DMatrixT)
|
||||
X = da.from_array(X,
|
||||
chunks=(chunk, dataset.X.shape[1]))
|
||||
y = to_cp(dataset.y, DMatrixT)
|
||||
y = da.from_array(y, chunks=(chunk, ))
|
||||
if dataset.w is not None:
|
||||
w = to_cp(dataset.w, DMatrixT)
|
||||
w = da.from_array(w, chunks=(chunk, ))
|
||||
else:
|
||||
w = None
|
||||
|
||||
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
|
||||
m = DMatrixT(client, data=X, label=y, weight=w,
|
||||
max_bin=params.get('max_bin', 256))
|
||||
else:
|
||||
m = DMatrixT(client, data=X, label=y, weight=w)
|
||||
history = dxgb.train(client, params=params, dtrain=m,
|
||||
num_boost_round=num_rounds,
|
||||
evals=[(m, 'train')])['history']
|
||||
note(history)
|
||||
assert tm.non_increasing(history['train'][dataset.metric])
|
||||
|
||||
|
||||
class TestDistributedGPU(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
@ -37,119 +159,28 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
def test_dask_dataframe(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
X, y = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
|
||||
X = X.map_partitions(cudf.from_pandas)
|
||||
y = y.map_partitions(cudf.from_pandas)
|
||||
|
||||
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||
out = dxgb.train(client, {'tree_method': 'gpu_hist',
|
||||
'debug_synchronize': True},
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'X')],
|
||||
num_boost_round=4)
|
||||
|
||||
assert isinstance(out['booster'], dxgb.Booster)
|
||||
assert len(out['history']['X']['rmse']) == 4
|
||||
|
||||
predictions = dxgb.predict(client, out, dtrain).compute()
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
|
||||
series_predictions = dxgb.inplace_predict(client, out, X)
|
||||
assert isinstance(series_predictions, dd.Series)
|
||||
series_predictions = series_predictions.compute()
|
||||
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
|
||||
cp.testing.assert_allclose(single_node, predictions)
|
||||
np.testing.assert_allclose(single_node,
|
||||
series_predictions.to_array())
|
||||
|
||||
predt = dxgb.predict(client, out, X)
|
||||
assert isinstance(predt, dd.Series)
|
||||
|
||||
def is_df(part):
|
||||
assert isinstance(part, cudf.DataFrame), part
|
||||
return part
|
||||
|
||||
predt.map_partitions(
|
||||
is_df,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'}))
|
||||
|
||||
cp.testing.assert_allclose(
|
||||
predt.values.compute(), single_node)
|
||||
run_with_dask_dataframe(dxgb.DaskDMatrix, client)
|
||||
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
@settings(deadline=duration(seconds=120))
|
||||
@pytest.mark.mgpu
|
||||
def test_gpu_hist(self, params, num_rounds, dataset):
|
||||
with LocalCUDACluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
params['tree_method'] = 'gpu_hist'
|
||||
params = dataset.set_params(params)
|
||||
# multi class doesn't handle empty dataset well (empty
|
||||
# means at least 1 worker has data).
|
||||
if params['objective'] == "multi:softmax":
|
||||
return
|
||||
# It doesn't make sense to distribute a completely
|
||||
# empty dataset.
|
||||
if dataset.X.shape[0] == 0:
|
||||
return
|
||||
|
||||
chunk = 128
|
||||
X = da.from_array(dataset.X,
|
||||
chunks=(chunk, dataset.X.shape[1]))
|
||||
y = da.from_array(dataset.y, chunks=(chunk, ))
|
||||
if dataset.w is not None:
|
||||
w = da.from_array(dataset.w, chunks=(chunk, ))
|
||||
else:
|
||||
w = None
|
||||
|
||||
m = dxgb.DaskDMatrix(
|
||||
client, data=X, label=y, weight=w)
|
||||
history = dxgb.train(client, params=params, dtrain=m,
|
||||
num_boost_round=num_rounds,
|
||||
evals=[(m, 'train')])['history']
|
||||
note(history)
|
||||
assert tm.non_increasing(history['train'][dataset.metric])
|
||||
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix,
|
||||
client)
|
||||
run_gpu_hist(params, num_rounds, dataset,
|
||||
dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_dask_array(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
X, y = generate_array()
|
||||
|
||||
X = X.map_blocks(cp.asarray)
|
||||
y = y.map_blocks(cp.asarray)
|
||||
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||
out = dxgb.train(client, {'tree_method': 'gpu_hist',
|
||||
'debug_synchronize': True},
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'X')],
|
||||
num_boost_round=2)
|
||||
from_dmatrix = dxgb.predict(client, out, dtrain).compute()
|
||||
inplace_predictions = dxgb.inplace_predict(
|
||||
client, out, X).compute()
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
np.testing.assert_allclose(single_node, from_dmatrix)
|
||||
device = cp.cuda.runtime.getDevice()
|
||||
assert device == inplace_predictions.device.id
|
||||
single_node = cp.array(single_node)
|
||||
assert device == single_node.device.id
|
||||
cp.testing.assert_allclose(
|
||||
single_node,
|
||||
inplace_predictions)
|
||||
run_with_dask_array(dxgb.DaskDMatrix, client)
|
||||
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@ -159,7 +190,8 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'gpu_hist',
|
||||
'debug_synchronize': True}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
def run_quantile(self, name):
|
||||
if sys.platform.startswith("win"):
|
||||
|
||||
@ -128,8 +128,7 @@ def test_dask_missing_value_reg():
|
||||
|
||||
|
||||
def test_dask_missing_value_cls():
|
||||
# Multi-class doesn't handle empty DMatrix well. So we use lesser workers.
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with LocalCluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
X_0 = np.ones((kRows // 2, kCols))
|
||||
X_1 = np.zeros((kRows // 2, kCols))
|
||||
@ -234,7 +233,7 @@ def test_sklearn_grid_search():
|
||||
assert len(means) == len(set(means))
|
||||
|
||||
|
||||
def run_empty_dmatrix(client, parameters):
|
||||
def run_empty_dmatrix_reg(client, parameters):
|
||||
|
||||
def _check_outputs(out, predictions):
|
||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||
@ -271,6 +270,46 @@ def run_empty_dmatrix(client, parameters):
|
||||
_check_outputs(out, predictions)
|
||||
|
||||
|
||||
def run_empty_dmatrix_cls(client, parameters):
|
||||
n_classes = 4
|
||||
|
||||
def _check_outputs(out, predictions):
|
||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||
assert len(out['history']['validation']['merror']) == 2
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
assert predictions.shape[1] == n_classes, predictions.shape
|
||||
|
||||
kRows, kCols = 1, 97
|
||||
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
|
||||
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||
parameters['objective'] = 'multi:softprob'
|
||||
parameters['num_class'] = n_classes
|
||||
|
||||
out = xgb.dask.train(client, parameters,
|
||||
dtrain=dtrain,
|
||||
evals=[(dtrain, 'validation')],
|
||||
num_boost_round=2)
|
||||
predictions = xgb.dask.predict(client=client, model=out,
|
||||
data=dtrain).compute()
|
||||
_check_outputs(out, predictions)
|
||||
|
||||
# train has more rows than evals
|
||||
valid = dtrain
|
||||
kRows += 1
|
||||
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
|
||||
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||
|
||||
out = xgb.dask.train(client, parameters,
|
||||
dtrain=dtrain,
|
||||
evals=[(valid, 'validation')],
|
||||
num_boost_round=2)
|
||||
predictions = xgb.dask.predict(client=client, model=out,
|
||||
data=valid).compute()
|
||||
_check_outputs(out, predictions)
|
||||
|
||||
|
||||
# No test for Exact, as empty DMatrix handling are mostly for distributed
|
||||
# environment and Exact doesn't support it.
|
||||
|
||||
@ -278,11 +317,13 @@ def test_empty_dmatrix_hist():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'hist'}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
|
||||
def test_empty_dmatrix_approx():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'approx'}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user