[EM] Python wrapper for the ExtMemQuantileDMatrix. (#10762)
Not exposed to the document yet. - Add C API. - Add Python API. - Basic CPU tests.
This commit is contained in:
parent
7510a87466
commit
34937fea41
@ -472,37 +472,66 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
* @example external_memory.c
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \brief Create a Quantile DMatrix with data iterator.
|
||||
/**
|
||||
* @brief Create a Quantile DMatrix with data iterator.
|
||||
*
|
||||
* Short note for how to use the second set of callback for (GPU)Hist tree method:
|
||||
*
|
||||
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
|
||||
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
|
||||
* - Step 1: Create a DMatrix proxy by @ref XGProxyDMatrixCreate and hold the handle.
|
||||
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
|
||||
* `XGQuantileDMatrixCreateFromCallback`.
|
||||
* - Step 3: Call appropriate data setters in `next` functions.
|
||||
*
|
||||
* See test_iterative_dmatrix.cu or Python interface for examples.
|
||||
*
|
||||
* \param iter A handle to external data iterator.
|
||||
* \param proxy A DMatrix proxy handle created by \ref XGProxyDMatrixCreate.
|
||||
* \param ref Reference DMatrix for providing quantile information.
|
||||
* \param reset Callback function resetting the iterator state.
|
||||
* \param next Callback function yielding the next batch of data.
|
||||
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
||||
* @param iter A handle to external data iterator.
|
||||
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
|
||||
* @param ref Reference DMatrix for providing quantile information.
|
||||
* @param reset Callback function resetting the iterator state.
|
||||
* @param next Callback function yielding the next batch of data.
|
||||
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
||||
* - missing: Which value to represent missing value
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - max_bin (optional): Maximum number of bins for building histogram.
|
||||
* \param out The created Quantile DMatrix.
|
||||
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
|
||||
the corresponding booster training parameter.
|
||||
* @param out The created Quantile DMatrix.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out);
|
||||
|
||||
/**
|
||||
* @brief Create a Quantile DMatrix backed by external memory.
|
||||
*
|
||||
* @since 3.0.0
|
||||
*
|
||||
* @note This is still under development, not ready for test yet.
|
||||
*
|
||||
* @param iter A handle to external data iterator.
|
||||
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
|
||||
* @param ref Reference DMatrix for providing quantile information.
|
||||
* @param reset Callback function resetting the iterator state.
|
||||
* @param next Callback function yielding the next batch of data.
|
||||
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
||||
* - missing: Which value to represent missing value
|
||||
* - cache_prefix: The path of cache file, caller must initialize all the directories in this path.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
|
||||
the corresponding booster training parameter.
|
||||
* @param out The created Quantile DMatrix.
|
||||
*
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next,
|
||||
char const *config, DMatrixHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief Create a Device Quantile DMatrix with data iterator.
|
||||
* \deprecated since 1.7.0
|
||||
|
||||
@ -5,7 +5,15 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
|
||||
from . import tracker # noqa
|
||||
from . import collective, dask
|
||||
from .core import Booster, DataIter, DMatrix, QuantileDMatrix, _py_version, build_info
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
DMatrix,
|
||||
ExtMemQuantileDMatrix,
|
||||
QuantileDMatrix,
|
||||
_py_version,
|
||||
build_info,
|
||||
)
|
||||
from .tracker import RabitTracker # noqa
|
||||
from .training import cv, train
|
||||
|
||||
@ -31,6 +39,7 @@ __all__ = [
|
||||
# core
|
||||
"DMatrix",
|
||||
"QuantileDMatrix",
|
||||
"ExtMemQuantileDMatrix",
|
||||
"Booster",
|
||||
"DataIter",
|
||||
"train",
|
||||
|
||||
@ -526,8 +526,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
on_host :
|
||||
Whether the data should be cached on host memory instead of harddrive when using
|
||||
GPU with external memory. If set to true, then the "external memory" would
|
||||
simply be CPU (host) memory. This is still working in progress, not ready for
|
||||
test yet.
|
||||
simply be CPU (host) memory.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
|
||||
.. warning::
|
||||
|
||||
This is still working in progress, not ready for test yet.
|
||||
|
||||
"""
|
||||
|
||||
@ -927,8 +932,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
|
||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
|
||||
it = iterator
|
||||
def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None:
|
||||
args = make_jcargs(
|
||||
missing=self.missing,
|
||||
nthread=self.nthread,
|
||||
@ -1673,6 +1677,63 @@ class QuantileDMatrix(DMatrix):
|
||||
self.handle = handle
|
||||
|
||||
|
||||
class ExtMemQuantileDMatrix(DMatrix):
|
||||
"""The external memory version of the :py:class:`QuantileDMatrix`.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is still working in progress, not ready for test yet.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
|
||||
"""
|
||||
|
||||
@_deprecate_positional_args
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
data: DataIter,
|
||||
missing: Optional[float] = None,
|
||||
nthread: Optional[int] = None,
|
||||
max_bin: Optional[int] = None,
|
||||
ref: Optional[DMatrix] = None,
|
||||
enable_categorical: bool = False,
|
||||
) -> None:
|
||||
self.max_bin = max_bin
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.nthread = nthread if nthread is not None else -1
|
||||
|
||||
self._init(data, ref, enable_categorical)
|
||||
assert self.handle is not None
|
||||
|
||||
def _init(
|
||||
self, it: DataIter, ref: Optional[DMatrix], enable_categorical: bool
|
||||
) -> None:
|
||||
args = make_jcargs(
|
||||
missing=self.missing,
|
||||
nthread=self.nthread,
|
||||
cache_prefix=it.cache_prefix if it.cache_prefix else "",
|
||||
on_host=it.on_host,
|
||||
)
|
||||
handle = ctypes.c_void_p()
|
||||
reset_callback, next_callback = it.get_callbacks(enable_categorical)
|
||||
# We don't need the iter handle (hence None) in Python as reset,next callbacks
|
||||
# are member functions, and ctypes can handle the `self` parameter
|
||||
# automatically.
|
||||
ret = _LIB.XGExtMemQuantileDMatrixCreateFromCallback(
|
||||
None, # iter
|
||||
it.proxy.handle, # proxy
|
||||
ref.handle if ref is not None else ref, # ref
|
||||
reset_callback, # reset
|
||||
next_callback, # next
|
||||
args, # config
|
||||
ctypes.byref(handle), # out
|
||||
)
|
||||
it.reraise()
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
|
||||
|
||||
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from functools import partial, update_wrapper
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
@ -194,6 +195,43 @@ def check_quantile_loss_extmem(
|
||||
np.testing.assert_allclose(predt, predt_it)
|
||||
|
||||
|
||||
def check_extmem_qdm(
|
||||
n_samples_per_batch: int,
|
||||
n_features: int,
|
||||
n_batches: int,
|
||||
device: str,
|
||||
on_host: bool,
|
||||
) -> None:
|
||||
"""Basic test for the `ExtMemQuantileDMatrix`."""
|
||||
|
||||
it = tm.IteratorForTest(
|
||||
*tm.make_batches(
|
||||
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
|
||||
),
|
||||
cache="cache",
|
||||
on_host=on_host,
|
||||
)
|
||||
Xy_it = xgb.ExtMemQuantileDMatrix(it)
|
||||
with pytest.raises(ValueError, match="Only the `hist`"):
|
||||
booster_it = xgb.train(
|
||||
{"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8
|
||||
)
|
||||
|
||||
booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
|
||||
X, y, w = it.as_arrays()
|
||||
Xy = xgb.QuantileDMatrix(X, y, weight=w)
|
||||
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
|
||||
|
||||
cut_it = Xy_it.get_quantile_cut()
|
||||
cut = Xy.get_quantile_cut()
|
||||
np.testing.assert_allclose(cut_it[0], cut[0])
|
||||
np.testing.assert_allclose(cut_it[1], cut[1])
|
||||
|
||||
predt_it = booster_it.predict(Xy_it)
|
||||
predt = booster.predict(Xy)
|
||||
np.testing.assert_allclose(predt_it, predt)
|
||||
|
||||
|
||||
def check_cut(
|
||||
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
|
||||
) -> None:
|
||||
|
||||
@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
||||
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
xgboost_CHECK_C_ARG_PTR(reset);
|
||||
@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
namespace {
|
||||
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
|
||||
std::shared_ptr<DMatrix> _ref{nullptr};
|
||||
if (ref) {
|
||||
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
|
||||
@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
||||
_ref = *pp_ref;
|
||||
CHECK(_ref) << err;
|
||||
}
|
||||
return _ref;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
|
||||
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next,
|
||||
char const *config, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
|
||||
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
|
||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
xgboost_CHECK_C_ARG_PTR(reset);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
|
||||
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
namespace xgboost::data::detail {
|
||||
void CheckParam(BatchParam const& init, BatchParam const& param) {
|
||||
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
|
||||
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
|
||||
CHECK(!param.regen && param.hess.empty())
|
||||
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
|
||||
}
|
||||
} // namespace xgboost::data::detail
|
||||
|
||||
@ -12,7 +12,7 @@ import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.data import SingleBatchInternalIter as SingleBatch
|
||||
from xgboost.testing import IteratorForTest, make_batches, non_increasing
|
||||
from xgboost.testing.updater import check_quantile_loss_extmem
|
||||
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
|
||||
|
||||
pytestmark = tm.timeout(30)
|
||||
|
||||
@ -304,3 +304,13 @@ def test_quantile_objective(
|
||||
"approx",
|
||||
"cpu",
|
||||
)
|
||||
|
||||
|
||||
@given(
|
||||
strategies.integers(1, 4096),
|
||||
strategies.integers(1, 8),
|
||||
strategies.integers(1, 4),
|
||||
)
|
||||
@settings(deadline=None, max_examples=10, print_blob=True)
|
||||
def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None:
|
||||
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user