Model IO in JSON. (#5110)
This commit is contained in:
@@ -5,23 +5,25 @@
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/version_config.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "../data/simple_csr_source.h"
|
||||
#include "../common/io.h"
|
||||
#include "../data/adapter.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
// declare the data callback.
|
||||
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
|
||||
@@ -569,23 +571,43 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Learner*>(handle)->Load(fi.get());
|
||||
if (common::FileExtension(fname) == "json") {
|
||||
auto str = common::LoadSequentialFile(fname);
|
||||
CHECK_GT(str.size(), 2);
|
||||
CHECK_EQ(str[0], '{');
|
||||
Json in { Json::Load({str.c_str(), str.size()}) };
|
||||
static_cast<Learner*>(handle)->LoadModel(in);
|
||||
} else {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Learner*>(handle)->Load(fi.get());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
|
||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Save(fo.get());
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w"));
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
learner->Configure();
|
||||
if (common::FileExtension(c_fname) == "json") {
|
||||
Json out { Object() };
|
||||
learner->SaveModel(&out);
|
||||
std::string str;
|
||||
Json::Dump(out, &str);
|
||||
fo->Write(str.c_str(), str.size());
|
||||
} else {
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Save(fo.get());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
// The following two functions are `Load` and `Save` for memory based serialization
|
||||
// methods. E.g. Python pickle.
|
||||
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
const void* buf,
|
||||
xgboost::bst_ulong len) {
|
||||
const void* buf,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
@@ -594,16 +616,17 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
xgboost::bst_ulong* out_len,
|
||||
const char** out_dptr) {
|
||||
xgboost::bst_ulong* out_len,
|
||||
const char** out_dptr) {
|
||||
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
raw_str.resize(0);
|
||||
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Save(&fo);
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
learner->Configure();
|
||||
learner->Save(&fo);
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||
API_END();
|
||||
@@ -619,6 +642,7 @@ inline void XGBoostDumpModelImpl(
|
||||
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Configure();
|
||||
str_vecs = bst->DumpModel(fmap, with_stats != 0, format);
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
|
||||
Reference in New Issue
Block a user