Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
|
||||
API_END();
|
||||
}
|
||||
@@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
|
||||
float *out_data) {
|
||||
|
||||
Reference in New Issue
Block a user