[breaking] Change internal model serialization to UBJSON. (#7556)
* Use typed array for models. * Change the memory snapshot format. * Add new C API for saving to raw format.
This commit is contained in:
124
src/learner.cc
124
src/learner.cc
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* Copyright 2014-2022 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@@ -706,6 +706,21 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
|
||||
|
||||
namespace {
|
||||
StringView ModelMsg() {
|
||||
return StringView{
|
||||
R"doc(
|
||||
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
|
||||
older XGBoost, please export the model by calling `Booster.save_model` from that version
|
||||
first, then load it back in current version. See:
|
||||
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||
|
||||
for more details about differences between saving model and serializing.
|
||||
)doc"};
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class LearnerIO : public LearnerConfiguration {
|
||||
private:
|
||||
std::set<std::string> saved_configs_ = {"num_round"};
|
||||
@@ -714,12 +729,17 @@ class LearnerIO : public LearnerConfiguration {
|
||||
std::string const serialisation_header_ { u8"CONFIG-offset:" };
|
||||
|
||||
public:
|
||||
explicit LearnerIO(std::vector<std::shared_ptr<DMatrix> > cache) :
|
||||
LearnerConfiguration{cache} {}
|
||||
explicit LearnerIO(std::vector<std::shared_ptr<DMatrix>> cache) : LearnerConfiguration{cache} {}
|
||||
|
||||
void LoadModel(Json const& in) override {
|
||||
CHECK(IsA<Object>(in));
|
||||
Version::Load(in);
|
||||
auto version = Version::Load(in);
|
||||
if (std::get<0>(version) == 1 && std::get<1>(version) < 6) {
|
||||
LOG(WARNING)
|
||||
<< "Found JSON model saved before XGBoost 1.6, please save the model using current "
|
||||
"version again. The support for old JSON model will be discontinued in XGBoost 2.3.";
|
||||
}
|
||||
|
||||
auto const& learner = get<Object>(in["learner"]);
|
||||
mparam_.FromJson(learner.at("learner_model_param"));
|
||||
|
||||
@@ -733,8 +753,8 @@ class LearnerIO : public LearnerConfiguration {
|
||||
auto const& gradient_booster = learner.at("gradient_booster");
|
||||
name = get<String>(gradient_booster["name"]);
|
||||
tparam_.UpdateAllowUnknown(Args{{"booster", name}});
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster,
|
||||
&generic_parameters_, &learner_model_param_));
|
||||
gbm_.reset(
|
||||
GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_));
|
||||
gbm_->LoadModel(gradient_booster);
|
||||
|
||||
auto const& j_attributes = get<Object const>(learner.at("attributes"));
|
||||
@@ -746,20 +766,17 @@ class LearnerIO : public LearnerConfiguration {
|
||||
// feature names and types are saved in xgboost 1.4
|
||||
auto it = learner.find("feature_names");
|
||||
if (it != learner.cend()) {
|
||||
auto const &feature_names = get<Array const>(it->second);
|
||||
feature_names_.clear();
|
||||
for (auto const &name : feature_names) {
|
||||
feature_names_.emplace_back(get<String const>(name));
|
||||
}
|
||||
auto const& feature_names = get<Array const>(it->second);
|
||||
feature_names_.resize(feature_names.size());
|
||||
std::transform(feature_names.cbegin(), feature_names.cend(), feature_names_.begin(),
|
||||
[](Json const& fn) { return get<String const>(fn); });
|
||||
}
|
||||
it = learner.find("feature_types");
|
||||
if (it != learner.cend()) {
|
||||
auto const &feature_types = get<Array const>(it->second);
|
||||
feature_types_.clear();
|
||||
for (auto const &name : feature_types) {
|
||||
auto type = get<String const>(name);
|
||||
feature_types_.emplace_back(type);
|
||||
}
|
||||
auto const& feature_types = get<Array const>(it->second);
|
||||
feature_types_.resize(feature_types.size());
|
||||
std::transform(feature_types.cbegin(), feature_types.cend(), feature_types_.begin(),
|
||||
[](Json const& fn) { return get<String const>(fn); });
|
||||
}
|
||||
|
||||
this->need_configuration_ = true;
|
||||
@@ -799,6 +816,7 @@ class LearnerIO : public LearnerConfiguration {
|
||||
feature_types.emplace_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
// About to be deprecated by JSON format
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
generic_parameters_.UpdateAllowUnknown(Args{});
|
||||
@@ -817,15 +835,20 @@ class LearnerIO : public LearnerConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
if (header[0] == '{') {
|
||||
// Dispatch to JSON
|
||||
auto json_stream = common::FixedSizeStream(&fp);
|
||||
std::string buffer;
|
||||
json_stream.Take(&buffer);
|
||||
auto model = Json::Load({buffer.c_str(), buffer.size()});
|
||||
if (header[0] == '{') { // Dispatch to JSON
|
||||
auto buffer = common::ReadAll(fi, &fp);
|
||||
Json model;
|
||||
if (header[1] == '"') {
|
||||
model = Json::Load(StringView{buffer});
|
||||
} else if (std::isalpha(header[1])) {
|
||||
model = Json::Load(StringView{buffer}, std::ios::binary);
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid model format";
|
||||
}
|
||||
this->LoadModel(model);
|
||||
return;
|
||||
}
|
||||
|
||||
// use the peekable reader.
|
||||
fi = &fp;
|
||||
// read parameter
|
||||
@@ -983,45 +1006,46 @@ class LearnerIO : public LearnerConfiguration {
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
Json memory_snapshot{Object()};
|
||||
memory_snapshot["Model"] = Object();
|
||||
auto &model = memory_snapshot["Model"];
|
||||
auto& model = memory_snapshot["Model"];
|
||||
this->SaveModel(&model);
|
||||
memory_snapshot["Config"] = Object();
|
||||
auto &config = memory_snapshot["Config"];
|
||||
auto& config = memory_snapshot["Config"];
|
||||
this->SaveConfig(&config);
|
||||
std::string out_str;
|
||||
Json::Dump(memory_snapshot, &out_str);
|
||||
fo->Write(out_str.c_str(), out_str.size());
|
||||
|
||||
std::vector<char> stream;
|
||||
Json::Dump(memory_snapshot, &stream, std::ios::binary);
|
||||
fo->Write(stream.data(), stream.size());
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
common::PeekableInStream fp(fi);
|
||||
char c {0};
|
||||
fp.PeekRead(&c, 1);
|
||||
if (c == '{') {
|
||||
std::string buffer;
|
||||
common::FixedSizeStream{&fp}.Take(&buffer);
|
||||
auto memory_snapshot = Json::Load({buffer.c_str(), buffer.size()});
|
||||
this->LoadModel(memory_snapshot["Model"]);
|
||||
this->LoadConfig(memory_snapshot["Config"]);
|
||||
char header[2];
|
||||
fp.PeekRead(header, 2);
|
||||
if (header[0] == '{') {
|
||||
auto buffer = common::ReadAll(fi, &fp);
|
||||
Json memory_snapshot;
|
||||
if (header[1] == '"') {
|
||||
memory_snapshot = Json::Load(StringView{buffer});
|
||||
LOG(WARNING) << ModelMsg();
|
||||
} else if (std::isalpha(header[1])) {
|
||||
memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary);
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid serialization file.";
|
||||
}
|
||||
if (IsA<Null>(memory_snapshot["Model"])) {
|
||||
// R has xgb.load that doesn't distinguish whether configuration is saved.
|
||||
// We should migrate to use `xgb.load.raw` instead.
|
||||
this->LoadModel(memory_snapshot);
|
||||
} else {
|
||||
this->LoadModel(memory_snapshot["Model"]);
|
||||
this->LoadConfig(memory_snapshot["Config"]);
|
||||
}
|
||||
} else {
|
||||
std::string header;
|
||||
header.resize(serialisation_header_.size());
|
||||
CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size());
|
||||
// Avoid printing the content in loaded header, which might be random binary code.
|
||||
CHECK(header == serialisation_header_) // NOLINT
|
||||
<< R"doc(
|
||||
|
||||
If you are loading a serialized model (like pickle in Python) generated by older
|
||||
XGBoost, please export the model by calling `Booster.save_model` from that version
|
||||
first, then load it back in current version. There's a simple script for helping
|
||||
the process. See:
|
||||
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||
|
||||
for reference to the script, and more details about differences between saving model and
|
||||
serializing.
|
||||
|
||||
)doc";
|
||||
CHECK(header == serialisation_header_) << ModelMsg();
|
||||
int64_t sz {-1};
|
||||
CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz));
|
||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||
|
||||
Reference in New Issue
Block a user