[EM] Python wrapper for the ExtMemQuantileDMatrix. (#10762)

Not exposed to the document yet.

- Add C API.
- Add Python API.
- Basic CPU tests.
This commit is contained in:
Jiaming Yuan 2024-08-29 04:08:25 +08:00 committed by GitHub
parent 7510a87466
commit 34937fea41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 208 additions and 27 deletions

View File

@ -472,37 +472,66 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
* @example external_memory.c
*/
/*!
* \brief Create a Quantile DMatrix with data iterator.
/**
* @brief Create a Quantile DMatrix with data iterator.
*
* Short note for how to use the second set of callback for (GPU)Hist tree method:
*
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
* - Step 1: Create a DMatrix proxy by @ref XGProxyDMatrixCreate and hold the handle.
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
* `XGQuantileDMatrixCreateFromCallback`.
* - Step 3: Call appropriate data setters in `next` functions.
*
* See test_iterative_dmatrix.cu or Python interface for examples.
*
* \param iter A handle to external data iterator.
* \param proxy A DMatrix proxy handle created by \ref XGProxyDMatrixCreate.
* \param ref Reference DMatrix for providing quantile information.
* \param reset Callback function resetting the iterator state.
* \param next Callback function yielding the next batch of data.
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* @param iter A handle to external data iterator.
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
* @param ref Reference DMatrix for providing quantile information.
* @param reset Callback function resetting the iterator state.
* @param next Callback function yielding the next batch of data.
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - missing: Which value to represent missing value
* - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram.
* \param out The created Quantile DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
the corresponding booster training parameter.
* @param out The created Quantile DMatrix.
*
* \return 0 when success, -1 when failure happens
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out);
/**
* @brief Create a Quantile DMatrix backed by external memory.
*
* @since 3.0.0
*
* @note This is still under development, not ready for test yet.
*
* @param iter A handle to external data iterator.
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
* @param ref Reference DMatrix for providing quantile information.
* @param reset Callback function resetting the iterator state.
* @param next Callback function yielding the next batch of data.
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - missing: Which value to represent missing value
* - cache_prefix: The path of cache file, caller must initialize all the directories in this path.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
the corresponding booster training parameter.
* @param out The created Quantile DMatrix.
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const *config, DMatrixHandle *out);
/*!
* \brief Create a Device Quantile DMatrix with data iterator.
* \deprecated since 1.7.0

View File

@ -5,7 +5,15 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
from . import tracker # noqa
from . import collective, dask
from .core import Booster, DataIter, DMatrix, QuantileDMatrix, _py_version, build_info
from .core import (
Booster,
DataIter,
DMatrix,
ExtMemQuantileDMatrix,
QuantileDMatrix,
_py_version,
build_info,
)
from .tracker import RabitTracker # noqa
from .training import cv, train
@ -31,6 +39,7 @@ __all__ = [
# core
"DMatrix",
"QuantileDMatrix",
"ExtMemQuantileDMatrix",
"Booster",
"DataIter",
"train",

View File

@ -526,8 +526,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
on_host :
Whether the data should be cached on host memory instead of harddrive when using
GPU with external memory. If set to true, then the "external memory" would
simply be CPU (host) memory. This is still working in progress, not ready for
test yet.
simply be CPU (host) memory.
.. versionadded:: 3.0.0
.. warning::
This is still working in progress, not ready for test yet.
"""
@ -927,8 +932,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
if feature_types is not None:
self.feature_types = feature_types
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
it = iterator
def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None:
args = make_jcargs(
missing=self.missing,
nthread=self.nthread,
@ -1673,6 +1677,63 @@ class QuantileDMatrix(DMatrix):
self.handle = handle
class ExtMemQuantileDMatrix(DMatrix):
"""The external memory version of the :py:class:`QuantileDMatrix`.
.. warning::
This is still working in progress, not ready for test yet.
.. versionadded:: 3.0.0
"""
@_deprecate_positional_args
def __init__( # pylint: disable=super-init-not-called
self,
data: DataIter,
missing: Optional[float] = None,
nthread: Optional[int] = None,
max_bin: Optional[int] = None,
ref: Optional[DMatrix] = None,
enable_categorical: bool = False,
) -> None:
self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else -1
self._init(data, ref, enable_categorical)
assert self.handle is not None
def _init(
self, it: DataIter, ref: Optional[DMatrix], enable_categorical: bool
) -> None:
args = make_jcargs(
missing=self.missing,
nthread=self.nthread,
cache_prefix=it.cache_prefix if it.cache_prefix else "",
on_host=it.on_host,
)
handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(enable_categorical)
# We don't need the iter handle (hence None) in Python as reset,next callbacks
# are member functions, and ctypes can handle the `self` parameter
# automatically.
ret = _LIB.XGExtMemQuantileDMatrixCreateFromCallback(
None, # iter
it.proxy.handle, # proxy
ref.handle if ref is not None else ref, # ref
reset_callback, # reset
next_callback, # next
args, # config
ctypes.byref(handle), # out
)
it.reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]

View File

@ -5,6 +5,7 @@ from functools import partial, update_wrapper
from typing import Any, Dict, List
import numpy as np
import pytest
import xgboost as xgb
import xgboost.testing as tm
@ -194,6 +195,43 @@ def check_quantile_loss_extmem(
np.testing.assert_allclose(predt, predt_it)
def check_extmem_qdm(
n_samples_per_batch: int,
n_features: int,
n_batches: int,
device: str,
on_host: bool,
) -> None:
"""Basic test for the `ExtMemQuantileDMatrix`."""
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
),
cache="cache",
on_host=on_host,
)
Xy_it = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="Only the `hist`"):
booster_it = xgb.train(
{"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8
)
booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
X, y, w = it.as_arrays()
Xy = xgb.QuantileDMatrix(X, y, weight=w)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])
predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)
np.testing.assert_allclose(predt_it, predt)
def check_cut(
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
) -> None:

View File

@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
API_END();
}
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
namespace {
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
std::shared_ptr<DMatrix> _ref{nullptr};
if (ref) {
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
_ref = *pp_ref;
CHECK(_ref) << err;
}
return _ref;
}
} // namespace
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
API_END();
}
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const *config, DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
API_END();
}

View File

@ -8,6 +8,7 @@
namespace xgboost::data::detail {
void CheckParam(BatchParam const& init, BatchParam const& param) {
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
CHECK(!param.regen && param.hess.empty())
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
}
} // namespace xgboost::data::detail

View File

@ -12,7 +12,7 @@ import xgboost as xgb
from xgboost import testing as tm
from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing
from xgboost.testing.updater import check_quantile_loss_extmem
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
pytestmark = tm.timeout(30)
@ -304,3 +304,13 @@ def test_quantile_objective(
"approx",
"cpu",
)
@given(
strategies.integers(1, 4096),
strategies.integers(1, 8),
strategies.integers(1, 4),
)
@settings(deadline=None, max_examples=10, print_blob=True)
def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None:
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)