Add data split mode to DMatrix MetaInfo (#8568)

This commit is contained in:
Rong Ou
2022-12-25 04:37:37 -08:00
committed by GitHub
parent 77b069c25d
commit 3ceeb8c61c
20 changed files with 113 additions and 103 deletions

View File

@@ -206,17 +206,29 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
}
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN();
auto data_split_mode = DataSplitMode::kNone;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
data_split_mode = DataSplitMode::kRow;
}
xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
Json config{Object()};
config["uri"] = std::string{fname};
config["silent"] = silent;
std::string config_str;
Json::Dump(config, &config_str);
return XGDMatrixCreateFromURI(config_str.c_str(), out);
}
XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(config);
xgboost_CHECK_C_ARG_PTR(out);
auto jconfig = Json::Load(StringView{config});
std::string uri = RequiredArg<String>(jconfig, "uri", __func__);
auto silent = static_cast<bool>(OptionalArg<Integer, int64_t>(jconfig, "silent", 1));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(jconfig, "data_split_mode", 0));
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(uri, silent, data_split_mode));
API_END();
}