From 34937fea41a4d4f89738668d1ab65b2d89e34b74 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 29 Aug 2024 04:08:25 +0800 Subject: [PATCH] [EM] Python wrapper for the `ExtMemQuantileDMatrix`. (#10762) Not exposed to the document yet. - Add C API. - Add Python API. - Basic CPU tests. --- include/xgboost/c_api.h | 53 +++++++++++++---- python-package/xgboost/__init__.py | 11 +++- python-package/xgboost/core.py | 69 +++++++++++++++++++++-- python-package/xgboost/testing/updater.py | 38 +++++++++++++ src/c_api/c_api.cc | 49 +++++++++++++--- src/data/batch_utils.cc | 3 +- tests/python/test_data_iterator.py | 12 +++- 7 files changed, 208 insertions(+), 27 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 9f72d1e13..ffff11ddb 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -472,37 +472,66 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy * @example external_memory.c */ -/*! - * \brief Create a Quantile DMatrix with data iterator. +/** + * @brief Create a Quantile DMatrix with data iterator. * * Short note for how to use the second set of callback for (GPU)Hist tree method: * * - Step 0: Define a data iterator with 2 methods `reset`, and `next`. - * - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle. + * - Step 1: Create a DMatrix proxy by @ref XGProxyDMatrixCreate and hold the handle. * - Step 2: Pass the iterator handle, proxy handle and 2 methods into * `XGQuantileDMatrixCreateFromCallback`. * - Step 3: Call appropriate data setters in `next` functions. * * See test_iterative_dmatrix.cu or Python interface for examples. * - * \param iter A handle to external data iterator. - * \param proxy A DMatrix proxy handle created by \ref XGProxyDMatrixCreate. - * \param ref Reference DMatrix for providing quantile information. - * \param reset Callback function resetting the iterator state. - * \param next Callback function yielding the next batch of data. - * \param config JSON encoded parameters for DMatrix construction. Accepted fields are: + * @param iter A handle to external data iterator. + * @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate. + * @param ref Reference DMatrix for providing quantile information. + * @param reset Callback function resetting the iterator state. + * @param next Callback function yielding the next batch of data. + * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: * - missing: Which value to represent missing value * - nthread (optional): Number of threads used for initializing DMatrix. - * - max_bin (optional): Maximum number of bins for building histogram. - * \param out The created Quantile DMatrix. + * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with + the corresponding booster training parameter. + * @param out The created Quantile DMatrix. * - * \return 0 when success, -1 when failure happens + * @return 0 when success, -1 when failure happens */ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, DataIterHandle ref, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, char const *config, DMatrixHandle *out); +/** + * @brief Create a Quantile DMatrix backed by external memory. + * + * @since 3.0.0 + * + * @note This is still under development, not ready for test yet. + * + * @param iter A handle to external data iterator. + * @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate. + * @param ref Reference DMatrix for providing quantile information. + * @param reset Callback function resetting the iterator state. + * @param next Callback function yielding the next batch of data. + * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: + * - missing: Which value to represent missing value + * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. + * - nthread (optional): Number of threads used for initializing DMatrix. + * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with + the corresponding booster training parameter. + * @param out The created Quantile DMatrix. + * + * @return 0 when success, -1 when failure happens + */ +XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterHandle ref, + DataIterResetCallback *reset, + XGDMatrixCallbackNext *next, + char const *config, DMatrixHandle *out); + /*! * \brief Create a Device Quantile DMatrix with data iterator. * \deprecated since 1.7.0 diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index f6060973e..3030ad2eb 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -5,7 +5,15 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md from . import tracker # noqa from . import collective, dask -from .core import Booster, DataIter, DMatrix, QuantileDMatrix, _py_version, build_info +from .core import ( + Booster, + DataIter, + DMatrix, + ExtMemQuantileDMatrix, + QuantileDMatrix, + _py_version, + build_info, +) from .tracker import RabitTracker # noqa from .training import cv, train @@ -31,6 +39,7 @@ __all__ = [ # core "DMatrix", "QuantileDMatrix", + "ExtMemQuantileDMatrix", "Booster", "DataIter", "train", diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index b65154cad..8f6e560e4 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -526,8 +526,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes on_host : Whether the data should be cached on host memory instead of harddrive when using GPU with external memory. If set to true, then the "external memory" would - simply be CPU (host) memory. This is still working in progress, not ready for - test yet. + simply be CPU (host) memory. + + .. versionadded:: 3.0.0 + + .. warning:: + + This is still working in progress, not ready for test yet. """ @@ -927,8 +932,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m if feature_types is not None: self.feature_types = feature_types - def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: - it = iterator + def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None: args = make_jcargs( missing=self.missing, nthread=self.nthread, @@ -1673,6 +1677,63 @@ class QuantileDMatrix(DMatrix): self.handle = handle +class ExtMemQuantileDMatrix(DMatrix): + """The external memory version of the :py:class:`QuantileDMatrix`. + + .. warning:: + + This is still working in progress, not ready for test yet. + + .. versionadded:: 3.0.0 + + """ + + @_deprecate_positional_args + def __init__( # pylint: disable=super-init-not-called + self, + data: DataIter, + missing: Optional[float] = None, + nthread: Optional[int] = None, + max_bin: Optional[int] = None, + ref: Optional[DMatrix] = None, + enable_categorical: bool = False, + ) -> None: + self.max_bin = max_bin + self.missing = missing if missing is not None else np.nan + self.nthread = nthread if nthread is not None else -1 + + self._init(data, ref, enable_categorical) + assert self.handle is not None + + def _init( + self, it: DataIter, ref: Optional[DMatrix], enable_categorical: bool + ) -> None: + args = make_jcargs( + missing=self.missing, + nthread=self.nthread, + cache_prefix=it.cache_prefix if it.cache_prefix else "", + on_host=it.on_host, + ) + handle = ctypes.c_void_p() + reset_callback, next_callback = it.get_callbacks(enable_categorical) + # We don't need the iter handle (hence None) in Python as reset,next callbacks + # are member functions, and ctypes can handle the `self` parameter + # automatically. + ret = _LIB.XGExtMemQuantileDMatrixCreateFromCallback( + None, # iter + it.proxy.handle, # proxy + ref.handle if ref is not None else ref, # ref + reset_callback, # reset + next_callback, # next + args, # config + ctypes.byref(handle), # out + ) + it.reraise() + # delay check_call to throw intermediate exception first + _check_call(ret) + self.handle = handle + + Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index 7063a7b01..3a8715a4d 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -5,6 +5,7 @@ from functools import partial, update_wrapper from typing import Any, Dict, List import numpy as np +import pytest import xgboost as xgb import xgboost.testing as tm @@ -194,6 +195,43 @@ def check_quantile_loss_extmem( np.testing.assert_allclose(predt, predt_it) +def check_extmem_qdm( + n_samples_per_batch: int, + n_features: int, + n_batches: int, + device: str, + on_host: bool, +) -> None: + """Basic test for the `ExtMemQuantileDMatrix`.""" + + it = tm.IteratorForTest( + *tm.make_batches( + n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu" + ), + cache="cache", + on_host=on_host, + ) + Xy_it = xgb.ExtMemQuantileDMatrix(it) + with pytest.raises(ValueError, match="Only the `hist`"): + booster_it = xgb.train( + {"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8 + ) + + booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8) + X, y, w = it.as_arrays() + Xy = xgb.QuantileDMatrix(X, y, weight=w) + booster = xgb.train({"device": device}, Xy, num_boost_round=8) + + cut_it = Xy_it.get_quantile_cut() + cut = Xy.get_quantile_cut() + np.testing.assert_allclose(cut_it[0], cut[0]) + np.testing.assert_allclose(cut_it[1], cut[1]) + + predt_it = booster_it.predict(Xy_it) + predt = booster.predict(Xy) + np.testing.assert_allclose(predt_it, predt) + + def check_cut( n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any ) -> None: diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 737118865..0b5468d29 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy auto jconfig = Json::Load(StringView{config}); auto missing = GetMissing(jconfig); std::string cache = RequiredArg(jconfig, "cache_prefix", __func__); - auto n_threads = OptionalArg(jconfig, "nthread", 0); - auto on_host = OptionalArg(jconfig, "on_host", false); + auto n_threads = OptionalArg(jconfig, "nthread", 0); + auto on_host = OptionalArg(jconfig, "on_host", false); xgboost_CHECK_C_ARG_PTR(next); xgboost_CHECK_C_ARG_PTR(reset); @@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy API_END(); } + XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, @@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr API_END(); } -XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, - DataIterHandle ref, DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, char const *config, - DMatrixHandle *out) { - API_BEGIN(); +namespace { +std::shared_ptr GetRefDMatrix(DataIterHandle ref) { std::shared_ptr _ref{nullptr}; if (ref) { auto pp_ref = static_cast *>(ref); @@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand _ref = *pp_ref; CHECK(_ref) << err; } + return _ref; +} +} // namespace + +XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterHandle ref, DataIterResetCallback *reset, + XGDMatrixCallbackNext *next, char const *config, + DMatrixHandle *out) { + API_BEGIN(); + std::shared_ptr p_ref{GetRefDMatrix(ref)}; xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); @@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr{ - xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)}; + xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)}; + API_END(); +} + +XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterHandle ref, + DataIterResetCallback *reset, + XGDMatrixCallbackNext *next, + char const *config, DMatrixHandle *out) { + API_BEGIN(); + std::shared_ptr p_ref{GetRefDMatrix(ref)}; + + xgboost_CHECK_C_ARG_PTR(config); + auto jconfig = Json::Load(StringView{config}); + auto missing = GetMissing(jconfig); + auto n_threads = OptionalArg(jconfig, "nthread", 0); + auto max_bin = OptionalArg(jconfig, "max_bin", 256); + auto on_host = OptionalArg(jconfig, "on_host", false); + std::string cache = RequiredArg(jconfig, "cache_prefix", __func__); + + xgboost_CHECK_C_ARG_PTR(next); + xgboost_CHECK_C_ARG_PTR(reset); + xgboost_CHECK_C_ARG_PTR(out); + + *out = new std::shared_ptr{xgboost::DMatrix::Create( + iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)}; API_END(); } diff --git a/src/data/batch_utils.cc b/src/data/batch_utils.cc index 0727dfca7..926650f9f 100644 --- a/src/data/batch_utils.cc +++ b/src/data/batch_utils.cc @@ -8,6 +8,7 @@ namespace xgboost::data::detail { void CheckParam(BatchParam const& init, BatchParam const& param) { CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin(); - CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`."; + CHECK(!param.regen && param.hess.empty()) + << "Only the `hist` tree method can use the `QuantileDMatrix`."; } } // namespace xgboost::data::detail diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index a42ad0f75..560e22d05 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -12,7 +12,7 @@ import xgboost as xgb from xgboost import testing as tm from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.testing import IteratorForTest, make_batches, non_increasing -from xgboost.testing.updater import check_quantile_loss_extmem +from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem pytestmark = tm.timeout(30) @@ -304,3 +304,13 @@ def test_quantile_objective( "approx", "cpu", ) + + +@given( + strategies.integers(1, 4096), + strategies.integers(1, 8), + strategies.integers(1, 4), +) +@settings(deadline=None, max_examples=10, print_blob=True) +def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None: + check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)