[EM] Python wrapper for the ExtMemQuantileDMatrix. (#10762)
Not exposed to the document yet. - Add C API. - Add Python API. - Basic CPU tests.
This commit is contained in:
parent
7510a87466
commit
34937fea41
@ -472,37 +472,66 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
|||||||
* @example external_memory.c
|
* @example external_memory.c
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/*!
|
/**
|
||||||
* \brief Create a Quantile DMatrix with data iterator.
|
* @brief Create a Quantile DMatrix with data iterator.
|
||||||
*
|
*
|
||||||
* Short note for how to use the second set of callback for (GPU)Hist tree method:
|
* Short note for how to use the second set of callback for (GPU)Hist tree method:
|
||||||
*
|
*
|
||||||
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
|
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
|
||||||
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
|
* - Step 1: Create a DMatrix proxy by @ref XGProxyDMatrixCreate and hold the handle.
|
||||||
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
|
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
|
||||||
* `XGQuantileDMatrixCreateFromCallback`.
|
* `XGQuantileDMatrixCreateFromCallback`.
|
||||||
* - Step 3: Call appropriate data setters in `next` functions.
|
* - Step 3: Call appropriate data setters in `next` functions.
|
||||||
*
|
*
|
||||||
* See test_iterative_dmatrix.cu or Python interface for examples.
|
* See test_iterative_dmatrix.cu or Python interface for examples.
|
||||||
*
|
*
|
||||||
* \param iter A handle to external data iterator.
|
* @param iter A handle to external data iterator.
|
||||||
* \param proxy A DMatrix proxy handle created by \ref XGProxyDMatrixCreate.
|
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
|
||||||
* \param ref Reference DMatrix for providing quantile information.
|
* @param ref Reference DMatrix for providing quantile information.
|
||||||
* \param reset Callback function resetting the iterator state.
|
* @param reset Callback function resetting the iterator state.
|
||||||
* \param next Callback function yielding the next batch of data.
|
* @param next Callback function yielding the next batch of data.
|
||||||
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
||||||
* - missing: Which value to represent missing value
|
* - missing: Which value to represent missing value
|
||||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||||
* - max_bin (optional): Maximum number of bins for building histogram.
|
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
|
||||||
* \param out The created Quantile DMatrix.
|
the corresponding booster training parameter.
|
||||||
|
* @param out The created Quantile DMatrix.
|
||||||
*
|
*
|
||||||
* \return 0 when success, -1 when failure happens
|
* @return 0 when success, -1 when failure happens
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||||
DataIterHandle ref, DataIterResetCallback *reset,
|
DataIterHandle ref, DataIterResetCallback *reset,
|
||||||
XGDMatrixCallbackNext *next, char const *config,
|
XGDMatrixCallbackNext *next, char const *config,
|
||||||
DMatrixHandle *out);
|
DMatrixHandle *out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create a Quantile DMatrix backed by external memory.
|
||||||
|
*
|
||||||
|
* @since 3.0.0
|
||||||
|
*
|
||||||
|
* @note This is still under development, not ready for test yet.
|
||||||
|
*
|
||||||
|
* @param iter A handle to external data iterator.
|
||||||
|
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
|
||||||
|
* @param ref Reference DMatrix for providing quantile information.
|
||||||
|
* @param reset Callback function resetting the iterator state.
|
||||||
|
* @param next Callback function yielding the next batch of data.
|
||||||
|
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
||||||
|
* - missing: Which value to represent missing value
|
||||||
|
* - cache_prefix: The path of cache file, caller must initialize all the directories in this path.
|
||||||
|
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||||
|
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
|
||||||
|
the corresponding booster training parameter.
|
||||||
|
* @param out The created Quantile DMatrix.
|
||||||
|
*
|
||||||
|
* @return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||||
|
DataIterHandle ref,
|
||||||
|
DataIterResetCallback *reset,
|
||||||
|
XGDMatrixCallbackNext *next,
|
||||||
|
char const *config, DMatrixHandle *out);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Create a Device Quantile DMatrix with data iterator.
|
* \brief Create a Device Quantile DMatrix with data iterator.
|
||||||
* \deprecated since 1.7.0
|
* \deprecated since 1.7.0
|
||||||
|
|||||||
@ -5,7 +5,15 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
|||||||
|
|
||||||
from . import tracker # noqa
|
from . import tracker # noqa
|
||||||
from . import collective, dask
|
from . import collective, dask
|
||||||
from .core import Booster, DataIter, DMatrix, QuantileDMatrix, _py_version, build_info
|
from .core import (
|
||||||
|
Booster,
|
||||||
|
DataIter,
|
||||||
|
DMatrix,
|
||||||
|
ExtMemQuantileDMatrix,
|
||||||
|
QuantileDMatrix,
|
||||||
|
_py_version,
|
||||||
|
build_info,
|
||||||
|
)
|
||||||
from .tracker import RabitTracker # noqa
|
from .tracker import RabitTracker # noqa
|
||||||
from .training import cv, train
|
from .training import cv, train
|
||||||
|
|
||||||
@ -31,6 +39,7 @@ __all__ = [
|
|||||||
# core
|
# core
|
||||||
"DMatrix",
|
"DMatrix",
|
||||||
"QuantileDMatrix",
|
"QuantileDMatrix",
|
||||||
|
"ExtMemQuantileDMatrix",
|
||||||
"Booster",
|
"Booster",
|
||||||
"DataIter",
|
"DataIter",
|
||||||
"train",
|
"train",
|
||||||
|
|||||||
@ -526,8 +526,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
on_host :
|
on_host :
|
||||||
Whether the data should be cached on host memory instead of harddrive when using
|
Whether the data should be cached on host memory instead of harddrive when using
|
||||||
GPU with external memory. If set to true, then the "external memory" would
|
GPU with external memory. If set to true, then the "external memory" would
|
||||||
simply be CPU (host) memory. This is still working in progress, not ready for
|
simply be CPU (host) memory.
|
||||||
test yet.
|
|
||||||
|
.. versionadded:: 3.0.0
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is still working in progress, not ready for test yet.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -927,8 +932,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
|||||||
if feature_types is not None:
|
if feature_types is not None:
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
|
|
||||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
|
def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None:
|
||||||
it = iterator
|
|
||||||
args = make_jcargs(
|
args = make_jcargs(
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
nthread=self.nthread,
|
nthread=self.nthread,
|
||||||
@ -1673,6 +1677,63 @@ class QuantileDMatrix(DMatrix):
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
|
||||||
|
|
||||||
|
class ExtMemQuantileDMatrix(DMatrix):
|
||||||
|
"""The external memory version of the :py:class:`QuantileDMatrix`.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is still working in progress, not ready for test yet.
|
||||||
|
|
||||||
|
.. versionadded:: 3.0.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@_deprecate_positional_args
|
||||||
|
def __init__( # pylint: disable=super-init-not-called
|
||||||
|
self,
|
||||||
|
data: DataIter,
|
||||||
|
missing: Optional[float] = None,
|
||||||
|
nthread: Optional[int] = None,
|
||||||
|
max_bin: Optional[int] = None,
|
||||||
|
ref: Optional[DMatrix] = None,
|
||||||
|
enable_categorical: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.max_bin = max_bin
|
||||||
|
self.missing = missing if missing is not None else np.nan
|
||||||
|
self.nthread = nthread if nthread is not None else -1
|
||||||
|
|
||||||
|
self._init(data, ref, enable_categorical)
|
||||||
|
assert self.handle is not None
|
||||||
|
|
||||||
|
def _init(
|
||||||
|
self, it: DataIter, ref: Optional[DMatrix], enable_categorical: bool
|
||||||
|
) -> None:
|
||||||
|
args = make_jcargs(
|
||||||
|
missing=self.missing,
|
||||||
|
nthread=self.nthread,
|
||||||
|
cache_prefix=it.cache_prefix if it.cache_prefix else "",
|
||||||
|
on_host=it.on_host,
|
||||||
|
)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
reset_callback, next_callback = it.get_callbacks(enable_categorical)
|
||||||
|
# We don't need the iter handle (hence None) in Python as reset,next callbacks
|
||||||
|
# are member functions, and ctypes can handle the `self` parameter
|
||||||
|
# automatically.
|
||||||
|
ret = _LIB.XGExtMemQuantileDMatrixCreateFromCallback(
|
||||||
|
None, # iter
|
||||||
|
it.proxy.handle, # proxy
|
||||||
|
ref.handle if ref is not None else ref, # ref
|
||||||
|
reset_callback, # reset
|
||||||
|
next_callback, # next
|
||||||
|
args, # config
|
||||||
|
ctypes.byref(handle), # out
|
||||||
|
)
|
||||||
|
it.reraise()
|
||||||
|
# delay check_call to throw intermediate exception first
|
||||||
|
_check_call(ret)
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
|
|
||||||
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from functools import partial, update_wrapper
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import xgboost.testing as tm
|
||||||
@ -194,6 +195,43 @@ def check_quantile_loss_extmem(
|
|||||||
np.testing.assert_allclose(predt, predt_it)
|
np.testing.assert_allclose(predt, predt_it)
|
||||||
|
|
||||||
|
|
||||||
|
def check_extmem_qdm(
|
||||||
|
n_samples_per_batch: int,
|
||||||
|
n_features: int,
|
||||||
|
n_batches: int,
|
||||||
|
device: str,
|
||||||
|
on_host: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Basic test for the `ExtMemQuantileDMatrix`."""
|
||||||
|
|
||||||
|
it = tm.IteratorForTest(
|
||||||
|
*tm.make_batches(
|
||||||
|
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
|
||||||
|
),
|
||||||
|
cache="cache",
|
||||||
|
on_host=on_host,
|
||||||
|
)
|
||||||
|
Xy_it = xgb.ExtMemQuantileDMatrix(it)
|
||||||
|
with pytest.raises(ValueError, match="Only the `hist`"):
|
||||||
|
booster_it = xgb.train(
|
||||||
|
{"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8
|
||||||
|
)
|
||||||
|
|
||||||
|
booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
|
||||||
|
X, y, w = it.as_arrays()
|
||||||
|
Xy = xgb.QuantileDMatrix(X, y, weight=w)
|
||||||
|
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
|
||||||
|
|
||||||
|
cut_it = Xy_it.get_quantile_cut()
|
||||||
|
cut = Xy.get_quantile_cut()
|
||||||
|
np.testing.assert_allclose(cut_it[0], cut[0])
|
||||||
|
np.testing.assert_allclose(cut_it[1], cut[1])
|
||||||
|
|
||||||
|
predt_it = booster_it.predict(Xy_it)
|
||||||
|
predt = booster.predict(Xy)
|
||||||
|
np.testing.assert_allclose(predt_it, predt)
|
||||||
|
|
||||||
|
|
||||||
def check_cut(
|
def check_cut(
|
||||||
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
|
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
|||||||
auto jconfig = Json::Load(StringView{config});
|
auto jconfig = Json::Load(StringView{config});
|
||||||
auto missing = GetMissing(jconfig);
|
auto missing = GetMissing(jconfig);
|
||||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||||
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
|
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
|
||||||
|
|
||||||
xgboost_CHECK_C_ARG_PTR(next);
|
xgboost_CHECK_C_ARG_PTR(next);
|
||||||
xgboost_CHECK_C_ARG_PTR(reset);
|
xgboost_CHECK_C_ARG_PTR(reset);
|
||||||
@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||||
DataIterResetCallback *reset,
|
DataIterResetCallback *reset,
|
||||||
XGDMatrixCallbackNext *next, float missing,
|
XGDMatrixCallbackNext *next, float missing,
|
||||||
@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
namespace {
|
||||||
DataIterHandle ref, DataIterResetCallback *reset,
|
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
|
||||||
XGDMatrixCallbackNext *next, char const *config,
|
|
||||||
DMatrixHandle *out) {
|
|
||||||
API_BEGIN();
|
|
||||||
std::shared_ptr<DMatrix> _ref{nullptr};
|
std::shared_ptr<DMatrix> _ref{nullptr};
|
||||||
if (ref) {
|
if (ref) {
|
||||||
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
|
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
|
||||||
@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
|||||||
_ref = *pp_ref;
|
_ref = *pp_ref;
|
||||||
CHECK(_ref) << err;
|
CHECK(_ref) << err;
|
||||||
}
|
}
|
||||||
|
return _ref;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||||
|
DataIterHandle ref, DataIterResetCallback *reset,
|
||||||
|
XGDMatrixCallbackNext *next, char const *config,
|
||||||
|
DMatrixHandle *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
|
||||||
|
|
||||||
xgboost_CHECK_C_ARG_PTR(config);
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
auto jconfig = Json::Load(StringView{config});
|
auto jconfig = Json::Load(StringView{config});
|
||||||
@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
|||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||||
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
|
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||||
|
DataIterHandle ref,
|
||||||
|
DataIterResetCallback *reset,
|
||||||
|
XGDMatrixCallbackNext *next,
|
||||||
|
char const *config, DMatrixHandle *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
|
auto jconfig = Json::Load(StringView{config});
|
||||||
|
auto missing = GetMissing(jconfig);
|
||||||
|
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||||
|
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
|
||||||
|
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
|
||||||
|
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(next);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(reset);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
|
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
|
||||||
|
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
namespace xgboost::data::detail {
|
namespace xgboost::data::detail {
|
||||||
void CheckParam(BatchParam const& init, BatchParam const& param) {
|
void CheckParam(BatchParam const& init, BatchParam const& param) {
|
||||||
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
|
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
|
||||||
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
|
CHECK(!param.regen && param.hess.empty())
|
||||||
|
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
|
||||||
}
|
}
|
||||||
} // namespace xgboost::data::detail
|
} // namespace xgboost::data::detail
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import xgboost as xgb
|
|||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.data import SingleBatchInternalIter as SingleBatch
|
from xgboost.data import SingleBatchInternalIter as SingleBatch
|
||||||
from xgboost.testing import IteratorForTest, make_batches, non_increasing
|
from xgboost.testing import IteratorForTest, make_batches, non_increasing
|
||||||
from xgboost.testing.updater import check_quantile_loss_extmem
|
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
|
||||||
|
|
||||||
pytestmark = tm.timeout(30)
|
pytestmark = tm.timeout(30)
|
||||||
|
|
||||||
@ -304,3 +304,13 @@ def test_quantile_objective(
|
|||||||
"approx",
|
"approx",
|
||||||
"cpu",
|
"cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@given(
|
||||||
|
strategies.integers(1, 4096),
|
||||||
|
strategies.integers(1, 8),
|
||||||
|
strategies.integers(1, 4),
|
||||||
|
)
|
||||||
|
@settings(deadline=None, max_examples=10, print_blob=True)
|
||||||
|
def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None:
|
||||||
|
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user