Quantile DMatrix for CPU. (#8130)
- Add a new `QuantileDMatrix` that works for both CPU and GPU. - Deprecate `DeviceQuantileDMatrix`.
This commit is contained in:
parent
2cba1d9fcc
commit
d87f69215e
@ -22,6 +22,9 @@ Core Data Structure
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.QuantileDMatrix
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.DeviceQuantileDMatrix
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
@ -415,28 +415,26 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
|
||||
DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next,
|
||||
char const* c_json_config,
|
||||
DMatrixHandle *out);
|
||||
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
|
||||
char const *c_json_config, DMatrixHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief Create a Quantile DMatrix with data iterator.
|
||||
*
|
||||
* Short note for how to use the second set of callback for GPU Hist tree method:
|
||||
* Short note for how to use the second set of callback for (GPU)Hist tree method:
|
||||
*
|
||||
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
|
||||
* - Step 1: Create a DMatrix proxy by `XGProxyDMatrixCreate` and hold the handle.
|
||||
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
|
||||
* `XGDeviceQuantileDMatrixCreateFromCallback`.
|
||||
* `XGQuantileDMatrixCreateFromCallback`.
|
||||
* - Step 3: Call appropriate data setters in `next` functions.
|
||||
*
|
||||
* See test_iterative_device_dmatrix.cu or Python interface for examples.
|
||||
* See test_iterative_dmatrix.cu or Python interface for examples.
|
||||
*
|
||||
* \param iter A handle to external data iterator.
|
||||
* \param proxy A DMatrix proxy handle created by `XGProxyDMatrixCreate`.
|
||||
* \param ref Reference DMatrix for providing quantile information.
|
||||
* \param reset Callback function resetting the iterator state.
|
||||
* \param next Callback function yielding the next batch of data.
|
||||
* \param missing Which value to represent missing value
|
||||
@ -446,11 +444,21 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int nthread, int max_bin,
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief Create a Device Quantile DMatrix with data iterator.
|
||||
* \deprecated since 2.0
|
||||
* \see XGQuantileDMatrixCreateFromCallback()
|
||||
*/
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
int nthread, int max_bin, DMatrixHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief Set data on a DMatrix proxy.
|
||||
*
|
||||
|
||||
@ -6,6 +6,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
from .core import (
|
||||
DMatrix,
|
||||
DeviceQuantileDMatrix,
|
||||
QuantileDMatrix,
|
||||
Booster,
|
||||
DataIter,
|
||||
build_info,
|
||||
@ -33,6 +34,7 @@ __all__ = [
|
||||
# core
|
||||
"DMatrix",
|
||||
"DeviceQuantileDMatrix",
|
||||
"QuantileDMatrix",
|
||||
"Booster",
|
||||
"DataIter",
|
||||
"train",
|
||||
|
||||
@ -1146,7 +1146,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_types : list or None
|
||||
feature_types :
|
||||
Labels for features. None will reset existing feature names
|
||||
|
||||
"""
|
||||
@ -1189,7 +1189,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
|
||||
class _ProxyDMatrix(DMatrix):
|
||||
"""A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix,
|
||||
"""A placeholder class when DMatrix cannot be constructed (QuantileDMatrix,
|
||||
inplace_predict).
|
||||
|
||||
"""
|
||||
@ -1234,17 +1234,35 @@ class _ProxyDMatrix(DMatrix):
|
||||
)
|
||||
|
||||
|
||||
class DeviceQuantileDMatrix(DMatrix):
|
||||
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do
|
||||
not use this for test/validation tasks as some information may be lost in
|
||||
quantisation. This DMatrix is primarily designed to save memory in training from
|
||||
device memory inputs by avoiding intermediate storage. Set max_bin to control the
|
||||
number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for
|
||||
documents on meta info.
|
||||
class QuantileDMatrix(DMatrix):
|
||||
"""A DMatrix variant that generates quantilized data directly from input for
|
||||
``hist`` and ``gpu_hist`` tree methods. This DMatrix is primarily designed to save
|
||||
memory in training by avoiding intermediate storage. Set ``max_bin`` to control the
|
||||
number of bins during quantisation, which should be consistent with the training
|
||||
parameter ``max_bin``. When ``QuantileDMatrix`` is used for validation/test dataset,
|
||||
``ref`` should be another ``QuantileDMatrix``(or ``DMatrix``, but not recommended as
|
||||
it defeats the purpose of saving memory) constructed from training dataset. See
|
||||
:py:obj:`xgboost.DMatrix` for documents on meta info.
|
||||
|
||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||
.. note::
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
Do not use ``QuantileDMatrix`` as validation/test dataset without supplying a
|
||||
reference (the training dataset) ``QuantileDMatrix`` using ``ref`` as some
|
||||
information may be lost in quantisation.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bin :
|
||||
The number of histogram bin, should be consistent with the training parameter
|
||||
``max_bin``.
|
||||
|
||||
ref :
|
||||
The training dataset that provides quantile information, needed when creating
|
||||
validation/test dataset with ``QuantileDMatrix``. Supplying the training DMatrix
|
||||
as a reference means that the same quantisation applied to the training data is
|
||||
applied to the validation/test data
|
||||
|
||||
"""
|
||||
|
||||
@ -1261,7 +1279,8 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
feature_types: Optional[FeatureTypes] = None,
|
||||
nthread: Optional[int] = None,
|
||||
max_bin: int = 256,
|
||||
max_bin: Optional[int] = None,
|
||||
ref: Optional[DMatrix] = None,
|
||||
group: Optional[ArrayLike] = None,
|
||||
qid: Optional[ArrayLike] = None,
|
||||
label_lower_bound: Optional[ArrayLike] = None,
|
||||
@ -1269,9 +1288,9 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
feature_weights: Optional[ArrayLike] = None,
|
||||
enable_categorical: bool = False,
|
||||
) -> None:
|
||||
self.max_bin = max_bin
|
||||
self.max_bin: int = max_bin if max_bin is not None else 256
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.nthread = nthread if nthread is not None else 1
|
||||
self.nthread = nthread if nthread is not None else -1
|
||||
self._silent = silent # unused, kept for compatibility
|
||||
|
||||
if isinstance(data, ctypes.c_void_p):
|
||||
@ -1280,12 +1299,13 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
|
||||
if qid is not None and group is not None:
|
||||
raise ValueError(
|
||||
'Only one of the eval_qid or eval_group for each evaluation '
|
||||
'dataset should be provided.'
|
||||
"Only one of the eval_qid or eval_group for each evaluation "
|
||||
"dataset should be provided."
|
||||
)
|
||||
|
||||
self._init(
|
||||
data,
|
||||
ref=ref,
|
||||
label=label,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
@ -1299,7 +1319,13 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
|
||||
def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None:
|
||||
def _init(
|
||||
self,
|
||||
data: DataType,
|
||||
ref: Optional[DMatrix],
|
||||
enable_categorical: bool,
|
||||
**meta: Any,
|
||||
) -> None:
|
||||
from .data import (
|
||||
_is_dlpack,
|
||||
_transform_dlpack,
|
||||
@ -1317,20 +1343,26 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
it = SingleBatchInternalIter(data=data, **meta)
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
reset_callback, next_callback = it.get_callbacks(False, enable_categorical)
|
||||
reset_callback, next_callback = it.get_callbacks(True, enable_categorical)
|
||||
if it.cache_prefix is not None:
|
||||
raise ValueError(
|
||||
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
|
||||
"QuantileDMatrix doesn't cache data, remove the cache_prefix "
|
||||
"in iterator to fix this error."
|
||||
)
|
||||
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
|
||||
args = {
|
||||
"nthread": self.nthread,
|
||||
"missing": self.missing,
|
||||
"max_bin": self.max_bin,
|
||||
}
|
||||
config = from_pystr_to_cstr(json.dumps(args))
|
||||
ret = _LIB.XGQuantileDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
ref.handle if ref is not None else ref,
|
||||
reset_callback,
|
||||
next_callback,
|
||||
ctypes.c_float(self.missing),
|
||||
ctypes.c_int(self.nthread),
|
||||
ctypes.c_int(self.max_bin),
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
it.reraise()
|
||||
@ -1339,6 +1371,20 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
self.handle = handle
|
||||
|
||||
|
||||
class DeviceQuantileDMatrix(QuantileDMatrix):
|
||||
""" Use `QuantileDMatrix` instead.
|
||||
|
||||
.. deprecated:: 2.0.0
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn("Please use `QuantileDMatrix` instead.", FutureWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ import collections
|
||||
import logging
|
||||
import platform
|
||||
import socket
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, update_wrapper
|
||||
@ -64,10 +65,10 @@ from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
DeviceQuantileDMatrix,
|
||||
DMatrix,
|
||||
Metric,
|
||||
Objective,
|
||||
QuantileDMatrix,
|
||||
_deprecate_positional_args,
|
||||
_expect,
|
||||
_has_categorical,
|
||||
@ -495,7 +496,7 @@ async def map_worker_partitions(
|
||||
client: Optional["distributed.Client"],
|
||||
func: Callable[..., _MapRetT],
|
||||
*refs: Any,
|
||||
workers: List[str],
|
||||
workers: Sequence[str],
|
||||
) -> List[_MapRetT]:
|
||||
"""Map a function onto partitions of each worker."""
|
||||
# Note for function purity:
|
||||
@ -628,22 +629,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
return 1
|
||||
|
||||
|
||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce
|
||||
the memory usage by eliminating data copies. Internally the all partitions/chunks
|
||||
of data are merged by weighted GK sketching. So the number of partitions from dask
|
||||
may affect training accuracy as GK generates bounded error for each merge. See doc
|
||||
string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for
|
||||
other parameters.
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bin : Number of bins for histogram construction.
|
||||
|
||||
"""
|
||||
|
||||
class DaskQuantileDMatrix(DaskDMatrix):
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
@ -657,7 +643,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
silent: bool = False, # disable=unused-argument
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
max_bin: int = 256,
|
||||
max_bin: Optional[int] = None,
|
||||
ref: Optional[DMatrix] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
@ -684,14 +671,31 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
)
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
self._ref: Optional[int] = id(ref) if ref is not None else None
|
||||
|
||||
def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
args = super()._create_fn_args(worker_addr)
|
||||
args["max_bin"] = self.max_bin
|
||||
if self._ref is not None:
|
||||
args["ref"] = self._ref
|
||||
return args
|
||||
|
||||
|
||||
def _create_device_quantile_dmatrix(
|
||||
class DaskDeviceQuantileDMatrix(DaskQuantileDMatrix):
|
||||
"""Use `DaskQuantileDMatrix` instead.
|
||||
|
||||
.. deprecated:: 2.0.0
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn("Please use `DaskQuantileDMatrix` instead.", FutureWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def _create_quantile_dmatrix(
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[Union[Any, List[Any]]],
|
||||
feature_weights: Optional[Any],
|
||||
@ -700,18 +704,20 @@ def _create_device_quantile_dmatrix(
|
||||
parts: Optional[_DataParts],
|
||||
max_bin: int,
|
||||
enable_categorical: bool,
|
||||
) -> DeviceQuantileDMatrix:
|
||||
ref: Optional[DMatrix] = None,
|
||||
) -> QuantileDMatrix:
|
||||
worker = distributed.get_worker()
|
||||
if parts is None:
|
||||
msg = f"worker {worker.address} has an empty DMatrix."
|
||||
LOGGER.warning(msg)
|
||||
import cupy
|
||||
|
||||
d = DeviceQuantileDMatrix(
|
||||
d = QuantileDMatrix(
|
||||
cupy.zeros((0, 0)),
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
max_bin=max_bin,
|
||||
ref=ref,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
return d
|
||||
@ -719,13 +725,14 @@ def _create_device_quantile_dmatrix(
|
||||
unzipped_dict = _get_worker_parts(parts)
|
||||
it = DaskPartitionIter(**unzipped_dict)
|
||||
|
||||
dmatrix = DeviceQuantileDMatrix(
|
||||
dmatrix = QuantileDMatrix(
|
||||
it,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=nthread,
|
||||
max_bin=max_bin,
|
||||
ref=ref,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
dmatrix.set_info(feature_weights=feature_weights)
|
||||
@ -786,11 +793,9 @@ def _create_dmatrix(
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _dmatrix_from_list_of_parts(
|
||||
is_quantile: bool, **kwargs: Any
|
||||
) -> Union[DMatrix, DeviceQuantileDMatrix]:
|
||||
def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
|
||||
if is_quantile:
|
||||
return _create_device_quantile_dmatrix(**kwargs)
|
||||
return _create_quantile_dmatrix(**kwargs)
|
||||
return _create_dmatrix(**kwargs)
|
||||
|
||||
|
||||
@ -921,6 +926,17 @@ async def _train_async(
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference"
|
||||
" to evaluation `QuantileDMatrix`."
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(
|
||||
**ref, nthread=n_threads, ref=Xy
|
||||
)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
|
||||
@ -960,12 +976,14 @@ async def _train_async(
|
||||
results = await map_worker_partitions(
|
||||
client,
|
||||
dispatched_train,
|
||||
# extra function parameters
|
||||
params,
|
||||
_rabit_args,
|
||||
id(dtrain),
|
||||
evals_name,
|
||||
evals_id,
|
||||
*([dtrain] + evals_data),
|
||||
# workers to be used for training
|
||||
workers=workers,
|
||||
)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
@ -1167,6 +1167,7 @@ def _proxy_transform(
|
||||
if _is_dlpack(data):
|
||||
return _transform_dlpack(data), None, feature_names, feature_types
|
||||
if _is_numpy_array(data):
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
return data, None, feature_names, feature_types
|
||||
if _is_scipy_csr(data):
|
||||
return data, None, feature_names, feature_types
|
||||
|
||||
@ -281,11 +281,36 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
|
||||
int nthread, int max_bin,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
LOG(WARNING) << __func__ << " is deprecated. Use `XGQuantileDMatrixCreateFromCallback` instead.";
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
std::shared_ptr<DMatrix> _ref{nullptr};
|
||||
if (ref) {
|
||||
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
|
||||
StringView err{"Invalid handle to ref."};
|
||||
CHECK(pp_ref) << err;
|
||||
_ref = *pp_ref;
|
||||
CHECK(_ref) << err;
|
||||
}
|
||||
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
|
||||
auto max_bin = OptionalArg<Integer, int64_t>(jconfig, "max_bin", 256);
|
||||
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
@ -6,16 +5,14 @@ import sys
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
import test_quantile_dmatrix as tqd
|
||||
|
||||
|
||||
class TestDeviceQuantileDMatrix:
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
with pytest.raises(TypeError, match='is not supported'):
|
||||
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
||||
cputest = tqd.TestQuantileDMatrix()
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dmatrix_feature_weights(self):
|
||||
def test_dmatrix_feature_weights(self) -> None:
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
data = rng.randn(5, 5)
|
||||
@ -29,7 +26,7 @@ class TestDeviceQuantileDMatrix:
|
||||
feature_weights.astype(np.float32))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dmatrix_cupy_init(self):
|
||||
def test_dmatrix_cupy_init(self) -> None:
|
||||
import cupy as cp
|
||||
data = cp.random.randn(5, 5)
|
||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||
@ -55,3 +52,10 @@ class TestDeviceQuantileDMatrix:
|
||||
|
||||
cp.testing.assert_allclose(fw, got_fw)
|
||||
cp.testing.assert_allclose(labels, got_labels)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_ref_dmatrix(self) -> None:
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
|
||||
|
||||
@ -429,9 +429,10 @@ class TestDistributedGPU:
|
||||
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
|
||||
del sig["client"]
|
||||
ddm_names = list(sig.keys())
|
||||
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters)
|
||||
sig = OrderedDict(signature(dxgb.DaskQuantileDMatrix).parameters)
|
||||
del sig["client"]
|
||||
del sig["max_bin"]
|
||||
del sig["ref"]
|
||||
ddqdm_names = list(sig.keys())
|
||||
assert len(ddm_names) == len(ddqdm_names)
|
||||
|
||||
@ -442,9 +443,10 @@ class TestDistributedGPU:
|
||||
sig = OrderedDict(signature(xgb.DMatrix).parameters)
|
||||
del sig["nthread"] # no nthread in dask
|
||||
dm_names = list(sig.keys())
|
||||
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters)
|
||||
sig = OrderedDict(signature(xgb.QuantileDMatrix).parameters)
|
||||
del sig["nthread"]
|
||||
del sig["max_bin"]
|
||||
del sig["ref"]
|
||||
dqdm_names = list(sig.keys())
|
||||
|
||||
# between single node
|
||||
@ -499,7 +501,6 @@ class TestDistributedGPU:
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
|
||||
@ -1,32 +1,12 @@
|
||||
import xgboost as xgb
|
||||
from xgboost.data import SingleBatchInternalIter as SingleBatch
|
||||
import numpy as np
|
||||
from testing import IteratorForTest, non_increasing
|
||||
from typing import Tuple, List
|
||||
from testing import IteratorForTest, non_increasing, make_batches
|
||||
import pytest
|
||||
from hypothesis import given, strategies, settings
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
|
||||
def make_batches(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
||||
X = []
|
||||
y = []
|
||||
if use_cupy:
|
||||
import cupy
|
||||
|
||||
rng = cupy.random.RandomState(1994)
|
||||
else:
|
||||
rng = np.random.RandomState(1994)
|
||||
for i in range(n_batches):
|
||||
_X = rng.randn(n_samples_per_batch, n_features)
|
||||
_y = rng.randn(n_samples_per_batch)
|
||||
X.append(_X)
|
||||
y.append(_y)
|
||||
return X, y
|
||||
|
||||
|
||||
def test_single_batch(tree_method: str = "approx") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
@ -111,8 +91,8 @@ def run_data_iterator(
|
||||
if not subsample:
|
||||
assert non_increasing(results_from_it["Train"]["rmse"])
|
||||
|
||||
X, y = it.as_arrays()
|
||||
Xy = xgb.DMatrix(X, y)
|
||||
X, y, w = it.as_arrays()
|
||||
Xy = xgb.DMatrix(X, y, weight=w)
|
||||
assert Xy.num_row() == n_samples_per_batch * n_batches
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
|
||||
212
tests/python/test_quantile_dmatrix.py
Normal file
212
tests/python/test_quantile_dmatrix.py
Normal file
@ -0,0 +1,212 @@
|
||||
from typing import Dict, List, Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from scipy import sparse
|
||||
from testing import IteratorForTest, make_batches, make_batches_sparse, make_categorical
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
class TestQuantileDMatrix:
|
||||
def test_basic(self) -> None:
|
||||
n_samples = 234
|
||||
n_features = 8
|
||||
|
||||
rng = np.random.default_rng()
|
||||
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
|
||||
n_samples, n_features
|
||||
)
|
||||
y = rng.normal(0, 3, size=n_samples)
|
||||
Xy = xgb.QuantileDMatrix(X, y)
|
||||
assert Xy.num_row() == n_samples
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
X = sparse.random(n_samples, n_features, density=0.1, format="csr")
|
||||
Xy = xgb.QuantileDMatrix(X, y)
|
||||
assert Xy.num_row() == n_samples
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
X = sparse.random(n_samples, n_features, density=0.8, format="csr")
|
||||
Xy = xgb.QuantileDMatrix(X, y)
|
||||
assert Xy.num_row() == n_samples
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
|
||||
def test_with_iterator(self, sparsity: float) -> None:
|
||||
n_samples_per_batch = 317
|
||||
n_features = 8
|
||||
n_batches = 7
|
||||
|
||||
if sparsity == 0.0:
|
||||
it = IteratorForTest(
|
||||
*make_batches(n_samples_per_batch, n_features, n_batches, False), None
|
||||
)
|
||||
else:
|
||||
it = IteratorForTest(
|
||||
*make_batches_sparse(
|
||||
n_samples_per_batch, n_features, n_batches, sparsity
|
||||
),
|
||||
None
|
||||
)
|
||||
Xy = xgb.QuantileDMatrix(it)
|
||||
assert Xy.num_row() == n_samples_per_batch * n_batches
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9])
|
||||
def test_training(self, sparsity: float) -> None:
|
||||
n_samples_per_batch = 317
|
||||
n_features = 8
|
||||
n_batches = 7
|
||||
if sparsity == 0.0:
|
||||
it = IteratorForTest(
|
||||
*make_batches(n_samples_per_batch, n_features, n_batches, False), None
|
||||
)
|
||||
else:
|
||||
it = IteratorForTest(
|
||||
*make_batches_sparse(
|
||||
n_samples_per_batch, n_features, n_batches, sparsity
|
||||
),
|
||||
None
|
||||
)
|
||||
|
||||
parameters = {"tree_method": "hist", "max_bin": 256}
|
||||
Xy_it = xgb.QuantileDMatrix(it, max_bin=parameters["max_bin"])
|
||||
from_it = xgb.train(parameters, Xy_it)
|
||||
|
||||
X, y, w = it.as_arrays()
|
||||
w_it = Xy_it.get_weight()
|
||||
np.testing.assert_allclose(w_it, w)
|
||||
|
||||
Xy_arr = xgb.DMatrix(X, y, weight=w)
|
||||
from_arr = xgb.train(parameters, Xy_arr)
|
||||
|
||||
np.testing.assert_allclose(from_arr.predict(Xy_it), from_it.predict(Xy_arr))
|
||||
|
||||
y -= y.min()
|
||||
y += 0.01
|
||||
Xy = xgb.QuantileDMatrix(X, y, weight=w)
|
||||
with pytest.raises(ValueError, match=r"Only.*hist.*"):
|
||||
parameters = {
|
||||
"tree_method": "approx",
|
||||
"max_bin": 256,
|
||||
"objective": "reg:gamma",
|
||||
}
|
||||
xgb.train(parameters, Xy)
|
||||
|
||||
def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
|
||||
n_samples, n_features = 2048, 17
|
||||
if enable_cat:
|
||||
X, y = make_categorical(
|
||||
n_samples, n_features, n_categories=13, onehot=False
|
||||
)
|
||||
if tree_method == "gpu_hist":
|
||||
import cudf
|
||||
X = cudf.from_pandas(X)
|
||||
y = cudf.from_pandas(y)
|
||||
else:
|
||||
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
|
||||
n_samples, n_features
|
||||
)
|
||||
y = rng.normal(0, 3, size=n_samples)
|
||||
|
||||
# Use ref
|
||||
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
|
||||
Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
|
||||
qdm_results: Dict[str, Dict[str, List[float]]] = {}
|
||||
xgb.train(
|
||||
{"tree_method": tree_method},
|
||||
Xy,
|
||||
evals=[(Xy, "Train"), (Xy_valid, "valid")],
|
||||
evals_result=qdm_results,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"]
|
||||
)
|
||||
# No ref
|
||||
Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
|
||||
qdm_results = {}
|
||||
xgb.train(
|
||||
{"tree_method": tree_method},
|
||||
Xy,
|
||||
evals=[(Xy, "Train"), (Xy_valid, "valid")],
|
||||
evals_result=qdm_results,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"]
|
||||
)
|
||||
|
||||
# Different number of features
|
||||
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
|
||||
dXy = xgb.DMatrix(X, y, enable_categorical=enable_cat)
|
||||
|
||||
n_samples, n_features = 256, 15
|
||||
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
|
||||
n_samples, n_features
|
||||
)
|
||||
y = rng.normal(0, 3, size=n_samples)
|
||||
with pytest.raises(ValueError, match=r".*features\."):
|
||||
xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
|
||||
|
||||
# Compare training results
|
||||
n_samples, n_features = 256, 17
|
||||
if enable_cat:
|
||||
X, y = make_categorical(n_samples, n_features, 13, onehot=False)
|
||||
if tree_method == "gpu_hist":
|
||||
import cudf
|
||||
X = cudf.from_pandas(X)
|
||||
y = cudf.from_pandas(y)
|
||||
else:
|
||||
X = rng.normal(loc=0, scale=3, size=n_samples * n_features).reshape(
|
||||
n_samples, n_features
|
||||
)
|
||||
y = rng.normal(0, 3, size=n_samples)
|
||||
Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
|
||||
# use DMatrix as ref
|
||||
Xy_valid_d = xgb.QuantileDMatrix(X, y, ref=dXy, enable_categorical=enable_cat)
|
||||
dXy_valid = xgb.DMatrix(X, y, enable_categorical=enable_cat)
|
||||
|
||||
qdm_results = {}
|
||||
xgb.train(
|
||||
{"tree_method": tree_method},
|
||||
Xy,
|
||||
evals=[(Xy, "Train"), (Xy_valid, "valid")],
|
||||
evals_result=qdm_results,
|
||||
)
|
||||
|
||||
dm_results: Dict[str, Dict[str, List[float]]] = {}
|
||||
xgb.train(
|
||||
{"tree_method": tree_method},
|
||||
dXy,
|
||||
evals=[(dXy, "Train"), (dXy_valid, "valid"), (Xy_valid_d, "dvalid")],
|
||||
evals_result=dm_results,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
dm_results["Train"]["rmse"], qdm_results["Train"]["rmse"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
dm_results["valid"]["rmse"], qdm_results["valid"]["rmse"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
|
||||
)
|
||||
|
||||
def test_ref_dmatrix(self) -> None:
|
||||
rng = np.random.RandomState(1994)
|
||||
self.run_ref_dmatrix(rng, "hist", True)
|
||||
self.run_ref_dmatrix(rng, "hist", False)
|
||||
|
||||
def test_predict(self) -> None:
|
||||
n_samples, n_features = 16, 2
|
||||
X, y = make_categorical(
|
||||
n_samples, n_features, n_categories=13, onehot=False
|
||||
)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
|
||||
booster = xgb.train({"tree_method": "hist"}, Xy)
|
||||
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
a = booster.predict(Xy)
|
||||
qXy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
|
||||
b = booster.predict(qXy)
|
||||
np.testing.assert_allclose(a, b)
|
||||
@ -1382,6 +1382,42 @@ class TestWithDask:
|
||||
num_rounds = 30
|
||||
self.run_updater_test(client, params, num_rounds, dataset, 'hist')
|
||||
|
||||
def test_quantile_dmatrix(self, client: Client) -> None:
|
||||
X, y = make_categorical(client, 10000, 30, 13)
|
||||
|
||||
Xy = xgb.dask.DaskDMatrix(client, X, y, enable_categorical=True)
|
||||
valid_Xy = xgb.dask.DaskDMatrix(client, X, y, enable_categorical=True)
|
||||
|
||||
output = xgb.dask.train(
|
||||
client,
|
||||
{"tree_method": "hist"},
|
||||
Xy,
|
||||
num_boost_round=10,
|
||||
evals=[(Xy, "Train"), (valid_Xy, "Valid")]
|
||||
)
|
||||
dmatrix_hist = output["history"]
|
||||
|
||||
Xy = xgb.dask.DaskQuantileDMatrix(client, X, y, enable_categorical=True)
|
||||
valid_Xy = xgb.dask.DaskQuantileDMatrix(
|
||||
client, X, y, enable_categorical=True, ref=Xy
|
||||
)
|
||||
|
||||
output = xgb.dask.train(
|
||||
client,
|
||||
{"tree_method": "hist"},
|
||||
Xy,
|
||||
num_boost_round=10,
|
||||
evals=[(Xy, "Train"), (valid_Xy, "Valid")]
|
||||
)
|
||||
quantile_hist = output["history"]
|
||||
|
||||
np.testing.assert_allclose(
|
||||
quantile_hist["Train"]["rmse"], dmatrix_hist["Train"]["rmse"]
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
quantile_hist["Valid"]["rmse"], dmatrix_hist["Valid"]["rmse"]
|
||||
)
|
||||
|
||||
@given(params=exact_parameter_strategy,
|
||||
dataset=tm.dataset_strategy)
|
||||
@settings(deadline=None, suppress_health_check=suppress, print_blob=True)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import multiprocessing
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, List, Sequence, Callable
|
||||
import urllib
|
||||
import zipfile
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
@ -180,79 +180,148 @@ def skip_s390x():
|
||||
|
||||
|
||||
class IteratorForTest(xgb.core.DataIter):
|
||||
def __init__(self, X, y):
|
||||
def __init__(
|
||||
self,
|
||||
X: Sequence,
|
||||
y: Sequence,
|
||||
w: Optional[Sequence],
|
||||
cache: Optional[str] = "./"
|
||||
) -> None:
|
||||
assert len(X) == len(y)
|
||||
self.X = X
|
||||
self.y = y
|
||||
self.w = w
|
||||
self.it = 0
|
||||
super().__init__("./")
|
||||
super().__init__(cache)
|
||||
|
||||
def next(self, input_data):
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == len(self.X):
|
||||
return 0
|
||||
# Use copy to make sure the iterator doesn't hold a reference to the data.
|
||||
input_data(data=self.X[self.it].copy(), label=self.y[self.it].copy())
|
||||
input_data(
|
||||
data=self.X[self.it].copy(),
|
||||
label=self.y[self.it].copy(),
|
||||
weight=self.w[self.it].copy() if self.w else None,
|
||||
)
|
||||
gc.collect() # clear up the copy, see if XGBoost access freed memory.
|
||||
self.it += 1
|
||||
return 1
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.it = 0
|
||||
|
||||
def as_arrays(self):
|
||||
def as_arrays(
|
||||
self,
|
||||
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], np.ndarray, np.ndarray]:
|
||||
if isinstance(self.X[0], sparse.csr_matrix):
|
||||
X = sparse.vstack(self.X, format="csr")
|
||||
else:
|
||||
X = np.concatenate(self.X, axis=0)
|
||||
y = np.concatenate(self.y, axis=0)
|
||||
return X, y
|
||||
w = np.concatenate(self.w, axis=0)
|
||||
return X, y, w
|
||||
|
||||
|
||||
def make_batches(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
X = []
|
||||
y = []
|
||||
w = []
|
||||
if use_cupy:
|
||||
import cupy
|
||||
|
||||
rng = cupy.random.RandomState(1994)
|
||||
else:
|
||||
rng = np.random.RandomState(1994)
|
||||
for i in range(n_batches):
|
||||
_X = rng.randn(n_samples_per_batch, n_features)
|
||||
_y = rng.randn(n_samples_per_batch)
|
||||
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
|
||||
X.append(_X)
|
||||
y.append(_y)
|
||||
w.append(_w)
|
||||
return X, y, w
|
||||
|
||||
|
||||
def make_batches_sparse(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, sparsity: float
|
||||
) -> Tuple[List[sparse.csr_matrix], List[np.ndarray], List[np.ndarray]]:
|
||||
X = []
|
||||
y = []
|
||||
w = []
|
||||
rng = np.random.RandomState(1994)
|
||||
for i in range(n_batches):
|
||||
_X = sparse.random(
|
||||
n_samples_per_batch,
|
||||
n_features,
|
||||
1.0 - sparsity,
|
||||
format="csr",
|
||||
dtype=np.float32,
|
||||
random_state=rng,
|
||||
)
|
||||
_y = rng.randn(n_samples_per_batch)
|
||||
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
|
||||
X.append(_X)
|
||||
y.append(_y)
|
||||
w.append(_w)
|
||||
return X, y, w
|
||||
|
||||
|
||||
# Contains a dataset in numpy format as well as the relevant objective and metric
|
||||
class TestDataset:
|
||||
def __init__(self, name, get_dataset, objective, metric):
|
||||
def __init__(
|
||||
self, name: str, get_dataset: Callable, objective: str, metric: str
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.objective = objective
|
||||
self.metric = metric
|
||||
self.X, self.y = get_dataset()
|
||||
self.w = None
|
||||
self.w: Optional[np.ndarray] = None
|
||||
self.margin: Optional[np.ndarray] = None
|
||||
|
||||
def set_params(self, params_in):
|
||||
def set_params(self, params_in: Dict[str, Any]) -> Dict[str, Any]:
|
||||
params_in['objective'] = self.objective
|
||||
params_in['eval_metric'] = self.metric
|
||||
if self.objective == "multi:softmax":
|
||||
params_in["num_class"] = int(np.max(self.y) + 1)
|
||||
return params_in
|
||||
|
||||
def get_dmat(self):
|
||||
def get_dmat(self) -> xgb.DMatrix:
|
||||
return xgb.DMatrix(
|
||||
self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True
|
||||
)
|
||||
|
||||
def get_device_dmat(self):
|
||||
def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix:
|
||||
w = None if self.w is None else cp.array(self.w)
|
||||
X = cp.array(self.X, dtype=np.float32)
|
||||
y = cp.array(self.y, dtype=np.float32)
|
||||
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
|
||||
|
||||
def get_external_dmat(self):
|
||||
def get_external_dmat(self) -> xgb.DMatrix:
|
||||
n_samples = self.X.shape[0]
|
||||
n_batches = 10
|
||||
per_batch = n_samples // n_batches + 1
|
||||
|
||||
predictor = []
|
||||
response = []
|
||||
weight = []
|
||||
for i in range(n_batches):
|
||||
beg = i * per_batch
|
||||
end = min((i + 1) * per_batch, n_samples)
|
||||
assert end != beg
|
||||
X = self.X[beg: end, ...]
|
||||
y = self.y[beg: end]
|
||||
w = self.w[beg: end] if self.w is not None else None
|
||||
predictor.append(X)
|
||||
response.append(y)
|
||||
if w is not None:
|
||||
weight.append(w)
|
||||
|
||||
it = IteratorForTest(predictor, response)
|
||||
it = IteratorForTest(predictor, response, weight if weight else None)
|
||||
return xgb.DMatrix(it)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import xgboost as xgb
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user