[breaking] Change internal model serialization to UBJSON. (#7556)
* Use typed array for models. * Change the memory snapshot format. * Add new C API for saving to raw format.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2014-2021 by Contributors
|
||||
// Copyright (c) 2014-2022 by Contributors
|
||||
#include <rabit/rabit.h>
|
||||
#include <rabit/c_api.h>
|
||||
|
||||
@@ -248,22 +248,16 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
|
||||
#endif
|
||||
|
||||
// Create from data iterator
|
||||
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
|
||||
DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next,
|
||||
char const* c_json_config,
|
||||
DMatrixHandle *out) {
|
||||
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
|
||||
char const *c_json_config, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = get<Number const>(config["missing"]);
|
||||
std::string cache = get<String const>(config["cache_prefix"]);
|
||||
int32_t n_threads = omp_get_max_threads();
|
||||
if (!IsA<Null>(config["nthread"])) {
|
||||
n_threads = get<Integer const>(config["nthread"]);
|
||||
}
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
|
||||
iter, proxy, reset, next, missing, n_threads, cache)};
|
||||
auto missing = GetMissing(config);
|
||||
std::string cache = RequiredArg<String>(config, "cache_prefix", __func__);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -358,8 +352,8 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
||||
StringView{data}, ncol);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto nthread = get<Integer const>(config["nthread"]);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -371,9 +365,9 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
|
||||
xgboost::data::ArrayAdapter(StringView{data})};
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto nthread = get<Integer const>(config["nthread"]);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -765,11 +759,11 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
auto& entry = learner->GetThreadLocal().prediction_entry;
|
||||
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
|
||||
|
||||
auto const& j_config = get<Object const>(config);
|
||||
auto type = PredictionType(get<Integer const>(j_config.at("type")));
|
||||
auto iteration_begin = get<Integer const>(j_config.at("iteration_begin"));
|
||||
auto iteration_end = get<Integer const>(j_config.at("iteration_end"));
|
||||
auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
|
||||
auto iteration_begin = RequiredArg<Integer>(config, "iteration_begin", __func__);
|
||||
auto iteration_end = RequiredArg<Integer>(config, "iteration_end", __func__);
|
||||
|
||||
auto const& j_config = get<Object const>(config);
|
||||
auto ntree_limit_it = j_config.find("ntree_limit");
|
||||
if (ntree_limit_it != j_config.cend() && !IsA<Null>(ntree_limit_it->second) &&
|
||||
get<Integer const>(ntree_limit_it->second) != 0) {
|
||||
@@ -785,7 +779,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
type == PredictionType::kApproxContribution;
|
||||
bool interactions = type == PredictionType::kInteraction ||
|
||||
type == PredictionType::kApproxInteraction;
|
||||
bool training = get<Boolean const>(config["training"]);
|
||||
bool training = RequiredArg<Boolean>(config, "training", __func__);
|
||||
learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions,
|
||||
iteration_begin, iteration_end, training,
|
||||
type == PredictionType::kLeaf, contribs, approximate,
|
||||
@@ -796,7 +790,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
auto rounds = iteration_end - iteration_begin;
|
||||
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
|
||||
// Determine shape
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
||||
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
|
||||
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
|
||||
&shape, out_dim);
|
||||
@@ -814,15 +808,15 @@ void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
|
||||
float missing = GetMissing(config);
|
||||
learner->InplacePredict(x, p_m, type, missing, &p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
RequiredArg<Integer>(config, "iteration_begin", __func__),
|
||||
RequiredArg<Integer>(config, "iteration_end", __func__));
|
||||
CHECK(p_predt);
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
||||
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
|
||||
learner->BoostedRounds(), &shape, out_dim);
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
@@ -900,12 +894,21 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar(
|
||||
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
if (common::FileExtension(fname) == "json") {
|
||||
auto read_file = [&]() {
|
||||
auto str = common::LoadSequentialFile(fname);
|
||||
CHECK_GT(str.size(), 2);
|
||||
CHECK_GE(str.size(), 3); // "{}\0"
|
||||
CHECK_EQ(str[0], '{');
|
||||
Json in { Json::Load({str.c_str(), str.size()}) };
|
||||
CHECK_EQ(str[str.size() - 2], '}');
|
||||
return str;
|
||||
};
|
||||
if (common::FileExtension(fname) == "json") {
|
||||
auto str = read_file();
|
||||
Json in{Json::Load(StringView{str})};
|
||||
static_cast<Learner*>(handle)->LoadModel(in);
|
||||
} else if (common::FileExtension(fname) == "ubj") {
|
||||
auto str = read_file();
|
||||
Json in = Json::Load(StringView{str}, std::ios::binary);
|
||||
static_cast<Learner *>(handle)->LoadModel(in);
|
||||
} else {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Learner*>(handle)->LoadModel(fi.get());
|
||||
@@ -913,32 +916,83 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) {
|
||||
namespace {
|
||||
void WarnOldModel() {
|
||||
if (XGBOOST_VER_MAJOR >= 2) {
|
||||
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
|
||||
"`ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.";
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w"));
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
learner->Configure();
|
||||
if (common::FileExtension(c_fname) == "json") {
|
||||
Json out { Object() };
|
||||
auto save_json = [&](std::ios::openmode mode) {
|
||||
Json out{Object()};
|
||||
learner->SaveModel(&out);
|
||||
std::string str;
|
||||
Json::Dump(out, &str);
|
||||
fo->Write(str.c_str(), str.size());
|
||||
std::vector<char> str;
|
||||
Json::Dump(out, &str, mode);
|
||||
fo->Write(str.data(), str.size());
|
||||
};
|
||||
if (common::FileExtension(c_fname) == "json") {
|
||||
save_json(std::ios::out);
|
||||
} else if (common::FileExtension(c_fname) == "ubj") {
|
||||
save_json(std::ios::binary);
|
||||
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
|
||||
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
|
||||
"`deprecated` to choose between formats.";
|
||||
save_json(std::ios::out);
|
||||
} else {
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
WarnOldModel();
|
||||
auto *bst = static_cast<Learner *>(handle);
|
||||
bst->SaveModel(fo.get());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
const void* buf,
|
||||
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
static_cast<Learner*>(handle)->LoadModel(&fs);
|
||||
common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*)
|
||||
static_cast<Learner *>(handle)->LoadModel(&fs);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_config,
|
||||
xgboost::bst_ulong *out_len, char const **out_dptr) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto config = Json::Load(StringView{json_config});
|
||||
auto format = RequiredArg<String>(config, "format", __func__);
|
||||
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
std::string &raw_str = learner->GetThreadLocal().ret_str;
|
||||
raw_str.clear();
|
||||
|
||||
learner->Configure();
|
||||
Json out{Object{}};
|
||||
if (format == "json") {
|
||||
learner->SaveModel(&out);
|
||||
Json::Dump(out, &raw_str);
|
||||
} else if (format == "ubj") {
|
||||
learner->SaveModel(&out);
|
||||
Json::Dump(out, &raw_str, std::ios::binary);
|
||||
} else if (format == "deprecated") {
|
||||
WarnOldModel();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
learner->SaveModel(&fo);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown format: `" << format << "`";
|
||||
}
|
||||
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -952,6 +1006,8 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
raw_str.resize(0);
|
||||
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
LOG(WARNING) << "`" << __func__
|
||||
<< "` is deprecated, please use `XGBoosterSaveModelToBuffer` instead.";
|
||||
|
||||
learner->Configure();
|
||||
learner->SaveModel(&fo);
|
||||
@@ -1208,7 +1264,8 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
auto config = Json::Load(StringView{json_config});
|
||||
auto importance = get<String const>(config["importance_type"]);
|
||||
|
||||
auto importance = RequiredArg<String>(config, "importance_type", __func__);
|
||||
std::string feature_map_uri;
|
||||
if (!IsA<Null>(config["feature_map"])) {
|
||||
feature_map_uri = get<String const>(config["feature_map"]);
|
||||
|
||||
Reference in New Issue
Block a user