[EM] Python wrapper for the ExtMemQuantileDMatrix. (#10762)
Not exposed to the document yet. - Add C API. - Add Python API. - Basic CPU tests.
This commit is contained in:
@@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
||||
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
xgboost_CHECK_C_ARG_PTR(reset);
|
||||
@@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
@@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
namespace {
|
||||
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
|
||||
std::shared_ptr<DMatrix> _ref{nullptr};
|
||||
if (ref) {
|
||||
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
|
||||
@@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
||||
_ref = *pp_ref;
|
||||
CHECK(_ref) << err;
|
||||
}
|
||||
return _ref;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
@@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
|
||||
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterHandle ref,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next,
|
||||
char const *config, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
|
||||
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
|
||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
xgboost_CHECK_C_ARG_PTR(reset);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
|
||||
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
namespace xgboost::data::detail {
|
||||
void CheckParam(BatchParam const& init, BatchParam const& param) {
|
||||
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
|
||||
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
|
||||
CHECK(!param.regen && param.hess.empty())
|
||||
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
|
||||
}
|
||||
} // namespace xgboost::data::detail
|
||||
|
||||
Reference in New Issue
Block a user