Implement new save_raw in Python. (#7572)
* Expose the new C API function to Python. * Remove old document and helper script. * Small optimization to the `save_raw` and Json ctors.
This commit is contained in:
@@ -971,28 +971,34 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
|
||||
auto format = RequiredArg<String>(config, "format", __func__);
|
||||
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
std::string &raw_str = learner->GetThreadLocal().ret_str;
|
||||
raw_str.clear();
|
||||
|
||||
learner->Configure();
|
||||
|
||||
auto save_json = [&](std::ios::openmode mode) {
|
||||
std::vector<char> &raw_char_vec = learner->GetThreadLocal().ret_char_vec;
|
||||
Json out{Object{}};
|
||||
learner->SaveModel(&out);
|
||||
Json::Dump(out, &raw_char_vec, mode);
|
||||
*out_dptr = dmlc::BeginPtr(raw_char_vec);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_char_vec.size());
|
||||
};
|
||||
|
||||
Json out{Object{}};
|
||||
if (format == "json") {
|
||||
learner->SaveModel(&out);
|
||||
Json::Dump(out, &raw_str);
|
||||
save_json(std::ios::out);
|
||||
} else if (format == "ubj") {
|
||||
learner->SaveModel(&out);
|
||||
Json::Dump(out, &raw_str, std::ios::binary);
|
||||
save_json(std::ios::binary);
|
||||
} else if (format == "deprecated") {
|
||||
WarnOldModel();
|
||||
auto &raw_str = learner->GetThreadLocal().ret_str;
|
||||
raw_str.clear();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
learner->SaveModel(&fo);
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown format: `" << format << "`";
|
||||
}
|
||||
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user