[EM] Python wrapper for the ExtMemQuantileDMatrix. (#10762)

Not exposed to the document yet.

- Add C API.
- Add Python API.
- Basic CPU tests.
This commit is contained in:
Jiaming Yuan
2024-08-29 04:08:25 +08:00
committed by GitHub
parent 7510a87466
commit 34937fea41
7 changed files with 208 additions and 27 deletions

View File

@@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
@@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
API_END();
}
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
@@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
namespace {
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
std::shared_ptr<DMatrix> _ref{nullptr};
if (ref) {
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
@@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
_ref = *pp_ref;
CHECK(_ref) << err;
}
return _ref;
}
} // namespace
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, char const *config,
DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
@@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
API_END();
}
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterHandle ref,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const *config, DMatrixHandle *out) {
API_BEGIN();
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
API_END();
}

View File

@@ -8,6 +8,7 @@
namespace xgboost::data::detail {
void CheckParam(BatchParam const& init, BatchParam const& param) {
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
CHECK(!param.regen && param.hess.empty())
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
}
} // namespace xgboost::data::detail