Implement new save_raw in Python. (#7572)

* Expose the new C API function to Python.
* Remove old document and helper script.
* Small optimization to the `save_raw` and Json ctors.
This commit is contained in:
Jiaming Yuan
2022-01-19 02:27:51 +08:00
committed by GitHub
parent 9f20a3315e
commit dac9eb13bd
8 changed files with 104 additions and 150 deletions

View File

@@ -971,28 +971,34 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
auto format = RequiredArg<String>(config, "format", __func__);
auto *learner = static_cast<Learner *>(handle);
std::string &raw_str = learner->GetThreadLocal().ret_str;
raw_str.clear();
learner->Configure();
auto save_json = [&](std::ios::openmode mode) {
std::vector<char> &raw_char_vec = learner->GetThreadLocal().ret_char_vec;
Json out{Object{}};
learner->SaveModel(&out);
Json::Dump(out, &raw_char_vec, mode);
*out_dptr = dmlc::BeginPtr(raw_char_vec);
*out_len = static_cast<xgboost::bst_ulong>(raw_char_vec.size());
};
Json out{Object{}};
if (format == "json") {
learner->SaveModel(&out);
Json::Dump(out, &raw_str);
save_json(std::ios::out);
} else if (format == "ubj") {
learner->SaveModel(&out);
Json::Dump(out, &raw_str, std::ios::binary);
save_json(std::ios::binary);
} else if (format == "deprecated") {
WarnOldModel();
auto &raw_str = learner->GetThreadLocal().ret_str;
raw_str.clear();
common::MemoryBufferStream fo(&raw_str);
learner->SaveModel(&fo);
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
} else {
LOG(FATAL) << "Unknown format: `" << format << "`";
}
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}