Preserve order of saved updaters config. (#9355)
- Save the updater sequence as an array instead of object. - Warn only once. The compatibility is kept, but we should be able to break it as the config is not loaded in pickle model and it's declared to be not stable.
This commit is contained in:
parent
b572a39919
commit
41c6813496
@ -76,32 +76,20 @@ test_that("Models from previous versions of XGBoost can be loaded", {
|
||||
name <- m[3]
|
||||
is_rds <- endsWith(model_file, '.rds')
|
||||
is_json <- endsWith(model_file, '.json')
|
||||
|
||||
cpp_warning <- capture.output({
|
||||
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
|
||||
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
|
||||
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
|
||||
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
|
||||
booster <- readRDS(model_file)
|
||||
expect_warning(predict(booster, newdata = pred_data))
|
||||
booster <- readRDS(model_file)
|
||||
expect_warning(run_booster_check(booster, name))
|
||||
} else {
|
||||
if (is_rds) {
|
||||
booster <- readRDS(model_file)
|
||||
expect_warning(predict(booster, newdata = pred_data))
|
||||
booster <- readRDS(model_file)
|
||||
expect_warning(run_booster_check(booster, name))
|
||||
} else {
|
||||
if (is_rds) {
|
||||
booster <- readRDS(model_file)
|
||||
} else {
|
||||
booster <- xgb.load(model_file)
|
||||
}
|
||||
predict(booster, newdata = pred_data)
|
||||
run_booster_check(booster, name)
|
||||
booster <- xgb.load(model_file)
|
||||
}
|
||||
})
|
||||
cpp_warning <- paste0(cpp_warning, collapse = ' ')
|
||||
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') >= 0) {
|
||||
# Expect a C++ warning when a model is loaded from RDS and it was generated by old XGBoost`
|
||||
m <- grepl(paste0('.*If you are loading a serialized model ',
|
||||
'\\(like pickle in Python, RDS in R\\).*',
|
||||
'for more details about differences between ',
|
||||
'saving model and serializing.*'), cpp_warning, perl = TRUE)
|
||||
expect_true(length(m) > 0 && all(m))
|
||||
predict(booster, newdata = pred_data)
|
||||
run_booster_check(booster, name)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@ -51,5 +51,33 @@ inline void MaxFeatureSize(std::uint64_t n_features) {
|
||||
constexpr StringView InplacePredictProxy() {
|
||||
return "Inplace predict accepts only DMatrixProxy as input.";
|
||||
}
|
||||
|
||||
inline void MaxSampleSize(std::size_t n) {
|
||||
LOG(FATAL) << "Sample size too large for the current updater. Maximum number of samples:" << n
|
||||
<< ". Consider using a different updater or tree_method.";
|
||||
}
|
||||
|
||||
constexpr StringView OldSerialization() {
|
||||
return R"doc(If you are loading a serialized model (like pickle in Python, RDS in R) or
|
||||
configuration generated by an older version of XGBoost, please export the model by calling
|
||||
`Booster.save_model` from that version first, then load it back in current version. See:
|
||||
|
||||
https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html
|
||||
|
||||
for more details about differences between saving model and serializing.
|
||||
)doc";
|
||||
}
|
||||
|
||||
inline void WarnOldSerialization() {
|
||||
// Display it once is enough. Otherwise this can be really verbose in distributed
|
||||
// environments.
|
||||
static thread_local bool logged{false};
|
||||
if (logged) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(WARNING) << OldSerialization();
|
||||
logged = true;
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@ -21,8 +21,7 @@
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
MetaInfo& SimpleDMatrix::Info() { return info_; }
|
||||
|
||||
const MetaInfo& SimpleDMatrix::Info() const { return info_; }
|
||||
@ -97,6 +96,10 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches(Context const* ctx) {
|
||||
// column page doesn't exist, generate it
|
||||
if (!column_page_) {
|
||||
auto n = std::numeric_limits<decltype(Entry::index)>::max();
|
||||
if (this->sparse_page_->Size() > n) {
|
||||
error::MaxSampleSize(n);
|
||||
}
|
||||
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
|
||||
}
|
||||
auto begin_iter = BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
|
||||
@ -106,6 +109,10 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches(Context const* ctx) {
|
||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches(Context const* ctx) {
|
||||
// Sorted column page doesn't exist, generate it
|
||||
if (!sorted_column_page_) {
|
||||
auto n = std::numeric_limits<decltype(Entry::index)>::max();
|
||||
if (this->sparse_page_->Size() > n) {
|
||||
error::MaxSampleSize(n);
|
||||
}
|
||||
sorted_column_page_.reset(
|
||||
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
|
||||
sorted_column_page_->SortRows(ctx->Threads());
|
||||
@ -427,5 +434,4 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
||||
|
||||
fmat_ctx_ = ctx;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
|
||||
#include "../common/error_msg.h" // for UnknownDevice, WarnOldSerialization, InplacePredictProxy
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/timer.h"
|
||||
@ -391,19 +391,32 @@ void GBTree::LoadConfig(Json const& in) {
|
||||
LOG(WARNING) << msg << " Changing `tree_method` to `hist`.";
|
||||
}
|
||||
|
||||
auto const& j_updaters = get<Object const>(in["updater"]);
|
||||
std::vector<Json> updater_seq;
|
||||
if (IsA<Object>(in["updater"])) {
|
||||
// before 2.0
|
||||
error::WarnOldSerialization();
|
||||
for (auto const& kv : get<Object const>(in["updater"])) {
|
||||
auto name = kv.first;
|
||||
auto config = kv.second;
|
||||
config["name"] = name;
|
||||
updater_seq.push_back(config);
|
||||
}
|
||||
} else {
|
||||
// after 2.0
|
||||
auto const& j_updaters = get<Array const>(in["updater"]);
|
||||
updater_seq = j_updaters;
|
||||
}
|
||||
|
||||
updaters_.clear();
|
||||
|
||||
for (auto const& kv : j_updaters) {
|
||||
auto name = kv.first;
|
||||
for (auto const& config : updater_seq) {
|
||||
auto name = get<String>(config["name"]);
|
||||
if (n_gpus == 0 && name == "grow_gpu_hist") {
|
||||
name = "grow_quantile_histmaker";
|
||||
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
|
||||
}
|
||||
std::unique_ptr<TreeUpdater> up{
|
||||
TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
|
||||
up->LoadConfig(kv.second);
|
||||
updaters_.push_back(std::move(up));
|
||||
updaters_.emplace_back(TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task));
|
||||
updaters_.back()->LoadConfig(config);
|
||||
}
|
||||
|
||||
specified_updater_ = get<Boolean>(in["specified_updater"]);
|
||||
@ -425,13 +438,14 @@ void GBTree::SaveConfig(Json* p_out) const {
|
||||
// language binding doesn't need to know about the forest size.
|
||||
out["gbtree_model_param"] = ToJson(model_.param);
|
||||
|
||||
out["updater"] = Object();
|
||||
out["updater"] = Array{};
|
||||
auto& j_updaters = get<Array>(out["updater"]);
|
||||
|
||||
auto& j_updaters = out["updater"];
|
||||
for (auto const& up : updaters_) {
|
||||
j_updaters[up->Name()] = Object();
|
||||
auto& j_up = j_updaters[up->Name()];
|
||||
up->SaveConfig(&j_up);
|
||||
for (auto const& up : this->updaters_) {
|
||||
Json up_config{Object{}};
|
||||
up_config["name"] = String{up->Name()};
|
||||
up->SaveConfig(&up_config);
|
||||
j_updaters.emplace_back(up_config);
|
||||
}
|
||||
out["specified_updater"] = Boolean{specified_updater_};
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@
|
||||
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_...
|
||||
#include "common/common.h" // for ToString, Split
|
||||
#include "common/error_msg.h" // for MaxFeatureSize
|
||||
#include "common/error_msg.h" // for MaxFeatureSize, WarnOldSerialization
|
||||
#include "common/io.h" // for PeekableInStream, ReadAll, FixedSizeStream, Mem...
|
||||
#include "common/observer.h" // for TrainingObserver
|
||||
#include "common/random.h" // for GlobalRandom
|
||||
@ -357,21 +357,6 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
||||
using LearnerAPIThreadLocalStore =
|
||||
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
|
||||
|
||||
namespace {
|
||||
StringView ModelMsg() {
|
||||
return StringView{
|
||||
R"doc(
|
||||
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
|
||||
older XGBoost, please export the model by calling `Booster.save_model` from that version
|
||||
first, then load it back in current version. See:
|
||||
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||
|
||||
for more details about differences between saving model and serializing.
|
||||
)doc"};
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class LearnerConfiguration : public Learner {
|
||||
private:
|
||||
std::mutex config_lock_;
|
||||
@ -531,7 +516,7 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
|
||||
if (!Version::Same(origin_version)) {
|
||||
LOG(WARNING) << ModelMsg();
|
||||
error::WarnOldSerialization();
|
||||
return; // skip configuration if version is not matched
|
||||
}
|
||||
|
||||
@ -562,7 +547,7 @@ class LearnerConfiguration : public Learner {
|
||||
for (size_t i = 0; i < n_metrics; ++i) {
|
||||
auto old_serialization = IsA<String>(j_metrics[i]);
|
||||
if (old_serialization) {
|
||||
LOG(WARNING) << ModelMsg();
|
||||
error::WarnOldSerialization();
|
||||
metric_names_[i] = get<String>(j_metrics[i]);
|
||||
} else {
|
||||
metric_names_[i] = get<String>(j_metrics[i]["name"]);
|
||||
@ -1173,7 +1158,7 @@ class LearnerIO : public LearnerConfiguration {
|
||||
Json memory_snapshot;
|
||||
if (header[1] == '"') {
|
||||
memory_snapshot = Json::Load(StringView{buffer});
|
||||
LOG(WARNING) << ModelMsg();
|
||||
error::WarnOldSerialization();
|
||||
} else if (std::isalpha(header[1])) {
|
||||
memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary);
|
||||
} else {
|
||||
@ -1192,7 +1177,7 @@ class LearnerIO : public LearnerConfiguration {
|
||||
header.resize(serialisation_header_.size());
|
||||
CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size());
|
||||
// Avoid printing the content in loaded header, which might be random binary code.
|
||||
CHECK(header == serialisation_header_) << ModelMsg();
|
||||
CHECK(header == serialisation_header_) << error::OldSerialization();
|
||||
int64_t sz {-1};
|
||||
CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz));
|
||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||
|
||||
@ -174,32 +174,52 @@ TEST(GBTree, JsonIO) {
|
||||
Context ctx;
|
||||
LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
|
||||
|
||||
std::unique_ptr<GradientBooster> gbm {
|
||||
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &ctx) };
|
||||
std::unique_ptr<GradientBooster> gbm{
|
||||
CreateTrainedGBM("gbtree", Args{{"tree_method", "exact"}, {"default_direction", "left"}},
|
||||
kRows, kCols, &mparam, &ctx)};
|
||||
|
||||
Json model {Object()};
|
||||
Json model{Object()};
|
||||
model["model"] = Object();
|
||||
auto& j_model = model["model"];
|
||||
auto j_model = model["model"];
|
||||
|
||||
model["config"] = Object();
|
||||
auto& j_param = model["config"];
|
||||
auto j_config = model["config"];
|
||||
|
||||
gbm->SaveModel(&j_model);
|
||||
gbm->SaveConfig(&j_param);
|
||||
gbm->SaveConfig(&j_config);
|
||||
|
||||
std::string model_str;
|
||||
Json::Dump(model, &model_str);
|
||||
|
||||
model = Json::Load({model_str.c_str(), model_str.size()});
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
||||
j_model = model["model"];
|
||||
j_config = model["config"];
|
||||
ASSERT_EQ(get<String>(j_model["name"]), "gbtree");
|
||||
|
||||
auto const& gbtree_model = model["model"]["model"];
|
||||
auto gbtree_model = j_model["model"];
|
||||
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1ul);
|
||||
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
|
||||
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1ul);
|
||||
|
||||
auto j_train_param = model["config"]["gbtree_model_param"];
|
||||
auto j_train_param = j_config["gbtree_model_param"];
|
||||
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
|
||||
|
||||
auto check_config = [](Json j_up_config) {
|
||||
auto colmaker = get<Array const>(j_up_config).front();
|
||||
auto pruner = get<Array const>(j_up_config).back();
|
||||
ASSERT_EQ(get<String const>(colmaker["name"]), "grow_colmaker");
|
||||
ASSERT_EQ(get<String const>(pruner["name"]), "prune");
|
||||
ASSERT_EQ(get<String const>(colmaker["colmaker_train_param"]["default_direction"]), "left");
|
||||
};
|
||||
check_config(j_config["updater"]);
|
||||
|
||||
std::unique_ptr<GradientBooster> loaded(gbm::GBTree::Create("gbtree", &ctx, &mparam));
|
||||
loaded->LoadModel(j_model);
|
||||
loaded->LoadConfig(j_config);
|
||||
|
||||
// roundtrip test
|
||||
Json j_config_rt{Object{}};
|
||||
loaded->SaveConfig(&j_config_rt);
|
||||
check_config(j_config_rt["updater"]);
|
||||
}
|
||||
|
||||
TEST(Dart, JsonIO) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user