[EM] Python wrapper for the ExtMemQuantileDMatrix. (#10762)

Not exposed to the document yet.

- Add C API.
- Add Python API.
- Basic CPU tests.
This commit is contained in:
Jiaming Yuan 2024-08-29 04:08:25 +08:00 committed by GitHub
parent 7510a87466
commit 34937fea41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 208 additions and 27 deletions

View File

@ -472,37 +472,66 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
* @example external_memory.c * @example external_memory.c
*/ */
/*! /**
* \brief Create a Quantile DMatrix with data iterator. * @brief Create a Quantile DMatrix with data iterator.
* *
* Short note for how to use the second set of callback for (GPU)Hist tree method: * Short note for how to use the second set of callback for (GPU)Hist tree method:
* *
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`. * - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle. * - Step 1: Create a DMatrix proxy by @ref XGProxyDMatrixCreate and hold the handle.
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into * - Step 2: Pass the iterator handle, proxy handle and 2 methods into
* `XGQuantileDMatrixCreateFromCallback`. * `XGQuantileDMatrixCreateFromCallback`.
* - Step 3: Call appropriate data setters in `next` functions. * - Step 3: Call appropriate data setters in `next` functions.
* *
* See test_iterative_dmatrix.cu or Python interface for examples. * See test_iterative_dmatrix.cu or Python interface for examples.
* *
* \param iter A handle to external data iterator. * @param iter A handle to external data iterator.
* \param proxy A DMatrix proxy handle created by \ref XGProxyDMatrixCreate. * @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
* \param ref Reference DMatrix for providing quantile information. * @param ref Reference DMatrix for providing quantile information.
* \param reset Callback function resetting the iterator state. * @param reset Callback function resetting the iterator state.
* \param next Callback function yielding the next batch of data. * @param next Callback function yielding the next batch of data.
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are: * @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - missing: Which value to represent missing value * - missing: Which value to represent missing value
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram. * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
* \param out The created Quantile DMatrix. the corresponding booster training parameter.
* @param out The created Quantile DMatrix.
* *
* \return 0 when success, -1 when failure happens * @return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset, DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config, XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out); DMatrixHandle *out);
/**
* @brief Create a Quantile DMatrix backed by external memory.
*
* @since 3.0.0
*
* @note This is still under development, not ready for test yet.
*
* @param iter A handle to external data iterator.
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
* @param ref Reference DMatrix for providing quantile information.
* @param reset Callback function resetting the iterator state.
* @param next Callback function yielding the next batch of data.
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - missing: Which value to represent missing value
* - cache_prefix: The path of cache file, caller must initialize all the directories in this path.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
the corresponding booster training parameter.
* @param out The created Quantile DMatrix.
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const *config, DMatrixHandle *out);
/*! /*!
* \brief Create a Device Quantile DMatrix with data iterator. * \brief Create a Device Quantile DMatrix with data iterator.
* \deprecated since 1.7.0 * \deprecated since 1.7.0

View File

@ -5,7 +5,15 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
from . import tracker # noqa from . import tracker # noqa
from . import collective, dask from . import collective, dask
from .core import Booster, DataIter, DMatrix, QuantileDMatrix, _py_version, build_info from .core import (
Booster,
DataIter,
DMatrix,
ExtMemQuantileDMatrix,
QuantileDMatrix,
_py_version,
build_info,
)
from .tracker import RabitTracker # noqa from .tracker import RabitTracker # noqa
from .training import cv, train from .training import cv, train
@ -31,6 +39,7 @@ __all__ = [
# core # core
"DMatrix", "DMatrix",
"QuantileDMatrix", "QuantileDMatrix",
"ExtMemQuantileDMatrix",
"Booster", "Booster",
"DataIter", "DataIter",
"train", "train",

View File

@ -526,8 +526,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
on_host : on_host :
Whether the data should be cached on host memory instead of harddrive when using Whether the data should be cached on host memory instead of harddrive when using
GPU with external memory. If set to true, then the "external memory" would GPU with external memory. If set to true, then the "external memory" would
simply be CPU (host) memory. This is still working in progress, not ready for simply be CPU (host) memory.
test yet.
.. versionadded:: 3.0.0
.. warning::
This is still working in progress, not ready for test yet.
""" """
@ -927,8 +932,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
if feature_types is not None: if feature_types is not None:
self.feature_types = feature_types self.feature_types = feature_types
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None:
it = iterator
args = make_jcargs( args = make_jcargs(
missing=self.missing, missing=self.missing,
nthread=self.nthread, nthread=self.nthread,
@ -1673,6 +1677,63 @@ class QuantileDMatrix(DMatrix):
self.handle = handle self.handle = handle
class ExtMemQuantileDMatrix(DMatrix):
"""The external memory version of the :py:class:`QuantileDMatrix`.
.. warning::
This is still working in progress, not ready for test yet.
.. versionadded:: 3.0.0
"""
@_deprecate_positional_args
def __init__( # pylint: disable=super-init-not-called
self,
data: DataIter,
missing: Optional[float] = None,
nthread: Optional[int] = None,
max_bin: Optional[int] = None,
ref: Optional[DMatrix] = None,
enable_categorical: bool = False,
) -> None:
self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else -1
self._init(data, ref, enable_categorical)
assert self.handle is not None
def _init(
self, it: DataIter, ref: Optional[DMatrix], enable_categorical: bool
) -> None:
args = make_jcargs(
missing=self.missing,
nthread=self.nthread,
cache_prefix=it.cache_prefix if it.cache_prefix else "",
on_host=it.on_host,
)
handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(enable_categorical)
# We don't need the iter handle (hence None) in Python as reset,next callbacks
# are member functions, and ctypes can handle the `self` parameter
# automatically.
ret = _LIB.XGExtMemQuantileDMatrixCreateFromCallback(
None, # iter
it.proxy.handle, # proxy
ref.handle if ref is not None else ref, # ref
reset_callback, # reset
next_callback, # next
args, # config
ctypes.byref(handle), # out
)
it.reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]

View File

@ -5,6 +5,7 @@ from functools import partial, update_wrapper
from typing import Any, Dict, List from typing import Any, Dict, List
import numpy as np import numpy as np
import pytest
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import xgboost.testing as tm
@ -194,6 +195,43 @@ def check_quantile_loss_extmem(
np.testing.assert_allclose(predt, predt_it) np.testing.assert_allclose(predt, predt_it)
def check_extmem_qdm(
n_samples_per_batch: int,
n_features: int,
n_batches: int,
device: str,
on_host: bool,
) -> None:
"""Basic test for the `ExtMemQuantileDMatrix`."""
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
),
cache="cache",
on_host=on_host,
)
Xy_it = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="Only the `hist`"):
booster_it = xgb.train(
{"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8
)
booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
X, y, w = it.as_arrays()
Xy = xgb.QuantileDMatrix(X, y, weight=w)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])
predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)
np.testing.assert_allclose(predt_it, predt)
def check_cut( def check_cut(
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
) -> None: ) -> None:

View File

@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
auto jconfig = Json::Load(StringView{config}); auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig); auto missing = GetMissing(jconfig);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__); std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0); auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false); auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
xgboost_CHECK_C_ARG_PTR(next); xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset); xgboost_CHECK_C_ARG_PTR(reset);
@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
API_END(); API_END();
} }
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, XGDMatrixCallbackNext *next, float missing,
@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
API_END(); API_END();
} }
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, namespace {
DataIterHandle ref, DataIterResetCallback *reset, std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> _ref{nullptr}; std::shared_ptr<DMatrix> _ref{nullptr};
if (ref) { if (ref) {
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref); auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
_ref = *pp_ref; _ref = *pp_ref;
CHECK(_ref) << err; CHECK(_ref) << err;
} }
return _ref;
}
} // namespace
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
xgboost_CHECK_C_ARG_PTR(config); xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config}); auto jconfig = Json::Load(StringView{config});
@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{ *out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)}; xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
API_END();
}
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const *config, DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
API_END(); API_END();
} }

View File

@ -8,6 +8,7 @@
namespace xgboost::data::detail { namespace xgboost::data::detail {
void CheckParam(BatchParam const& init, BatchParam const& param) { void CheckParam(BatchParam const& init, BatchParam const& param) {
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin(); CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`."; CHECK(!param.regen && param.hess.empty())
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
} }
} // namespace xgboost::data::detail } // namespace xgboost::data::detail

View File

@ -12,7 +12,7 @@ import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing from xgboost.testing import IteratorForTest, make_batches, non_increasing
from xgboost.testing.updater import check_quantile_loss_extmem from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
pytestmark = tm.timeout(30) pytestmark = tm.timeout(30)
@ -304,3 +304,13 @@ def test_quantile_objective(
"approx", "approx",
"cpu", "cpu",
) )
@given(
strategies.integers(1, 4096),
strategies.integers(1, 8),
strategies.integers(1, 4),
)
@settings(deadline=None, max_examples=10, print_blob=True)
def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None:
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)