[breaking] Change internal model serialization to UBJSON. (#7556)

* Use typed array for models.
* Change the memory snapshot format.
* Add new C API for saving to raw format.
This commit is contained in:
Jiaming Yuan 2022-01-16 02:11:53 +08:00 committed by GitHub
parent 13b0fa4b97
commit a1bcd33a3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 566 additions and 255 deletions

View File

@ -97,7 +97,7 @@
"default_left": { "default_left": {
"type": "array", "type": "array",
"items": { "items": {
"type": "boolean" "type": "integer"
} }
}, },
"categories": { "categories": {

View File

@ -1081,14 +1081,32 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle,
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
const void *buf, const void *buf,
bst_ulong len); bst_ulong len);
/*! /*!
* \brief save model into binary raw bytes, return header of the array * \brief Save model into raw bytes, return header of the array. User must copy the
* user must copy the result out, before next xgboost call * result out, before next xgboost call
*
* \param handle handle * \param handle handle
* \param out_len the argument to hold the output length * \param json_config JSON encoded string storing parameters for the function. Following
* \param out_dptr the argument to hold the output data pointer * keys are expected in the JSON document:
*
* "format": str
* - json: Output booster will be encoded as JSON.
* - ubj: Output booster will be encoded as Univeral binary JSON.
* - deprecated: Output booster will be encoded as old custom binary format. Do not use
* this format except for compatibility reasons.
*
* \param out_len The argument to hold the output length
* \param out_dptr The argument to hold the output data pointer
*
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_config,
bst_ulong *out_len, char const **out_dptr);
/*!
* \brief Deprecated, use `XGBoosterSaveModelToBuffer` instead.
*/
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, bst_ulong *out_len, XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, bst_ulong *out_len,
const char **out_dptr); const char **out_dptr);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright (c) 2015-2021 by Contributors * Copyright (c) 2015-2022 by Contributors
* \file data.h * \file data.h
* \brief The input data structure of xgboost. * \brief The input data structure of xgboost.
* \author Tianqi Chen * \author Tianqi Chen
@ -36,10 +36,7 @@ enum class DataType : uint8_t {
kStr = 5 kStr = 5
}; };
enum class FeatureType : uint8_t { enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
kNumerical,
kCategorical
};
/*! /*!
* \brief Meta information about dataset, always sit in memory. * \brief Meta information about dataset, always sit in memory.

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
* \file linalg.h * \file linalg.h
* \brief Linear algebra related utilities. * \brief Linear algebra related utilities.
*/ */
@ -567,7 +567,7 @@ template <typename T, int32_t D>
Json ArrayInterface(TensorView<T const, D> const &t) { Json ArrayInterface(TensorView<T const, D> const &t) {
Json array_interface{Object{}}; Json array_interface{Object{}};
array_interface["data"] = std::vector<Json>(2); array_interface["data"] = std::vector<Json>(2);
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(t.Values().data())); array_interface["data"][0] = Integer{reinterpret_cast<int64_t>(t.Values().data())};
array_interface["data"][1] = Boolean{true}; array_interface["data"][1] = Boolean{true};
if (t.DeviceIdx() >= 0) { if (t.DeviceIdx() >= 0) {
// Change this once we have different CUDA stream. // Change this once we have different CUDA stream.

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2022 by Contributors
* \file tree_model.h * \file tree_model.h
* \brief model structure for tree * \brief model structure for tree
* \author Tianqi Chen * \author Tianqi Chen
@ -42,7 +42,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
/*! \brief maximum depth, this is a statistics of the tree */ /*! \brief maximum depth, this is a statistics of the tree */
int deprecated_max_depth; int deprecated_max_depth;
/*! \brief number of features used for tree construction */ /*! \brief number of features used for tree construction */
int num_feature; bst_feature_t num_feature;
/*! /*!
* \brief leaf vector size, used for vector tree * \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree * used to store more than one dimensional information in tree
@ -629,6 +629,7 @@ class RegTree : public Model {
} }
private: private:
template <bool typed>
void LoadCategoricalSplit(Json const& in); void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const; void SaveCategoricalSplit(Json* p_out) const;
// vector of nodes // vector of nodes

View File

@ -1,4 +1,4 @@
// Copyright (c) 2014-2021 by Contributors // Copyright (c) 2014-2022 by Contributors
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <rabit/c_api.h> #include <rabit/c_api.h>
@ -248,22 +248,16 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
#endif #endif
// Create from data iterator // Create from data iterator
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
DataIterResetCallback *reset, char const *c_json_config, DMatrixHandle *out) {
XGDMatrixCallbackNext *next,
char const* c_json_config,
DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
auto config = Json::Load(StringView{c_json_config}); auto config = Json::Load(StringView{c_json_config});
float missing = get<Number const>(config["missing"]); auto missing = GetMissing(config);
std::string cache = get<String const>(config["cache_prefix"]); std::string cache = RequiredArg<String>(config, "cache_prefix", __func__);
int32_t n_threads = omp_get_max_threads(); auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
if (!IsA<Null>(config["nthread"])) { *out = new std::shared_ptr<xgboost::DMatrix>{
n_threads = get<Integer const>(config["nthread"]); xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
}
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
iter, proxy, reset, next, missing, n_threads, cache)};
API_END(); API_END();
} }
@ -358,8 +352,8 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
StringView{data}, ncol); StringView{data}, ncol);
auto config = Json::Load(StringView{c_json_config}); auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config); float missing = GetMissing(config);
auto nthread = get<Integer const>(config["nthread"]); auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread)); *out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END(); API_END();
} }
@ -371,9 +365,9 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
xgboost::data::ArrayAdapter(StringView{data})}; xgboost::data::ArrayAdapter(StringView{data})};
auto config = Json::Load(StringView{c_json_config}); auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config); float missing = GetMissing(config);
auto nthread = get<Integer const>(config["nthread"]); auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
*out = *out =
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread)); new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END(); API_END();
} }
@ -765,11 +759,11 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
auto& entry = learner->GetThreadLocal().prediction_entry; auto& entry = learner->GetThreadLocal().prediction_entry;
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat); auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
auto const& j_config = get<Object const>(config); auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
auto type = PredictionType(get<Integer const>(j_config.at("type"))); auto iteration_begin = RequiredArg<Integer>(config, "iteration_begin", __func__);
auto iteration_begin = get<Integer const>(j_config.at("iteration_begin")); auto iteration_end = RequiredArg<Integer>(config, "iteration_end", __func__);
auto iteration_end = get<Integer const>(j_config.at("iteration_end"));
auto const& j_config = get<Object const>(config);
auto ntree_limit_it = j_config.find("ntree_limit"); auto ntree_limit_it = j_config.find("ntree_limit");
if (ntree_limit_it != j_config.cend() && !IsA<Null>(ntree_limit_it->second) && if (ntree_limit_it != j_config.cend() && !IsA<Null>(ntree_limit_it->second) &&
get<Integer const>(ntree_limit_it->second) != 0) { get<Integer const>(ntree_limit_it->second) != 0) {
@ -785,7 +779,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
type == PredictionType::kApproxContribution; type == PredictionType::kApproxContribution;
bool interactions = type == PredictionType::kInteraction || bool interactions = type == PredictionType::kInteraction ||
type == PredictionType::kApproxInteraction; type == PredictionType::kApproxInteraction;
bool training = get<Boolean const>(config["training"]); bool training = RequiredArg<Boolean>(config, "training", __func__);
learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions, learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions,
iteration_begin, iteration_end, training, iteration_begin, iteration_end, training,
type == PredictionType::kLeaf, contribs, approximate, type == PredictionType::kLeaf, contribs, approximate,
@ -796,7 +790,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
auto rounds = iteration_end - iteration_begin; auto rounds = iteration_end - iteration_begin;
rounds = rounds == 0 ? learner->BoostedRounds() : rounds; rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
// Determine shape // Determine shape
bool strict_shape = get<Boolean const>(config["strict_shape"]); bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
CalcPredictShape(strict_shape, type, p_m->Info().num_row_, CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
p_m->Info().num_col_, chunksize, learner->Groups(), rounds, p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
&shape, out_dim); &shape, out_dim);
@ -814,15 +808,15 @@ void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet"; CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
HostDeviceVector<float>* p_predt { nullptr }; HostDeviceVector<float>* p_predt { nullptr };
auto type = PredictionType(get<Integer const>(config["type"])); auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
float missing = GetMissing(config); float missing = GetMissing(config);
learner->InplacePredict(x, p_m, type, missing, &p_predt, learner->InplacePredict(x, p_m, type, missing, &p_predt,
get<Integer const>(config["iteration_begin"]), RequiredArg<Integer>(config, "iteration_begin", __func__),
get<Integer const>(config["iteration_end"])); RequiredArg<Integer>(config, "iteration_end", __func__));
CHECK(p_predt); CHECK(p_predt);
auto &shape = learner->GetThreadLocal().prediction_shape; auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows; auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
bool strict_shape = get<Boolean const>(config["strict_shape"]); bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(), CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
learner->BoostedRounds(), &shape, out_dim); learner->BoostedRounds(), &shape, out_dim);
*out_result = dmlc::BeginPtr(p_predt->HostVector()); *out_result = dmlc::BeginPtr(p_predt->HostVector());
@ -900,11 +894,20 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar(
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
if (common::FileExtension(fname) == "json") { auto read_file = [&]() {
auto str = common::LoadSequentialFile(fname); auto str = common::LoadSequentialFile(fname);
CHECK_GT(str.size(), 2); CHECK_GE(str.size(), 3); // "{}\0"
CHECK_EQ(str[0], '{'); CHECK_EQ(str[0], '{');
Json in { Json::Load({str.c_str(), str.size()}) }; CHECK_EQ(str[str.size() - 2], '}');
return str;
};
if (common::FileExtension(fname) == "json") {
auto str = read_file();
Json in{Json::Load(StringView{str})};
static_cast<Learner*>(handle)->LoadModel(in);
} else if (common::FileExtension(fname) == "ubj") {
auto str = read_file();
Json in = Json::Load(StringView{str}, std::ios::binary);
static_cast<Learner *>(handle)->LoadModel(in); static_cast<Learner *>(handle)->LoadModel(in);
} else { } else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r")); std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
@ -913,27 +916,45 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
API_END(); API_END();
} }
namespace {
void WarnOldModel() {
if (XGBOOST_VER_MAJOR >= 2) {
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
"`ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.";
}
}
} // anonymous namespace
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) { XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w")); std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w"));
auto *learner = static_cast<Learner *>(handle); auto *learner = static_cast<Learner *>(handle);
learner->Configure(); learner->Configure();
if (common::FileExtension(c_fname) == "json") { auto save_json = [&](std::ios::openmode mode) {
Json out{Object()}; Json out{Object()};
learner->SaveModel(&out); learner->SaveModel(&out);
std::string str; std::vector<char> str;
Json::Dump(out, &str); Json::Dump(out, &str, mode);
fo->Write(str.c_str(), str.size()); fo->Write(str.data(), str.size());
};
if (common::FileExtension(c_fname) == "json") {
save_json(std::ios::out);
} else if (common::FileExtension(c_fname) == "ubj") {
save_json(std::ios::binary);
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
"`deprecated` to choose between formats.";
save_json(std::ios::out);
} else { } else {
WarnOldModel();
auto *bst = static_cast<Learner *>(handle); auto *bst = static_cast<Learner *>(handle);
bst->SaveModel(fo.get()); bst->SaveModel(fo.get());
} }
API_END(); API_END();
} }
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf,
const void* buf,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
@ -942,6 +963,39 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_config,
xgboost::bst_ulong *out_len, char const **out_dptr) {
API_BEGIN();
CHECK_HANDLE();
auto config = Json::Load(StringView{json_config});
auto format = RequiredArg<String>(config, "format", __func__);
auto *learner = static_cast<Learner *>(handle);
std::string &raw_str = learner->GetThreadLocal().ret_str;
raw_str.clear();
learner->Configure();
Json out{Object{}};
if (format == "json") {
learner->SaveModel(&out);
Json::Dump(out, &raw_str);
} else if (format == "ubj") {
learner->SaveModel(&out);
Json::Dump(out, &raw_str, std::ios::binary);
} else if (format == "deprecated") {
WarnOldModel();
common::MemoryBufferStream fo(&raw_str);
learner->SaveModel(&fo);
} else {
LOG(FATAL) << "Unknown format: `" << format << "`";
}
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
xgboost::bst_ulong* out_len, xgboost::bst_ulong* out_len,
const char** out_dptr) { const char** out_dptr) {
@ -952,6 +1006,8 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
raw_str.resize(0); raw_str.resize(0);
common::MemoryBufferStream fo(&raw_str); common::MemoryBufferStream fo(&raw_str);
LOG(WARNING) << "`" << __func__
<< "` is deprecated, please use `XGBoosterSaveModelToBuffer` instead.";
learner->Configure(); learner->Configure();
learner->SaveModel(&fo); learner->SaveModel(&fo);
@ -1208,7 +1264,8 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
CHECK_HANDLE(); CHECK_HANDLE();
auto *learner = static_cast<Learner *>(handle); auto *learner = static_cast<Learner *>(handle);
auto config = Json::Load(StringView{json_config}); auto config = Json::Load(StringView{json_config});
auto importance = get<String const>(config["importance_type"]);
auto importance = RequiredArg<String>(config, "importance_type", __func__);
std::string feature_map_uri; std::string feature_map_uri;
if (!IsA<Null>(config["feature_map"])) { if (!IsA<Null>(config["feature_map"])) {
feature_map_uri = get<String const>(config["feature_map"]); feature_map_uri = get<String const>(config["feature_map"]);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright (c) 2021 by XGBoost Contributors * Copyright (c) 2021-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_C_API_C_API_UTILS_H_ #ifndef XGBOOST_C_API_C_API_UTILS_H_
#define XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_
@ -241,5 +241,25 @@ inline void GenerateFeatureMap(Learner const *learner,
} }
void XGBBuildInfoDevice(Json* p_info); void XGBBuildInfoDevice(Json* p_info);
template <typename JT>
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`";
}
return get<std::remove_const_t<JT> const>(it->second);
}
template <typename JT, typename T>
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend()) {
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_ #endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -111,7 +111,7 @@ class ConfigParser {
const auto last_char = str.find_last_not_of(" \t\n\r"); const auto last_char = str.find_last_not_of(" \t\n\r");
if (first_char == std::string::npos) { if (first_char == std::string::npos) {
// Every character in str is a whitespace // Every character in str is a whitespace
return std::string(); return {};
} }
CHECK_NE(last_char, std::string::npos); CHECK_NE(last_char, std::string::npos);
const auto substr_len = last_char + 1 - first_char; const auto substr_len = last_char + 1 - first_char;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright (c) by XGBoost Contributors 2019 * Copyright (c) by XGBoost Contributors 2019-2022
*/ */
#if defined(__unix__) #if defined(__unix__)
#include <sys/stat.h> #include <sys/stat.h>
@ -142,5 +142,18 @@ std::string LoadSequentialFile(std::string uri, bool stream) {
buffer.resize(total); buffer.resize(total);
return buffer; return buffer;
} }
std::string FileExtension(std::string fname, bool lower) {
if (lower) {
std::transform(fname.begin(), fname.end(), fname.begin(),
[](char c) { return std::tolower(c); });
}
auto splited = Split(fname, '.');
if (splited.size() > 1) {
return splited.back();
} else {
return "";
}
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright by XGBoost Contributors 2014-2022
* \file io.h * \file io.h
* \brief general stream interface for serialization, I/O * \brief general stream interface for serialization, I/O
* \author Tianqi Chen * \author Tianqi Chen
@ -86,15 +86,31 @@ class FixedSizeStream : public PeekableInStream {
*/ */
std::string LoadSequentialFile(std::string uri, bool stream = false); std::string LoadSequentialFile(std::string uri, bool stream = false);
inline std::string FileExtension(std::string const& fname) { /**
auto splited = Split(fname, '.'); * \brief Get file extension from file name.
if (splited.size() > 1) { *
return splited.back(); * \param lower Return in lower case.
} else { *
return ""; * \return File extension without the `.`
} */
} std::string FileExtension(std::string fname, bool lower = true);
/**
* \brief Read the whole buffer from dmlc stream.
*/
inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) {
std::string buffer;
if (auto fixed_size = dynamic_cast<common::MemoryFixSizeBuffer*>(fi)) {
fixed_size->Seek(common::MemoryFixSizeBuffer::kSeekEnd);
size_t size = fixed_size->Tell();
buffer.resize(size);
fixed_size->Seek(0);
CHECK_EQ(fixed_size->Read(&buffer[0], size), size);
} else {
FixedSizeStream{fp}.Take(&buffer);
}
return buffer;
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_ #endif // XGBOOST_COMMON_IO_H_

View File

@ -1,6 +1,7 @@
/*! /*!
* Copyright 2019-2021 by Contributors * Copyright 2019-2022 by Contributors
*/ */
#include <algorithm>
#include <utility> #include <utility>
#include <limits> #include <limits>
#include "xgboost/json.h" #include "xgboost/json.h"
@ -13,22 +14,28 @@ void GBLinearModel::SaveModel(Json* p_out) const {
auto& out = *p_out; auto& out = *p_out;
size_t const n_weights = weight.size(); size_t const n_weights = weight.size();
std::vector<Json> j_weights(n_weights); F32Array j_weights{n_weights};
for (size_t i = 0; i < n_weights; ++i) { std::copy(weight.begin(), weight.end(), j_weights.GetArray().begin());
j_weights[i] = weight[i];
}
out["weights"] = std::move(j_weights); out["weights"] = std::move(j_weights);
out["boosted_rounds"] = Json{this->num_boosted_rounds}; out["boosted_rounds"] = Json{this->num_boosted_rounds};
} }
void GBLinearModel::LoadModel(Json const& in) { void GBLinearModel::LoadModel(Json const& in) {
auto const& j_weights = get<Array const>(in["weights"]); auto const& obj = get<Object const>(in);
auto weight_it = obj.find("weights");
if (IsA<F32Array>(weight_it->second)) {
auto const& j_weights = get<F32Array const>(weight_it->second);
weight.resize(j_weights.size());
std::copy(j_weights.begin(), j_weights.end(), weight.begin());
} else {
auto const& j_weights = get<Array const>(weight_it->second);
auto n_weights = j_weights.size(); auto n_weights = j_weights.size();
weight.resize(n_weights); weight.resize(n_weights);
for (size_t i = 0; i < n_weights; ++i) { for (size_t i = 0; i < n_weights; ++i) {
weight[i] = get<Number const>(j_weights[i]); weight[i] = get<Number const>(j_weights[i]);
} }
auto const& obj = get<Object const>(in); }
auto boosted_rounds = obj.find("boosted_rounds"); auto boosted_rounds = obj.find("boosted_rounds");
if (boosted_rounds != obj.cend()) { if (boosted_rounds != obj.cend()) {
this->num_boosted_rounds = get<Integer const>(boosted_rounds->second); this->num_boosted_rounds = get<Integer const>(boosted_rounds->second);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2020 by Contributors * Copyright 2019-2022 by Contributors
*/ */
#include <utility> #include <utility>
@ -69,13 +69,13 @@ void GBTreeModel::SaveModel(Json* p_out) const {
out["gbtree_model_param"] = ToJson(param); out["gbtree_model_param"] = ToJson(param);
std::vector<Json> trees_json(trees.size()); std::vector<Json> trees_json(trees.size());
for (size_t t = 0; t < trees.size(); ++t) { common::ParallelFor(trees.size(), omp_get_max_threads(), [&](auto t) {
auto const& tree = trees[t]; auto const& tree = trees[t];
Json tree_json{Object()}; Json tree_json{Object()};
tree->SaveModel(&tree_json); tree->SaveModel(&tree_json);
tree_json["id"] = Integer(static_cast<Integer::Int>(t)); tree_json["id"] = Integer{static_cast<Integer::Int>(t)};
trees_json[t] = std::move(tree_json); trees_json[t] = std::move(tree_json);
} });
std::vector<Json> tree_info_json(tree_info.size()); std::vector<Json> tree_info_json(tree_info.size());
for (size_t i = 0; i < tree_info.size(); ++i) { for (size_t i = 0; i < tree_info.size(); ++i) {
@ -95,11 +95,11 @@ void GBTreeModel::LoadModel(Json const& in) {
auto const& trees_json = get<Array const>(in["trees"]); auto const& trees_json = get<Array const>(in["trees"]);
trees.resize(trees_json.size()); trees.resize(trees_json.size());
for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT common::ParallelFor(trees_json.size(), omp_get_max_threads(), [&](auto t) {
auto tree_id = get<Integer>(trees_json[t]["id"]); auto tree_id = get<Integer>(trees_json[t]["id"]);
trees.at(tree_id).reset(new RegTree()); trees.at(tree_id).reset(new RegTree());
trees.at(tree_id)->LoadModel(trees_json[t]); trees.at(tree_id)->LoadModel(trees_json[t]);
} });
tree_info.resize(param.num_trees); tree_info.resize(param.num_trees);
auto const& tree_info_json = get<Array const>(in["tree_info"]); auto const& tree_info_json = get<Array const>(in["tree_info"]);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2021 by Contributors * Copyright 2014-2022 by Contributors
* \file learner.cc * \file learner.cc
* \brief Implementation of learning algorithm. * \brief Implementation of learning algorithm.
* \author Tianqi Chen * \author Tianqi Chen
@ -706,6 +706,21 @@ class LearnerConfiguration : public Learner {
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
namespace {
StringView ModelMsg() {
return StringView{
R"doc(
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
older XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
for more details about differences between saving model and serializing.
)doc"};
}
} // anonymous namespace
class LearnerIO : public LearnerConfiguration { class LearnerIO : public LearnerConfiguration {
private: private:
std::set<std::string> saved_configs_ = {"num_round"}; std::set<std::string> saved_configs_ = {"num_round"};
@ -714,12 +729,17 @@ class LearnerIO : public LearnerConfiguration {
std::string const serialisation_header_ { u8"CONFIG-offset:" }; std::string const serialisation_header_ { u8"CONFIG-offset:" };
public: public:
explicit LearnerIO(std::vector<std::shared_ptr<DMatrix> > cache) : explicit LearnerIO(std::vector<std::shared_ptr<DMatrix>> cache) : LearnerConfiguration{cache} {}
LearnerConfiguration{cache} {}
void LoadModel(Json const& in) override { void LoadModel(Json const& in) override {
CHECK(IsA<Object>(in)); CHECK(IsA<Object>(in));
Version::Load(in); auto version = Version::Load(in);
if (std::get<0>(version) == 1 && std::get<1>(version) < 6) {
LOG(WARNING)
<< "Found JSON model saved before XGBoost 1.6, please save the model using current "
"version again. The support for old JSON model will be discontinued in XGBoost 2.3.";
}
auto const& learner = get<Object>(in["learner"]); auto const& learner = get<Object>(in["learner"]);
mparam_.FromJson(learner.at("learner_model_param")); mparam_.FromJson(learner.at("learner_model_param"));
@ -733,8 +753,8 @@ class LearnerIO : public LearnerConfiguration {
auto const& gradient_booster = learner.at("gradient_booster"); auto const& gradient_booster = learner.at("gradient_booster");
name = get<String>(gradient_booster["name"]); name = get<String>(gradient_booster["name"]);
tparam_.UpdateAllowUnknown(Args{{"booster", name}}); tparam_.UpdateAllowUnknown(Args{{"booster", name}});
gbm_.reset(GradientBooster::Create(tparam_.booster, gbm_.reset(
&generic_parameters_, &learner_model_param_)); GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_));
gbm_->LoadModel(gradient_booster); gbm_->LoadModel(gradient_booster);
auto const& j_attributes = get<Object const>(learner.at("attributes")); auto const& j_attributes = get<Object const>(learner.at("attributes"));
@ -747,19 +767,16 @@ class LearnerIO : public LearnerConfiguration {
auto it = learner.find("feature_names"); auto it = learner.find("feature_names");
if (it != learner.cend()) { if (it != learner.cend()) {
auto const& feature_names = get<Array const>(it->second); auto const& feature_names = get<Array const>(it->second);
feature_names_.clear(); feature_names_.resize(feature_names.size());
for (auto const &name : feature_names) { std::transform(feature_names.cbegin(), feature_names.cend(), feature_names_.begin(),
feature_names_.emplace_back(get<String const>(name)); [](Json const& fn) { return get<String const>(fn); });
}
} }
it = learner.find("feature_types"); it = learner.find("feature_types");
if (it != learner.cend()) { if (it != learner.cend()) {
auto const& feature_types = get<Array const>(it->second); auto const& feature_types = get<Array const>(it->second);
feature_types_.clear(); feature_types_.resize(feature_types.size());
for (auto const &name : feature_types) { std::transform(feature_types.cbegin(), feature_types.cend(), feature_types_.begin(),
auto type = get<String const>(name); [](Json const& fn) { return get<String const>(fn); });
feature_types_.emplace_back(type);
}
} }
this->need_configuration_ = true; this->need_configuration_ = true;
@ -799,6 +816,7 @@ class LearnerIO : public LearnerConfiguration {
feature_types.emplace_back(type); feature_types.emplace_back(type);
} }
} }
// About to be deprecated by JSON format // About to be deprecated by JSON format
void LoadModel(dmlc::Stream* fi) override { void LoadModel(dmlc::Stream* fi) override {
generic_parameters_.UpdateAllowUnknown(Args{}); generic_parameters_.UpdateAllowUnknown(Args{});
@ -817,15 +835,20 @@ class LearnerIO : public LearnerConfiguration {
} }
} }
if (header[0] == '{') { if (header[0] == '{') { // Dispatch to JSON
// Dispatch to JSON auto buffer = common::ReadAll(fi, &fp);
auto json_stream = common::FixedSizeStream(&fp); Json model;
std::string buffer; if (header[1] == '"') {
json_stream.Take(&buffer); model = Json::Load(StringView{buffer});
auto model = Json::Load({buffer.c_str(), buffer.size()}); } else if (std::isalpha(header[1])) {
model = Json::Load(StringView{buffer}, std::ios::binary);
} else {
LOG(FATAL) << "Invalid model format";
}
this->LoadModel(model); this->LoadModel(model);
return; return;
} }
// use the peekable reader. // use the peekable reader.
fi = &fp; fi = &fp;
// read parameter // read parameter
@ -988,40 +1011,41 @@ class LearnerIO : public LearnerConfiguration {
memory_snapshot["Config"] = Object(); memory_snapshot["Config"] = Object();
auto& config = memory_snapshot["Config"]; auto& config = memory_snapshot["Config"];
this->SaveConfig(&config); this->SaveConfig(&config);
std::string out_str;
Json::Dump(memory_snapshot, &out_str); std::vector<char> stream;
fo->Write(out_str.c_str(), out_str.size()); Json::Dump(memory_snapshot, &stream, std::ios::binary);
fo->Write(stream.data(), stream.size());
} }
void Load(dmlc::Stream* fi) override { void Load(dmlc::Stream* fi) override {
common::PeekableInStream fp(fi); common::PeekableInStream fp(fi);
char c {0}; char header[2];
fp.PeekRead(&c, 1); fp.PeekRead(header, 2);
if (c == '{') { if (header[0] == '{') {
std::string buffer; auto buffer = common::ReadAll(fi, &fp);
common::FixedSizeStream{&fp}.Take(&buffer); Json memory_snapshot;
auto memory_snapshot = Json::Load({buffer.c_str(), buffer.size()}); if (header[1] == '"') {
memory_snapshot = Json::Load(StringView{buffer});
LOG(WARNING) << ModelMsg();
} else if (std::isalpha(header[1])) {
memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary);
} else {
LOG(FATAL) << "Invalid serialization file.";
}
if (IsA<Null>(memory_snapshot["Model"])) {
// R has xgb.load that doesn't distinguish whether configuration is saved.
// We should migrate to use `xgb.load.raw` instead.
this->LoadModel(memory_snapshot);
} else {
this->LoadModel(memory_snapshot["Model"]); this->LoadModel(memory_snapshot["Model"]);
this->LoadConfig(memory_snapshot["Config"]); this->LoadConfig(memory_snapshot["Config"]);
}
} else { } else {
std::string header; std::string header;
header.resize(serialisation_header_.size()); header.resize(serialisation_header_.size());
CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size()); CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size());
// Avoid printing the content in loaded header, which might be random binary code. // Avoid printing the content in loaded header, which might be random binary code.
CHECK(header == serialisation_header_) // NOLINT CHECK(header == serialisation_header_) << ModelMsg();
<< R"doc(
If you are loading a serialized model (like pickle in Python) generated by older
XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. There's a simple script for helping
the process. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
for reference to the script, and more details about differences between saving model and
serializing.
)doc";
int64_t sz {-1}; int64_t sz {-1};
CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz)); CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz));
if (!DMLC_IO_NO_ENDIAN_SWAP) { if (!DMLC_IO_NO_ENDIAN_SWAP) {

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2021 by Contributors * Copyright 2015-2022 by Contributors
* \file tree_model.cc * \file tree_model.cc
* \brief model structure for tree * \brief model structure for tree
*/ */
@ -893,27 +893,57 @@ void RegTree::Save(dmlc::Stream* fo) const {
} }
} }
} }
// typed array, not boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
std::vector<T> const& arr, size_t i) {
return arr[i];
}
// typed array boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
std::is_same<JT, Boolean>::value,
bool>
GetElem(std::vector<T> const& arr, size_t i) {
return arr[i] == 1;
}
// json array
template <typename JT, typename T>
std::enable_if_t<
std::is_same<T, Json>::value,
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
GetElem(std::vector<T> const& arr, size_t i) {
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
return get<Integer const>(arr[i]) == 1;
}
return get<JT const>(arr[i]);
}
template <bool typed>
void RegTree::LoadCategoricalSplit(Json const& in) { void RegTree::LoadCategoricalSplit(Json const& in) {
auto const& categories_segments = get<Array const>(in["categories_segments"]); using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
auto const& categories_sizes = get<Array const>(in["categories_sizes"]); using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
auto const& categories_nodes = get<Array const>(in["categories_nodes"]);
auto const& categories = get<Array const>(in["categories"]); auto const& categories_segments = get<I64ArrayT>(in["categories_segments"]);
auto const& categories_sizes = get<I64ArrayT>(in["categories_sizes"]);
auto const& categories_nodes = get<I32ArrayT>(in["categories_nodes"]);
auto const& categories = get<I32ArrayT>(in["categories"]);
size_t cnt = 0; size_t cnt = 0;
bst_node_t last_cat_node = -1; bst_node_t last_cat_node = -1;
if (!categories_nodes.empty()) { if (!categories_nodes.empty()) {
last_cat_node = get<Integer const>(categories_nodes[cnt]); last_cat_node = GetElem<Integer>(categories_nodes, cnt);
} }
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) { for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
if (nidx == last_cat_node) { if (nidx == last_cat_node) {
auto j_begin = get<Integer const>(categories_segments[cnt]); auto j_begin = GetElem<Integer>(categories_segments, cnt);
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin; auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()}; bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
CHECK_NE(j_end - j_begin, 0) << nidx; CHECK_NE(j_end - j_begin, 0) << nidx;
for (auto j = j_begin; j < j_end; ++j) { for (auto j = j_begin; j < j_end; ++j) {
auto const &category = get<Integer const>(categories[j]); auto const& category = GetElem<Integer>(categories, j);
auto cat = common::AsCat(category); auto cat = common::AsCat(category);
max_cat = std::max(max_cat, cat); max_cat = std::max(max_cat, cat);
} }
@ -924,7 +954,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
std::vector<uint32_t> cat_bits_storage(size, 0); std::vector<uint32_t> cat_bits_storage(size, 0);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)}; common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) { for (auto j = j_begin; j < j_end; ++j) {
cat_bits.Set(common::AsCat(get<Integer const>(categories[j]))); cat_bits.Set(common::AsCat(GetElem<Integer>(categories, j)));
} }
auto begin = split_categories_.size(); auto begin = split_categories_.size();
@ -936,9 +966,9 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
++cnt; ++cnt;
if (cnt == categories_nodes.size()) { if (cnt == categories_nodes.size()) {
last_cat_node = -1; last_cat_node = -1; // Don't break, we still need to initialize the remaining nodes.
} else { } else {
last_cat_node = get<Integer const>(categories_nodes[cnt]); last_cat_node = GetElem<Integer>(categories_nodes, cnt);
} }
} else { } else {
split_categories_segments_[nidx].beg = categories.size(); split_categories_segments_[nidx].beg = categories.size();
@ -947,104 +977,144 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
} }
} }
template void RegTree::LoadCategoricalSplit<true>(Json const& in);
template void RegTree::LoadCategoricalSplit<false>(Json const& in);
void RegTree::SaveCategoricalSplit(Json* p_out) const { void RegTree::SaveCategoricalSplit(Json* p_out) const {
auto& out = *p_out; auto& out = *p_out;
CHECK_EQ(this->split_types_.size(), param.num_nodes); CHECK_EQ(this->split_types_.size(), param.num_nodes);
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
std::vector<Json> categories_segments; I64Array categories_segments;
std::vector<Json> categories_sizes; I64Array categories_sizes;
std::vector<Json> categories; I32Array categories; // bst_cat_t = int32_t
std::vector<Json> categories_nodes; I32Array categories_nodes; // bst_note_t = int32_t
for (size_t i = 0; i < nodes_.size(); ++i) { for (size_t i = 0; i < nodes_.size(); ++i) {
if (this->split_types_[i] == FeatureType::kCategorical) { if (this->split_types_[i] == FeatureType::kCategorical) {
categories_nodes.emplace_back(i); categories_nodes.GetArray().emplace_back(i);
auto begin = categories.size(); auto begin = categories.Size();
categories_segments.emplace_back(static_cast<Integer::Int>(begin)); categories_segments.GetArray().emplace_back(begin);
auto segment = split_categories_segments_[i]; auto segment = split_categories_segments_[i];
auto node_categories = auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size);
this->GetSplitCategories().subspan(segment.beg, segment.size);
common::KCatBitField const cat_bits(node_categories); common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) { for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) { if (cat_bits.Check(i)) {
categories.emplace_back(static_cast<Integer::Int>(i)); categories.GetArray().emplace_back(i);
} }
} }
size_t size = categories.size() - begin; size_t size = categories.Size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size)); categories_sizes.GetArray().emplace_back(size);
CHECK_NE(size, 0); CHECK_NE(size, 0);
} }
} }
out["categories_segments"] = categories_segments; out["categories_segments"] = std::move(categories_segments);
out["categories_sizes"] = categories_sizes; out["categories_sizes"] = std::move(categories_sizes);
out["categories_nodes"] = categories_nodes; out["categories_nodes"] = std::move(categories_nodes);
out["categories"] = categories; out["categories"] = std::move(categories);
} }
void RegTree::LoadModel(Json const& in) { template <bool typed, bool feature_is_64,
FromJson(in["tree_param"], &param); typename FloatArrayT = std::conditional_t<typed, F32Array const, Array const>,
auto n_nodes = param.num_nodes; typename U8ArrayT = std::conditional_t<typed, U8Array const, Array const>,
typename I32ArrayT = std::conditional_t<typed, I32Array const, Array const>,
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>,
typename IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT, I32ArrayT>>
bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>* p_stats,
std::vector<FeatureType>* p_split_types, std::vector<RegTree::Node>* p_nodes,
std::vector<RegTree::Segment>* p_split_categories_segments) {
auto& stats = *p_stats;
auto& split_types = *p_split_types;
auto& nodes = *p_nodes;
auto& split_categories_segments = *p_split_categories_segments;
FromJson(in["tree_param"], param);
auto n_nodes = param->num_nodes;
CHECK_NE(n_nodes, 0); CHECK_NE(n_nodes, 0);
// stats // stats
auto const& loss_changes = get<Array const>(in["loss_changes"]); auto const& loss_changes = get<FloatArrayT>(in["loss_changes"]);
CHECK_EQ(loss_changes.size(), n_nodes); CHECK_EQ(loss_changes.size(), n_nodes);
auto const& sum_hessian = get<Array const>(in["sum_hessian"]); auto const& sum_hessian = get<FloatArrayT>(in["sum_hessian"]);
CHECK_EQ(sum_hessian.size(), n_nodes); CHECK_EQ(sum_hessian.size(), n_nodes);
auto const& base_weights = get<Array const>(in["base_weights"]); auto const& base_weights = get<FloatArrayT>(in["base_weights"]);
CHECK_EQ(base_weights.size(), n_nodes); CHECK_EQ(base_weights.size(), n_nodes);
// nodes // nodes
auto const& lefts = get<Array const>(in["left_children"]); auto const& lefts = get<I32ArrayT>(in["left_children"]);
CHECK_EQ(lefts.size(), n_nodes); CHECK_EQ(lefts.size(), n_nodes);
auto const& rights = get<Array const>(in["right_children"]); auto const& rights = get<I32ArrayT>(in["right_children"]);
CHECK_EQ(rights.size(), n_nodes); CHECK_EQ(rights.size(), n_nodes);
auto const& parents = get<Array const>(in["parents"]); auto const& parents = get<I32ArrayT>(in["parents"]);
CHECK_EQ(parents.size(), n_nodes); CHECK_EQ(parents.size(), n_nodes);
auto const& indices = get<Array const>(in["split_indices"]); auto const& indices = get<IndexArrayT>(in["split_indices"]);
CHECK_EQ(indices.size(), n_nodes); CHECK_EQ(indices.size(), n_nodes);
auto const& conds = get<Array const>(in["split_conditions"]); auto const& conds = get<FloatArrayT>(in["split_conditions"]);
CHECK_EQ(conds.size(), n_nodes); CHECK_EQ(conds.size(), n_nodes);
auto const& default_left = get<Array const>(in["default_left"]); auto const& default_left = get<U8ArrayT>(in["default_left"]);
CHECK_EQ(default_left.size(), n_nodes); CHECK_EQ(default_left.size(), n_nodes);
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend(); bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
std::vector<Json> split_type; std::remove_const_t<std::remove_reference_t<decltype(get<U8ArrayT const>(in["split_type"]))>>
split_type;
if (has_cat) { if (has_cat) {
split_type = get<Array const>(in["split_type"]); split_type = get<U8ArrayT const>(in["split_type"]);
} }
stats_.clear(); stats = std::remove_reference_t<decltype(stats)>(n_nodes);
nodes_.clear(); nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
split_categories_segments = std::remove_reference_t<decltype(split_categories_segments)>(n_nodes);
stats_.resize(n_nodes); static_assert(std::is_integral<decltype(GetElem<Integer>(lefts, 0))>::value, "");
nodes_.resize(n_nodes); static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
split_types_.resize(n_nodes); CHECK_EQ(n_nodes, split_categories_segments.size());
split_categories_segments_.resize(n_nodes);
CHECK_EQ(n_nodes, split_categories_segments_.size());
for (int32_t i = 0; i < n_nodes; ++i) { for (int32_t i = 0; i < n_nodes; ++i) {
auto& s = stats_[i]; auto& s = stats[i];
s.loss_chg = get<Number const>(loss_changes[i]); s.loss_chg = GetElem<Number>(loss_changes, i);
s.sum_hess = get<Number const>(sum_hessian[i]); s.sum_hess = GetElem<Number>(sum_hessian, i);
s.base_weight = get<Number const>(base_weights[i]); s.base_weight = GetElem<Number>(base_weights, i);
auto& n = nodes_[i]; auto& n = nodes[i];
bst_node_t left = get<Integer const>(lefts[i]); bst_node_t left = GetElem<Integer>(lefts, i);
bst_node_t right = get<Integer const>(rights[i]); bst_node_t right = GetElem<Integer>(rights, i);
bst_node_t parent = get<Integer const>(parents[i]); bst_node_t parent = GetElem<Integer>(parents, i);
bst_feature_t ind = get<Integer const>(indices[i]); bst_feature_t ind = GetElem<Integer>(indices, i);
float cond { get<Number const>(conds[i]) }; float cond{GetElem<Number>(conds, i)};
bool dft_left { get<Boolean const>(default_left[i]) }; bool dft_left{GetElem<Boolean>(default_left, i)};
n = Node{left, right, parent, ind, cond, dft_left}; n = RegTree::Node{left, right, parent, ind, cond, dft_left};
if (has_cat) { if (has_cat) {
split_types_[i] = split_types[i] = static_cast<FeatureType>(GetElem<Integer>(split_type, i));
static_cast<FeatureType>(get<Integer const>(split_type[i]));
} }
} }
return has_cat;
}
void RegTree::LoadModel(Json const& in) {
bool has_cat{false};
bool typed = IsA<F32Array>(in["loss_changes"]);
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
if (typed && feature_is_64) {
has_cat = LoadModelImpl<true, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (typed && !feature_is_64) {
has_cat = LoadModelImpl<true, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (!typed && feature_is_64) {
has_cat = LoadModelImpl<false, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else {
has_cat = LoadModelImpl<false, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
}
if (has_cat) { if (has_cat) {
this->LoadCategoricalSplit(in); if (typed) {
this->LoadCategoricalSplit<true>(in);
} else {
this->LoadCategoricalSplit<false>(in);
}
} else { } else {
this->split_categories_segments_.resize(this->param.num_nodes); this->split_categories_segments_.resize(this->param.num_nodes);
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
@ -1058,7 +1128,7 @@ void RegTree::LoadModel(Json const& in) {
} }
// easier access to [] operator // easier access to [] operator
auto& self = *this; auto& self = *this;
for (auto nid = 1; nid < n_nodes; ++nid) { for (auto nid = 1; nid < param.num_nodes; ++nid) {
auto parent = self[nid].Parent(); auto parent = self[nid].Parent();
CHECK_NE(parent, RegTree::kInvalidNodeId); CHECK_NE(parent, RegTree::kInvalidNodeId);
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid); self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
@ -1079,39 +1149,51 @@ void RegTree::SaveModel(Json* p_out) const {
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size())); CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
out["tree_param"] = ToJson(param); out["tree_param"] = ToJson(param);
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes)); CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
using I = Integer::Int;
auto n_nodes = param.num_nodes; auto n_nodes = param.num_nodes;
// stats // stats
std::vector<Json> loss_changes(n_nodes); F32Array loss_changes(n_nodes);
std::vector<Json> sum_hessian(n_nodes); F32Array sum_hessian(n_nodes);
std::vector<Json> base_weights(n_nodes); F32Array base_weights(n_nodes);
// nodes // nodes
std::vector<Json> lefts(n_nodes); I32Array lefts(n_nodes);
std::vector<Json> rights(n_nodes); I32Array rights(n_nodes);
std::vector<Json> parents(n_nodes); I32Array parents(n_nodes);
std::vector<Json> indices(n_nodes);
std::vector<Json> conds(n_nodes);
std::vector<Json> default_left(n_nodes); F32Array conds(n_nodes);
std::vector<Json> split_type(n_nodes); U8Array default_left(n_nodes);
U8Array split_type(n_nodes);
CHECK_EQ(this->split_types_.size(), param.num_nodes); CHECK_EQ(this->split_types_.size(), param.num_nodes);
auto save_tree = [&](auto* p_indices_array) {
auto& indices_array = *p_indices_array;
for (bst_node_t i = 0; i < n_nodes; ++i) { for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i]; auto const& s = stats_[i];
loss_changes[i] = s.loss_chg; loss_changes.Set(i, s.loss_chg);
sum_hessian[i] = s.sum_hess; sum_hessian.Set(i, s.sum_hess);
base_weights[i] = s.base_weight; base_weights.Set(i, s.base_weight);
auto const& n = nodes_[i]; auto const& n = nodes_[i];
lefts[i] = static_cast<I>(n.LeftChild()); lefts.Set(i, n.LeftChild());
rights[i] = static_cast<I>(n.RightChild()); rights.Set(i, n.RightChild());
parents[i] = static_cast<I>(n.Parent()); parents.Set(i, n.Parent());
indices[i] = static_cast<I>(n.SplitIndex()); indices_array.Set(i, n.SplitIndex());
conds[i] = n.SplitCond(); conds.Set(i, n.SplitCond());
default_left[i] = n.DefaultLeft(); default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
split_type[i] = static_cast<I>(this->NodeSplitType(i)); split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
}
};
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
I64Array indices_64(n_nodes);
save_tree(&indices_64);
out["split_indices"] = std::move(indices_64);
} else {
I32Array indices_32(n_nodes);
save_tree(&indices_32);
out["split_indices"] = std::move(indices_32);
} }
this->SaveCategoricalSplit(&out); this->SaveCategoricalSplit(&out);
@ -1124,7 +1206,7 @@ void RegTree::SaveModel(Json* p_out) const {
out["left_children"] = std::move(lefts); out["left_children"] = std::move(lefts);
out["right_children"] = std::move(rights); out["right_children"] = std::move(rights);
out["parents"] = std::move(parents); out["parents"] = std::move(parents);
out["split_indices"] = std::move(indices);
out["split_conditions"] = std::move(conds); out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left); out["default_left"] = std::move(default_left);
} }

View File

@ -32,6 +32,7 @@ dependencies:
- awscli - awscli
- numba - numba
- llvmlite - llvmlite
- py-ubjson
- pip: - pip:
- shap - shap
- ipython # required by shap at import time. - ipython # required by shap at import time.

View File

@ -31,6 +31,7 @@ dependencies:
- jsonschema - jsonschema
- boto3 - boto3
- awscli - awscli
- py-ubjson
- pip: - pip:
- sphinx_rtd_theme - sphinx_rtd_theme
- datatable - datatable

View File

@ -18,3 +18,4 @@ dependencies:
- jsonschema - jsonschema
- python-graphviz - python-graphviz
- pip - pip
- py-ubjson

View File

@ -16,3 +16,4 @@ dependencies:
- python-graphviz - python-graphviz
- modin-ray - modin-ray
- pip - pip
- py-ubjson

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2020 XGBoost contributors * Copyright 2019-2022 XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/version_config.h> #include <xgboost/version_config.h>
@ -150,6 +150,33 @@ TEST(CAPI, JsonModelIO) {
ASSERT_EQ(model_str_0.front(), '{'); ASSERT_EQ(model_str_0.front(), '{');
ASSERT_EQ(model_str_0, model_str_1); ASSERT_EQ(model_str_0, model_str_1);
/**
* In memory
*/
bst_ulong len{0};
char const *data;
XGBoosterSaveModelToBuffer(handle, R"({"format": "ubj"})", &len, &data);
ASSERT_GT(len, 3);
XGBoosterLoadModelFromBuffer(handle, data, len);
char const *saved;
bst_ulong saved_len{0};
XGBoosterSaveModelToBuffer(handle, R"({"format": "ubj"})", &saved_len, &saved);
ASSERT_EQ(len, saved_len);
auto l = StringView{data, len};
auto r = StringView{saved, saved_len};
ASSERT_EQ(l.size(), r.size());
ASSERT_EQ(l, r);
std::string buffer;
Json::Dump(Json::Load(l, std::ios::binary), &buffer);
ASSERT_EQ(model_str_0.size() - 1, buffer.size());
ASSERT_EQ(model_str_0.back(), '\0');
ASSERT_TRUE(std::equal(model_str_0.begin(), model_str_0.end() - 1, buffer.begin()));
ASSERT_EQ(XGBoosterSaveModelToBuffer(handle, R"({})", &len, &data), -1);
ASSERT_EQ(XGBoosterSaveModelToBuffer(handle, R"({"format": "foo"})", &len, &data), -1);
} }
TEST(CAPI, CatchDMLCError) { TEST(CAPI, CatchDMLCError) {

View File

@ -178,8 +178,8 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
learner->Save(&fo); learner->Save(&fo);
} }
Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()}); Json m_0 = Json::Load(StringView{continued_model}, std::ios::binary);
Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()}); Json m_1 = Json::Load(StringView{model_at_2kiter}, std::ios::binary);
CompareJSON(m_0, m_1); CompareJSON(m_0, m_1);
} }
@ -214,8 +214,8 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
common::MemoryBufferStream fo(&serialised_model_tmp); common::MemoryBufferStream fo(&serialised_model_tmp);
learner->Save(&fo); learner->Save(&fo);
Json m_0 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()}); Json m_0 = Json::Load(StringView{model_at_2kiter}, std::ios::binary);
Json m_1 = Json::Load(StringView{serialised_model_tmp.c_str(), serialised_model_tmp.size()}); Json m_1 = Json::Load(StringView{serialised_model_tmp}, std::ios::binary);
// GPU ID is changed as data is coming from device. // GPU ID is changed as data is coming from device.
ASSERT_EQ(get<Object>(m_0["Config"]["learner"]["generic_param"]).erase("gpu_id"), ASSERT_EQ(get<Object>(m_0["Config"]["learner"]["generic_param"]).erase("gpu_id"),
get<Object>(m_1["Config"]["learner"]["generic_param"]).erase("gpu_id")); get<Object>(m_1["Config"]["learner"]["generic_param"]).erase("gpu_id"));

View File

@ -198,8 +198,7 @@ void CheckReload(RegTree const &tree) {
Json saved{Object()}; Json saved{Object()};
loaded_tree.SaveModel(&saved); loaded_tree.SaveModel(&saved);
auto same = out == saved; ASSERT_EQ(out, saved);
ASSERT_TRUE(same);
} }
TEST(Tree, CategoricalIO) { TEST(Tree, CategoricalIO) {
@ -433,12 +432,12 @@ TEST(Tree, JsonIO) {
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3"); ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0"); ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3ul); ASSERT_EQ(get<I32Array const>(j_tree["left_children"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3ul); ASSERT_EQ(get<I32Array const>(j_tree["right_children"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3ul); ASSERT_EQ(get<I32Array const>(j_tree["parents"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3ul); ASSERT_EQ(get<I32Array const>(j_tree["split_indices"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3ul); ASSERT_EQ(get<F32Array const>(j_tree["split_conditions"]).size(), 3ul);
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3ul); ASSERT_EQ(get<U8Array const>(j_tree["default_left"]).size(), 3ul);
RegTree loaded_tree; RegTree loaded_tree;
loaded_tree.LoadModel(j_tree); loaded_tree.LoadModel(j_tree);

View File

@ -14,7 +14,7 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
def json_model(model_path, parameters): def json_model(model_path: str, parameters: dict) -> dict:
X = np.random.random((10, 3)) X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,)) y = np.random.randint(2, size=(10,))
@ -22,9 +22,14 @@ def json_model(model_path, parameters):
bst = xgb.train(parameters, dm1) bst = xgb.train(parameters, dm1)
bst.save_model(model_path) bst.save_model(model_path)
if model_path.endswith("ubj"):
import ubjson
with open(model_path, "rb") as ubjfd:
model = ubjson.load(ubjfd)
else:
with open(model_path, 'r') as fd: with open(model_path, 'r') as fd:
model = json.load(fd) model = json.load(fd)
return model return model
@ -259,23 +264,40 @@ class TestModels:
buf_from_raw = from_raw.save_raw() buf_from_raw = from_raw.save_raw()
assert buf == buf_from_raw assert buf == buf_from_raw
def test_model_json_io(self): def run_model_json_io(self, parameters: dict, ext: str) -> None:
if ext == "ubj" and tm.no_ubjson()["condition"]:
pytest.skip(tm.no_ubjson()["reason"])
loc = locale.getpreferredencoding(False) loc = locale.getpreferredencoding(False)
model_path = 'test_model_json_io.json' model_path = 'test_model_json_io.' + ext
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters) j_model = json_model(model_path, parameters)
assert isinstance(j_model['learner'], dict) assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file=model_path) bst = xgb.Booster(model_file=model_path)
bst.save_model(fname=model_path) bst.save_model(fname=model_path)
if ext == "ubj":
import ubjson
with open(model_path, "rb") as ubjfd:
j_model = ubjson.load(ubjfd)
else:
with open(model_path, 'r') as fd: with open(model_path, 'r') as fd:
j_model = json.load(fd) j_model = json.load(fd)
assert isinstance(j_model['learner'], dict) assert isinstance(j_model['learner'], dict)
os.remove(model_path) os.remove(model_path)
assert locale.getpreferredencoding(False) == loc assert locale.getpreferredencoding(False) == loc
@pytest.mark.parametrize("ext", ["json", "ubj"])
def test_model_json_io(self, ext: str) -> None:
parameters = {"booster": "gbtree", "tree_method": "hist"}
self.run_model_json_io(parameters, ext)
parameters = {"booster": "gblinear"}
self.run_model_json_io(parameters, ext)
parameters = {"booster": "dart", "tree_method": "hist"}
self.run_model_json_io(parameters, ext)
@pytest.mark.skipif(**tm.no_json_schema()) @pytest.mark.skipif(**tm.no_json_schema())
def test_json_io_schema(self): def test_json_io_schema(self):
import jsonschema import jsonschema

View File

@ -2,6 +2,7 @@ import pickle
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import os import os
import json
kRows = 100 kRows = 100
@ -15,13 +16,14 @@ def generate_data():
class TestPickling: class TestPickling:
def run_model_pickling(self, xgb_params): def run_model_pickling(self, xgb_params) -> str:
X, y = generate_data() X, y = generate_data()
dtrain = xgb.DMatrix(X, y) dtrain = xgb.DMatrix(X, y)
bst = xgb.train(xgb_params, dtrain) bst = xgb.train(xgb_params, dtrain)
dump_0 = bst.get_dump(dump_format='json') dump_0 = bst.get_dump(dump_format='json')
assert dump_0 assert dump_0
config_0 = bst.save_config()
filename = 'model.pkl' filename = 'model.pkl'
@ -42,9 +44,22 @@ class TestPickling:
if os.path.exists(filename): if os.path.exists(filename):
os.remove(filename) os.remove(filename)
config_1 = bst.save_config()
assert config_0 == config_1
return json.loads(config_0)
def test_model_pickling_json(self): def test_model_pickling_json(self):
params = { def check(config):
'nthread': 1, updater = config["learner"]["gradient_booster"]["updater"]
'tree_method': 'hist', if params["tree_method"] == "exact":
} subsample = updater["grow_colmaker"]["train_param"]["subsample"]
self.run_model_pickling(params) else:
subsample = updater["grow_quantile_histmaker"]["train_param"]["subsample"]
assert float(subsample) == 0.5
params = {"nthread": 8, "tree_method": "hist", "subsample": 0.5}
config = self.run_model_pickling(params)
check(config)
params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5}
config = self.run_model_pickling(params)
check(config)

View File

@ -29,6 +29,15 @@ except ImportError:
memory = Memory('./cachedir', verbose=0) memory = Memory('./cachedir', verbose=0)
def no_ubjson():
reason = "ubjson is not intsalled."
try:
import ubjson # noqa
return {"condition": False, "reason": reason}
except ImportError:
return {"condition": True, "reason": reason}
def no_sklearn(): def no_sklearn():
return {'condition': not SKLEARN_INSTALLED, return {'condition': not SKLEARN_INSTALLED,
'reason': 'Scikit-Learn is not installed'} 'reason': 'Scikit-Learn is not installed'}