Quantile DMatrix for CPU. (#8130)

- Add a new `QuantileDMatrix` that works for both CPU and GPU.
- Deprecate `DeviceQuantileDMatrix`.
This commit is contained in:
Jiaming Yuan 2022-08-02 15:51:23 +08:00 committed by GitHub
parent 2cba1d9fcc
commit d87f69215e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 521 additions and 117 deletions

View File

@ -22,6 +22,9 @@ Core Data Structure
:members: :members:
:show-inheritance: :show-inheritance:
.. autoclass:: xgboost.QuantileDMatrix
:show-inheritance:
.. autoclass:: xgboost.DeviceQuantileDMatrix .. autoclass:: xgboost.DeviceQuantileDMatrix
:show-inheritance: :show-inheritance:

View File

@ -415,28 +415,26 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
* *
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
DataIterResetCallback *reset, char const *c_json_config, DMatrixHandle *out);
XGDMatrixCallbackNext *next,
char const* c_json_config,
DMatrixHandle *out);
/*! /*!
* \brief Create a Quantile DMatrix with data iterator. * \brief Create a Quantile DMatrix with data iterator.
* *
* Short note for how to use the second set of callback for GPU Hist tree method: * Short note for how to use the second set of callback for (GPU)Hist tree method:
* *
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`. * - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
* - Step 1: Create a DMatrix proxy by `XGProxyDMatrixCreate` and hold the handle. * - Step 1: Create a DMatrix proxy by `XGProxyDMatrixCreate` and hold the handle.
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into * - Step 2: Pass the iterator handle, proxy handle and 2 methods into
* `XGDeviceQuantileDMatrixCreateFromCallback`. * `XGQuantileDMatrixCreateFromCallback`.
* - Step 3: Call appropriate data setters in `next` functions. * - Step 3: Call appropriate data setters in `next` functions.
* *
* See test_iterative_device_dmatrix.cu or Python interface for examples. * See test_iterative_dmatrix.cu or Python interface for examples.
* *
* \param iter A handle to external data iterator. * \param iter A handle to external data iterator.
* \param proxy A DMatrix proxy handle created by `XGProxyDMatrixCreate`. * \param proxy A DMatrix proxy handle created by `XGProxyDMatrixCreate`.
* \param ref Reference DMatrix for providing quantile information.
* \param reset Callback function resetting the iterator state. * \param reset Callback function resetting the iterator state.
* \param next Callback function yielding the next batch of data. * \param next Callback function yielding the next batch of data.
* \param missing Which value to represent missing value * \param missing Which value to represent missing value
@ -446,10 +444,20 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
* *
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback( XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread, int max_bin, XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out); DMatrixHandle *out);
/*!
* \brief Create a Device Quantile DMatrix with data iterator.
* \deprecated since 2.0
* \see XGQuantileDMatrixCreateFromCallback()
*/
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread, int max_bin, DMatrixHandle *out);
/*! /*!
* \brief Set data on a DMatrix proxy. * \brief Set data on a DMatrix proxy.

View File

@ -6,6 +6,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
from .core import ( from .core import (
DMatrix, DMatrix,
DeviceQuantileDMatrix, DeviceQuantileDMatrix,
QuantileDMatrix,
Booster, Booster,
DataIter, DataIter,
build_info, build_info,
@ -33,6 +34,7 @@ __all__ = [
# core # core
"DMatrix", "DMatrix",
"DeviceQuantileDMatrix", "DeviceQuantileDMatrix",
"QuantileDMatrix",
"Booster", "Booster",
"DataIter", "DataIter",
"train", "train",

View File

@ -1146,7 +1146,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Parameters Parameters
---------- ----------
feature_types : list or None feature_types :
Labels for features. None will reset existing feature names Labels for features. None will reset existing feature names
""" """
@ -1189,7 +1189,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
class _ProxyDMatrix(DMatrix): class _ProxyDMatrix(DMatrix):
"""A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix, """A placeholder class when DMatrix cannot be constructed (QuantileDMatrix,
inplace_predict). inplace_predict).
""" """
@ -1234,17 +1234,35 @@ class _ProxyDMatrix(DMatrix):
) )
class DeviceQuantileDMatrix(DMatrix): class QuantileDMatrix(DMatrix):
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do """A DMatrix variant that generates quantilized data directly from input for
not use this for test/validation tasks as some information may be lost in ``hist`` and ``gpu_hist`` tree methods. This DMatrix is primarily designed to save
quantisation. This DMatrix is primarily designed to save memory in training from memory in training by avoiding intermediate storage. Set ``max_bin`` to control the
device memory inputs by avoiding intermediate storage. Set max_bin to control the number of bins during quantisation, which should be consistent with the training
number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for parameter ``max_bin``. When ``QuantileDMatrix`` is used for validation/test dataset,
documents on meta info. ``ref`` should be another ``QuantileDMatrix``(or ``DMatrix``, but not recommended as
it defeats the purpose of saving memory) constructed from training dataset. See
:py:obj:`xgboost.DMatrix` for documents on meta info.
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. .. note::
.. versionadded:: 1.1.0 Do not use ``QuantileDMatrix`` as validation/test dataset without supplying a
reference (the training dataset) ``QuantileDMatrix`` using ``ref`` as some
information may be lost in quantisation.
.. versionadded:: 2.0.0
Parameters
----------
max_bin :
The number of histogram bin, should be consistent with the training parameter
``max_bin``.
ref :
The training dataset that provides quantile information, needed when creating
validation/test dataset with ``QuantileDMatrix``. Supplying the training DMatrix
as a reference means that the same quantisation applied to the training data is
applied to the validation/test data
""" """
@ -1261,7 +1279,8 @@ class DeviceQuantileDMatrix(DMatrix):
feature_names: Optional[FeatureNames] = None, feature_names: Optional[FeatureNames] = None,
feature_types: Optional[FeatureTypes] = None, feature_types: Optional[FeatureTypes] = None,
nthread: Optional[int] = None, nthread: Optional[int] = None,
max_bin: int = 256, max_bin: Optional[int] = None,
ref: Optional[DMatrix] = None,
group: Optional[ArrayLike] = None, group: Optional[ArrayLike] = None,
qid: Optional[ArrayLike] = None, qid: Optional[ArrayLike] = None,
label_lower_bound: Optional[ArrayLike] = None, label_lower_bound: Optional[ArrayLike] = None,
@ -1269,9 +1288,9 @@ class DeviceQuantileDMatrix(DMatrix):
feature_weights: Optional[ArrayLike] = None, feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False, enable_categorical: bool = False,
) -> None: ) -> None:
self.max_bin = max_bin self.max_bin: int = max_bin if max_bin is not None else 256
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1 self.nthread = nthread if nthread is not None else -1
self._silent = silent # unused, kept for compatibility self._silent = silent # unused, kept for compatibility
if isinstance(data, ctypes.c_void_p): if isinstance(data, ctypes.c_void_p):
@ -1280,12 +1299,13 @@ class DeviceQuantileDMatrix(DMatrix):
if qid is not None and group is not None: if qid is not None and group is not None:
raise ValueError( raise ValueError(
'Only one of the eval_qid or eval_group for each evaluation ' "Only one of the eval_qid or eval_group for each evaluation "
'dataset should be provided.' "dataset should be provided."
) )
self._init( self._init(
data, data,
ref=ref,
label=label, label=label,
weight=weight, weight=weight,
base_margin=base_margin, base_margin=base_margin,
@ -1299,7 +1319,13 @@ class DeviceQuantileDMatrix(DMatrix):
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
) )
def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None: def _init(
self,
data: DataType,
ref: Optional[DMatrix],
enable_categorical: bool,
**meta: Any,
) -> None:
from .data import ( from .data import (
_is_dlpack, _is_dlpack,
_transform_dlpack, _transform_dlpack,
@ -1317,20 +1343,26 @@ class DeviceQuantileDMatrix(DMatrix):
it = SingleBatchInternalIter(data=data, **meta) it = SingleBatchInternalIter(data=data, **meta)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(False, enable_categorical) reset_callback, next_callback = it.get_callbacks(True, enable_categorical)
if it.cache_prefix is not None: if it.cache_prefix is not None:
raise ValueError( raise ValueError(
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix " "QuantileDMatrix doesn't cache data, remove the cache_prefix "
"in iterator to fix this error." "in iterator to fix this error."
) )
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
args = {
"nthread": self.nthread,
"missing": self.missing,
"max_bin": self.max_bin,
}
config = from_pystr_to_cstr(json.dumps(args))
ret = _LIB.XGQuantileDMatrixCreateFromCallback(
None, None,
it.proxy.handle, it.proxy.handle,
ref.handle if ref is not None else ref,
reset_callback, reset_callback,
next_callback, next_callback,
ctypes.c_float(self.missing), config,
ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin),
ctypes.byref(handle), ctypes.byref(handle),
) )
it.reraise() it.reraise()
@ -1339,6 +1371,20 @@ class DeviceQuantileDMatrix(DMatrix):
self.handle = handle self.handle = handle
class DeviceQuantileDMatrix(QuantileDMatrix):
""" Use `QuantileDMatrix` instead.
.. deprecated:: 2.0.0
.. versionadded:: 1.1.0
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn("Please use `QuantileDMatrix` instead.", FutureWarning)
super().__init__(*args, **kwargs)
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]

View File

@ -35,6 +35,7 @@ import collections
import logging import logging
import platform import platform
import socket import socket
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial, update_wrapper from functools import partial, update_wrapper
@ -64,10 +65,10 @@ from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
from .core import ( from .core import (
Booster, Booster,
DataIter, DataIter,
DeviceQuantileDMatrix,
DMatrix, DMatrix,
Metric, Metric,
Objective, Objective,
QuantileDMatrix,
_deprecate_positional_args, _deprecate_positional_args,
_expect, _expect,
_has_categorical, _has_categorical,
@ -495,7 +496,7 @@ async def map_worker_partitions(
client: Optional["distributed.Client"], client: Optional["distributed.Client"],
func: Callable[..., _MapRetT], func: Callable[..., _MapRetT],
*refs: Any, *refs: Any,
workers: List[str], workers: Sequence[str],
) -> List[_MapRetT]: ) -> List[_MapRetT]:
"""Map a function onto partitions of each worker.""" """Map a function onto partitions of each worker."""
# Note for function purity: # Note for function purity:
@ -628,22 +629,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
return 1 return 1
class DaskDeviceQuantileDMatrix(DaskDMatrix): class DaskQuantileDMatrix(DaskDMatrix):
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce
the memory usage by eliminating data copies. Internally the all partitions/chunks
of data are merged by weighted GK sketching. So the number of partitions from dask
may affect training accuracy as GK generates bounded error for each merge. See doc
string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for
other parameters.
.. versionadded:: 1.2.0
Parameters
----------
max_bin : Number of bins for histogram construction.
"""
@_deprecate_positional_args @_deprecate_positional_args
def __init__( def __init__(
self, self,
@ -657,7 +643,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
silent: bool = False, # disable=unused-argument silent: bool = False, # disable=unused-argument
feature_names: Optional[FeatureNames] = None, feature_names: Optional[FeatureNames] = None,
feature_types: Optional[Union[Any, List[Any]]] = None, feature_types: Optional[Union[Any, List[Any]]] = None,
max_bin: int = 256, max_bin: Optional[int] = None,
ref: Optional[DMatrix] = None,
group: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None,
@ -684,14 +671,31 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
) )
self.max_bin = max_bin self.max_bin = max_bin
self.is_quantile = True self.is_quantile = True
self._ref: Optional[int] = id(ref) if ref is not None else None
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super()._create_fn_args(worker_addr) args = super()._create_fn_args(worker_addr)
args["max_bin"] = self.max_bin args["max_bin"] = self.max_bin
if self._ref is not None:
args["ref"] = self._ref
return args return args
def _create_device_quantile_dmatrix( class DaskDeviceQuantileDMatrix(DaskQuantileDMatrix):
"""Use `DaskQuantileDMatrix` instead.
.. deprecated:: 2.0.0
.. versionadded:: 1.2.0
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn("Please use `DaskQuantileDMatrix` instead.", FutureWarning)
super().__init__(*args, **kwargs)
def _create_quantile_dmatrix(
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[Union[Any, List[Any]]], feature_types: Optional[Union[Any, List[Any]]],
feature_weights: Optional[Any], feature_weights: Optional[Any],
@ -700,18 +704,20 @@ def _create_device_quantile_dmatrix(
parts: Optional[_DataParts], parts: Optional[_DataParts],
max_bin: int, max_bin: int,
enable_categorical: bool, enable_categorical: bool,
) -> DeviceQuantileDMatrix: ref: Optional[DMatrix] = None,
) -> QuantileDMatrix:
worker = distributed.get_worker() worker = distributed.get_worker()
if parts is None: if parts is None:
msg = f"worker {worker.address} has an empty DMatrix." msg = f"worker {worker.address} has an empty DMatrix."
LOGGER.warning(msg) LOGGER.warning(msg)
import cupy import cupy
d = DeviceQuantileDMatrix( d = QuantileDMatrix(
cupy.zeros((0, 0)), cupy.zeros((0, 0)),
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
max_bin=max_bin, max_bin=max_bin,
ref=ref,
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
) )
return d return d
@ -719,13 +725,14 @@ def _create_device_quantile_dmatrix(
unzipped_dict = _get_worker_parts(parts) unzipped_dict = _get_worker_parts(parts)
it = DaskPartitionIter(**unzipped_dict) it = DaskPartitionIter(**unzipped_dict)
dmatrix = DeviceQuantileDMatrix( dmatrix = QuantileDMatrix(
it, it,
missing=missing, missing=missing,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=nthread, nthread=nthread,
max_bin=max_bin, max_bin=max_bin,
ref=ref,
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
) )
dmatrix.set_info(feature_weights=feature_weights) dmatrix.set_info(feature_weights=feature_weights)
@ -786,11 +793,9 @@ def _create_dmatrix(
return dmatrix return dmatrix
def _dmatrix_from_list_of_parts( def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
is_quantile: bool, **kwargs: Any
) -> Union[DMatrix, DeviceQuantileDMatrix]:
if is_quantile: if is_quantile:
return _create_device_quantile_dmatrix(**kwargs) return _create_quantile_dmatrix(**kwargs)
return _create_dmatrix(**kwargs) return _create_dmatrix(**kwargs)
@ -921,7 +926,18 @@ async def _train_async(
if evals_id[i] == train_id: if evals_id[i] == train_id:
evals.append((Xy, evals_name[i])) evals.append((Xy, evals_name[i]))
continue continue
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads) if ref.get("ref", None) is not None:
if ref["ref"] != train_id:
raise ValueError(
"The training DMatrix should be used as a reference"
" to evaluation `QuantileDMatrix`."
)
del ref["ref"]
eval_Xy = _dmatrix_from_list_of_parts(
**ref, nthread=n_threads, ref=Xy
)
else:
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
evals.append((eval_Xy, evals_name[i])) evals.append((eval_Xy, evals_name[i]))
booster = worker_train( booster = worker_train(
@ -960,12 +976,14 @@ async def _train_async(
results = await map_worker_partitions( results = await map_worker_partitions(
client, client,
dispatched_train, dispatched_train,
# extra function parameters
params, params,
_rabit_args, _rabit_args,
id(dtrain), id(dtrain),
evals_name, evals_name,
evals_id, evals_id,
*([dtrain] + evals_data), *([dtrain] + evals_data),
# workers to be used for training
workers=workers, workers=workers,
) )
return list(filter(lambda ret: ret is not None, results))[0] return list(filter(lambda ret: ret is not None, results))[0]

View File

@ -1167,6 +1167,7 @@ def _proxy_transform(
if _is_dlpack(data): if _is_dlpack(data):
return _transform_dlpack(data), None, feature_names, feature_types return _transform_dlpack(data), None, feature_names, feature_types
if _is_numpy_array(data): if _is_numpy_array(data):
data, _ = _ensure_np_dtype(data, data.dtype)
return data, None, feature_names, feature_types return data, None, feature_names, feature_types
if _is_scipy_csr(data): if _is_scipy_csr(data):
return data, None, feature_names, feature_types return data, None, feature_names, feature_types

View File

@ -281,11 +281,36 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
int nthread, int max_bin, int nthread, int max_bin,
DMatrixHandle *out) { DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
LOG(WARNING) << __func__ << " is deprecated. Use `XGQuantileDMatrixCreateFromCallback` instead.";
*out = new std::shared_ptr<xgboost::DMatrix>{ *out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)}; xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
API_END(); API_END();
} }
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> _ref{nullptr};
if (ref) {
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
StringView err{"Invalid handle to ref."};
CHECK(pp_ref) << err;
_ref = *pp_ref;
CHECK(_ref) << err;
}
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
auto max_bin = OptionalArg<Integer, int64_t>(jconfig, "max_bin", 256);
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
API_END();
}
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) { XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
API_BEGIN(); API_BEGIN();
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);; *out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import pytest import pytest
@ -6,16 +5,14 @@ import sys
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm import testing as tm
import test_quantile_dmatrix as tqd
class TestDeviceQuantileDMatrix: class TestDeviceQuantileDMatrix:
def test_dmatrix_numpy_init(self): cputest = tqd.TestQuantileDMatrix()
data = np.random.randn(5, 5)
with pytest.raises(TypeError, match='is not supported'):
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_feature_weights(self): def test_dmatrix_feature_weights(self) -> None:
import cupy as cp import cupy as cp
rng = cp.random.RandomState(1994) rng = cp.random.RandomState(1994)
data = rng.randn(5, 5) data = rng.randn(5, 5)
@ -29,7 +26,7 @@ class TestDeviceQuantileDMatrix:
feature_weights.astype(np.float32)) feature_weights.astype(np.float32))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_cupy_init(self): def test_dmatrix_cupy_init(self) -> None:
import cupy as cp import cupy as cp
data = cp.random.randn(5, 5) data = cp.random.randn(5, 5)
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
@ -55,3 +52,10 @@ class TestDeviceQuantileDMatrix:
cp.testing.assert_allclose(fw, got_fw) cp.testing.assert_allclose(fw, got_fw)
cp.testing.assert_allclose(labels, got_labels) cp.testing.assert_allclose(labels, got_labels)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf())
def test_ref_dmatrix(self) -> None:
import cupy as cp
rng = cp.random.RandomState(1994)
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)

View File

@ -429,9 +429,10 @@ class TestDistributedGPU:
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters) sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
del sig["client"] del sig["client"]
ddm_names = list(sig.keys()) ddm_names = list(sig.keys())
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters) sig = OrderedDict(signature(dxgb.DaskQuantileDMatrix).parameters)
del sig["client"] del sig["client"]
del sig["max_bin"] del sig["max_bin"]
del sig["ref"]
ddqdm_names = list(sig.keys()) ddqdm_names = list(sig.keys())
assert len(ddm_names) == len(ddqdm_names) assert len(ddm_names) == len(ddqdm_names)
@ -442,9 +443,10 @@ class TestDistributedGPU:
sig = OrderedDict(signature(xgb.DMatrix).parameters) sig = OrderedDict(signature(xgb.DMatrix).parameters)
del sig["nthread"] # no nthread in dask del sig["nthread"] # no nthread in dask
dm_names = list(sig.keys()) dm_names = list(sig.keys())
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters) sig = OrderedDict(signature(xgb.QuantileDMatrix).parameters)
del sig["nthread"] del sig["nthread"]
del sig["max_bin"] del sig["max_bin"]
del sig["ref"]
dqdm_names = list(sig.keys()) dqdm_names = list(sig.keys())
# between single node # between single node
@ -499,7 +501,6 @@ class TestDistributedGPU:
for arg in rabit_args: for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8') port_env = arg.decode('utf-8')
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"): if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8") uri_env = arg.decode("utf-8")
port = port_env.split('=') port = port_env.split('=')

View File

@ -1,32 +1,12 @@
import xgboost as xgb import xgboost as xgb
from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.data import SingleBatchInternalIter as SingleBatch
import numpy as np import numpy as np
from testing import IteratorForTest, non_increasing from testing import IteratorForTest, non_increasing, make_batches
from typing import Tuple, List
import pytest import pytest
from hypothesis import given, strategies, settings from hypothesis import given, strategies, settings
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
X = []
y = []
if use_cupy:
import cupy
rng = cupy.random.RandomState(1994)
else:
rng = np.random.RandomState(1994)
for i in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch)
X.append(_X)
y.append(_y)
return X, y
def test_single_batch(tree_method: str = "approx") -> None: def test_single_batch(tree_method: str = "approx") -> None:
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
@ -111,8 +91,8 @@ def run_data_iterator(
if not subsample: if not subsample:
assert non_increasing(results_from_it["Train"]["rmse"]) assert non_increasing(results_from_it["Train"]["rmse"])
X, y = it.as_arrays() X, y, w = it.as_arrays()
Xy = xgb.DMatrix(X, y) Xy = xgb.DMatrix(X, y, weight=w)
assert Xy.num_row() == n_samples_per_batch * n_batches assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features assert Xy.num_col() == n_features

View File

@ -0,0 +1,212 @@
from typing import Dict, List, Any
import numpy as np
import pytest
from scipy import sparse
from testing import IteratorForTest, make_batches, make_batches_sparse, make_categorical
import xgboost as xgb
class TestQuantileDMatrix:
def test_basic(self) -> None:
n_samples = 234
n_features = 8
rng = np.random.default_rng()
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
n_samples, n_features
)
y = rng.normal(0, 3, size=n_samples)
Xy = xgb.QuantileDMatrix(X, y)
assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features
X = sparse.random(n_samples, n_features, density=0.1, format="csr")
Xy = xgb.QuantileDMatrix(X, y)
assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features
X = sparse.random(n_samples, n_features, density=0.8, format="csr")
Xy = xgb.QuantileDMatrix(X, y)
assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
def test_with_iterator(self, sparsity: float) -> None:
n_samples_per_batch = 317
n_features = 8
n_batches = 7
if sparsity == 0.0:
it = IteratorForTest(
*make_batches(n_samples_per_batch, n_features, n_batches, False), None
)
else:
it = IteratorForTest(
*make_batches_sparse(
n_samples_per_batch, n_features, n_batches, sparsity
),
None
)
Xy = xgb.QuantileDMatrix(it)
assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9])
def test_training(self, sparsity: float) -> None:
n_samples_per_batch = 317
n_features = 8
n_batches = 7
if sparsity == 0.0:
it = IteratorForTest(
*make_batches(n_samples_per_batch, n_features, n_batches, False), None
)
else:
it = IteratorForTest(
*make_batches_sparse(
n_samples_per_batch, n_features, n_batches, sparsity
),
None
)
parameters = {"tree_method": "hist", "max_bin": 256}
Xy_it = xgb.QuantileDMatrix(it, max_bin=parameters["max_bin"])
from_it = xgb.train(parameters, Xy_it)
X, y, w = it.as_arrays()
w_it = Xy_it.get_weight()
np.testing.assert_allclose(w_it, w)
Xy_arr = xgb.DMatrix(X, y, weight=w)
from_arr = xgb.train(parameters, Xy_arr)
np.testing.assert_allclose(from_arr.predict(Xy_it), from_it.predict(Xy_arr))
y -= y.min()
y += 0.01
Xy = xgb.QuantileDMatrix(X, y, weight=w)
with pytest.raises(ValueError, match=r"Only.*hist.*"):
parameters = {
"tree_method": "approx",
"max_bin": 256,
"objective": "reg:gamma",
}
xgb.train(parameters, Xy)
def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
n_samples, n_features = 2048, 17
if enable_cat:
X, y = make_categorical(
n_samples, n_features, n_categories=13, onehot=False
)
if tree_method == "gpu_hist":
import cudf
X = cudf.from_pandas(X)
y = cudf.from_pandas(y)
else:
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
n_samples, n_features
)
y = rng.normal(0, 3, size=n_samples)
# Use ref
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
qdm_results: Dict[str, Dict[str, List[float]]] = {}
xgb.train(
{"tree_method": tree_method},
Xy,
evals=[(Xy, "Train"), (Xy_valid, "valid")],
evals_result=qdm_results,
)
np.testing.assert_allclose(
qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"]
)
# No ref
Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
qdm_results = {}
xgb.train(
{"tree_method": tree_method},
Xy,
evals=[(Xy, "Train"), (Xy_valid, "valid")],
evals_result=qdm_results,
)
np.testing.assert_allclose(
qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"]
)
# Different number of features
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
dXy = xgb.DMatrix(X, y, enable_categorical=enable_cat)
n_samples, n_features = 256, 15
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
n_samples, n_features
)
y = rng.normal(0, 3, size=n_samples)
with pytest.raises(ValueError, match=r".*features\."):
xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
# Compare training results
n_samples, n_features = 256, 17
if enable_cat:
X, y = make_categorical(n_samples, n_features, 13, onehot=False)
if tree_method == "gpu_hist":
import cudf
X = cudf.from_pandas(X)
y = cudf.from_pandas(y)
else:
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
n_samples, n_features
)
y = rng.normal(0, 3, size=n_samples)
Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
# use DMatrix as ref
Xy_valid_d = xgb.QuantileDMatrix(X, y, ref=dXy, enable_categorical=enable_cat)
dXy_valid = xgb.DMatrix(X, y, enable_categorical=enable_cat)
qdm_results = {}
xgb.train(
{"tree_method": tree_method},
Xy,
evals=[(Xy, "Train"), (Xy_valid, "valid")],
evals_result=qdm_results,
)
dm_results: Dict[str, Dict[str, List[float]]] = {}
xgb.train(
{"tree_method": tree_method},
dXy,
evals=[(dXy, "Train"), (dXy_valid, "valid"), (Xy_valid_d, "dvalid")],
evals_result=dm_results,
)
np.testing.assert_allclose(
dm_results["Train"]["rmse"], qdm_results["Train"]["rmse"]
)
np.testing.assert_allclose(
dm_results["valid"]["rmse"], qdm_results["valid"]["rmse"]
)
np.testing.assert_allclose(
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
)
def test_ref_dmatrix(self) -> None:
rng = np.random.RandomState(1994)
self.run_ref_dmatrix(rng, "hist", True)
self.run_ref_dmatrix(rng, "hist", False)
def test_predict(self) -> None:
n_samples, n_features = 16, 2
X, y = make_categorical(
n_samples, n_features, n_categories=13, onehot=False
)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "hist"}, Xy)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
a = booster.predict(Xy)
qXy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
b = booster.predict(qXy)
np.testing.assert_allclose(a, b)

View File

@ -1382,6 +1382,42 @@ class TestWithDask:
num_rounds = 30 num_rounds = 30
self.run_updater_test(client, params, num_rounds, dataset, 'hist') self.run_updater_test(client, params, num_rounds, dataset, 'hist')
def test_quantile_dmatrix(self, client: Client) -> None:
X, y = make_categorical(client, 10000, 30, 13)
Xy = xgb.dask.DaskDMatrix(client, X, y, enable_categorical=True)
valid_Xy = xgb.dask.DaskDMatrix(client, X, y, enable_categorical=True)
output = xgb.dask.train(
client,
{"tree_method": "hist"},
Xy,
num_boost_round=10,
evals=[(Xy, "Train"), (valid_Xy, "Valid")]
)
dmatrix_hist = output["history"]
Xy = xgb.dask.DaskQuantileDMatrix(client, X, y, enable_categorical=True)
valid_Xy = xgb.dask.DaskQuantileDMatrix(
client, X, y, enable_categorical=True, ref=Xy
)
output = xgb.dask.train(
client,
{"tree_method": "hist"},
Xy,
num_boost_round=10,
evals=[(Xy, "Train"), (valid_Xy, "Valid")]
)
quantile_hist = output["history"]
np.testing.assert_allclose(
quantile_hist["Train"]["rmse"], dmatrix_hist["Train"]["rmse"]
)
np.testing.assert_allclose(
quantile_hist["Valid"]["rmse"], dmatrix_hist["Valid"]["rmse"]
)
@given(params=exact_parameter_strategy, @given(params=exact_parameter_strategy,
dataset=tm.dataset_strategy) dataset=tm.dataset_strategy)
@settings(deadline=None, suppress_health_check=suppress, print_blob=True) @settings(deadline=None, suppress_health_check=suppress, print_blob=True)

View File

@ -1,11 +1,11 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import os import os
import multiprocessing import multiprocessing
from typing import Tuple, Union from typing import Tuple, Union, List, Sequence, Callable
import urllib import urllib
import zipfile import zipfile
import sys import sys
from typing import Optional from typing import Optional, Dict, Any
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO from io import StringIO
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
@ -180,79 +180,148 @@ def skip_s390x():
class IteratorForTest(xgb.core.DataIter): class IteratorForTest(xgb.core.DataIter):
def __init__(self, X, y): def __init__(
self,
X: Sequence,
y: Sequence,
w: Optional[Sequence],
cache: Optional[str] = "./"
) -> None:
assert len(X) == len(y) assert len(X) == len(y)
self.X = X self.X = X
self.y = y self.y = y
self.w = w
self.it = 0 self.it = 0
super().__init__("./") super().__init__(cache)
def next(self, input_data): def next(self, input_data: Callable) -> int:
if self.it == len(self.X): if self.it == len(self.X):
return 0 return 0
# Use copy to make sure the iterator doesn't hold a reference to the data. # Use copy to make sure the iterator doesn't hold a reference to the data.
input_data(data=self.X[self.it].copy(), label=self.y[self.it].copy()) input_data(
gc.collect() # clear up the copy, see if XGBoost access freed memory. data=self.X[self.it].copy(),
label=self.y[self.it].copy(),
weight=self.w[self.it].copy() if self.w else None,
)
gc.collect() # clear up the copy, see if XGBoost access freed memory.
self.it += 1 self.it += 1
return 1 return 1
def reset(self): def reset(self) -> None:
self.it = 0 self.it = 0
def as_arrays(self): def as_arrays(
X = np.concatenate(self.X, axis=0) self,
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], np.ndarray, np.ndarray]:
if isinstance(self.X[0], sparse.csr_matrix):
X = sparse.vstack(self.X, format="csr")
else:
X = np.concatenate(self.X, axis=0)
y = np.concatenate(self.y, axis=0) y = np.concatenate(self.y, axis=0)
return X, y w = np.concatenate(self.w, axis=0)
return X, y, w
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
w = []
if use_cupy:
import cupy
rng = cupy.random.RandomState(1994)
else:
rng = np.random.RandomState(1994)
for i in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch)
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w
def make_batches_sparse(
n_samples_per_batch: int, n_features: int, n_batches: int, sparsity: float
) -> Tuple[List[sparse.csr_matrix], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
w = []
rng = np.random.RandomState(1994)
for i in range(n_batches):
_X = sparse.random(
n_samples_per_batch,
n_features,
1.0 - sparsity,
format="csr",
dtype=np.float32,
random_state=rng,
)
_y = rng.randn(n_samples_per_batch)
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w
# Contains a dataset in numpy format as well as the relevant objective and metric # Contains a dataset in numpy format as well as the relevant objective and metric
class TestDataset: class TestDataset:
def __init__(self, name, get_dataset, objective, metric): def __init__(
self, name: str, get_dataset: Callable, objective: str, metric: str
) -> None:
self.name = name self.name = name
self.objective = objective self.objective = objective
self.metric = metric self.metric = metric
self.X, self.y = get_dataset() self.X, self.y = get_dataset()
self.w = None self.w: Optional[np.ndarray] = None
self.margin: Optional[np.ndarray] = None self.margin: Optional[np.ndarray] = None
def set_params(self, params_in): def set_params(self, params_in: Dict[str, Any]) -> Dict[str, Any]:
params_in['objective'] = self.objective params_in['objective'] = self.objective
params_in['eval_metric'] = self.metric params_in['eval_metric'] = self.metric
if self.objective == "multi:softmax": if self.objective == "multi:softmax":
params_in["num_class"] = int(np.max(self.y) + 1) params_in["num_class"] = int(np.max(self.y) + 1)
return params_in return params_in
def get_dmat(self): def get_dmat(self) -> xgb.DMatrix:
return xgb.DMatrix( return xgb.DMatrix(
self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True
) )
def get_device_dmat(self): def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix:
w = None if self.w is None else cp.array(self.w) w = None if self.w is None else cp.array(self.w)
X = cp.array(self.X, dtype=np.float32) X = cp.array(self.X, dtype=np.float32)
y = cp.array(self.y, dtype=np.float32) y = cp.array(self.y, dtype=np.float32)
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin) return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
def get_external_dmat(self): def get_external_dmat(self) -> xgb.DMatrix:
n_samples = self.X.shape[0] n_samples = self.X.shape[0]
n_batches = 10 n_batches = 10
per_batch = n_samples // n_batches + 1 per_batch = n_samples // n_batches + 1
predictor = [] predictor = []
response = [] response = []
weight = []
for i in range(n_batches): for i in range(n_batches):
beg = i * per_batch beg = i * per_batch
end = min((i + 1) * per_batch, n_samples) end = min((i + 1) * per_batch, n_samples)
assert end != beg assert end != beg
X = self.X[beg: end, ...] X = self.X[beg: end, ...]
y = self.y[beg: end] y = self.y[beg: end]
w = self.w[beg: end] if self.w is not None else None
predictor.append(X) predictor.append(X)
response.append(y) response.append(y)
if w is not None:
weight.append(w)
it = IteratorForTest(predictor, response) it = IteratorForTest(predictor, response, weight if weight else None)
return xgb.DMatrix(it) return xgb.DMatrix(it)
def __repr__(self): def __repr__(self) -> str:
return self.name return self.name

View File

@ -1,4 +1,3 @@
import os
import xgboost as xgb import xgboost as xgb
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score