Implement new save_raw in Python. (#7572)

* Expose the new C API function to Python.
* Remove old document and helper script.
* Small optimization to the `save_raw` and Json ctors.
This commit is contained in:
Jiaming Yuan
2022-01-19 02:27:51 +08:00
committed by GitHub
parent 9f20a3315e
commit dac9eb13bd
8 changed files with 104 additions and 150 deletions

View File

@@ -971,28 +971,34 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
auto format = RequiredArg<String>(config, "format", __func__);
auto *learner = static_cast<Learner *>(handle);
std::string &raw_str = learner->GetThreadLocal().ret_str;
raw_str.clear();
learner->Configure();
auto save_json = [&](std::ios::openmode mode) {
std::vector<char> &raw_char_vec = learner->GetThreadLocal().ret_char_vec;
Json out{Object{}};
learner->SaveModel(&out);
Json::Dump(out, &raw_char_vec, mode);
*out_dptr = dmlc::BeginPtr(raw_char_vec);
*out_len = static_cast<xgboost::bst_ulong>(raw_char_vec.size());
};
Json out{Object{}};
if (format == "json") {
learner->SaveModel(&out);
Json::Dump(out, &raw_str);
save_json(std::ios::out);
} else if (format == "ubj") {
learner->SaveModel(&out);
Json::Dump(out, &raw_str, std::ios::binary);
save_json(std::ios::binary);
} else if (format == "deprecated") {
WarnOldModel();
auto &raw_str = learner->GetThreadLocal().ret_str;
raw_str.clear();
common::MemoryBufferStream fo(&raw_str);
learner->SaveModel(&fo);
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
} else {
LOG(FATAL) << "Unknown format: `" << format << "`";
}
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}

View File

@@ -195,11 +195,12 @@ Json& Value::operator[](int) {
}
// Json Object
JsonObject::JsonObject(JsonObject && that) noexcept :
Value(ValueKind::kObject), object_{std::move(that.object_)} {}
JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
std::swap(that.object_, this->object_);
}
JsonObject::JsonObject(std::map<std::string, Json> &&object) noexcept
: Value(ValueKind::kObject), object_{std::move(object)} {}
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
bool JsonObject::operator==(Value const& rhs) const {
if (!IsA<JsonObject>(&rhs)) {
@@ -220,8 +221,9 @@ bool JsonString::operator==(Value const& rhs) const {
void JsonString::Save(JsonWriter* writer) const { writer->Visit(this); }
// Json Array
JsonArray::JsonArray(JsonArray && that) noexcept :
Value(ValueKind::kArray), vec_{std::move(that.vec_)} {}
JsonArray::JsonArray(JsonArray&& that) noexcept : Value(ValueKind::kArray) {
std::swap(that.vec_, this->vec_);
}
bool JsonArray::operator==(Value const& rhs) const {
if (!IsA<JsonArray>(&rhs)) {
@@ -696,6 +698,7 @@ void Json::Dump(Json json, std::string* str, std::ios::openmode mode) {
}
void Json::Dump(Json json, std::vector<char>* str, std::ios::openmode mode) {
str->clear();
if (mode & std::ios::binary) {
UBJWriter writer{str};
writer.Save(json);
@@ -768,9 +771,7 @@ std::string UBJReader::DecodeStr() {
str.resize(bsize);
auto ptr = raw_str_.c_str() + cursor_.Pos();
std::memcpy(&str[0], ptr, bsize);
for (int64_t i = 0; i < bsize; ++i) {
this->cursor_.Forward();
}
this->cursor_.Forward(bsize);
return str;
}