Support column-wise data split with in-memory inputs (#9628)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou
2023-10-16 21:16:39 -07:00
committed by GitHub
parent 4d1607eefd
commit da6803b75b
12 changed files with 307 additions and 27 deletions

View File

@@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}
@@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}
@@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
API_END();
}
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
API_END();
}
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
float *out_data) {