Model IO in JSON. (#5110)

This commit is contained in:
Jiaming Yuan
2019-12-11 11:20:40 +08:00
committed by GitHub
parent c7cc657a4d
commit 208ab3b1ff
25 changed files with 667 additions and 165 deletions

View File

@@ -5,23 +5,25 @@
#include <cstdio>
#include <cstring>
#include <fstream>
#include <algorithm>
#include <vector>
#include <string>
#include <memory>
#include "xgboost/data.h"
#include "xgboost/learner.h"
#include "xgboost/c_api.h"
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
#include "xgboost/json.h"
#include "c_api_error.h"
#include "../data/simple_csr_source.h"
#include "../common/io.h"
#include "../data/adapter.h"
namespace xgboost {
// declare the data callback.
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
@@ -569,23 +571,43 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Learner*>(handle)->Load(fi.get());
if (common::FileExtension(fname) == "json") {
auto str = common::LoadSequentialFile(fname);
CHECK_GT(str.size(), 2);
CHECK_EQ(str[0], '{');
Json in { Json::Load({str.c_str(), str.size()}) };
static_cast<Learner*>(handle)->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Learner*>(handle)->Load(fi.get());
}
API_END();
}
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) {
API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
auto *bst = static_cast<Learner*>(handle);
bst->Save(fo.get());
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w"));
auto *learner = static_cast<Learner *>(handle);
learner->Configure();
if (common::FileExtension(c_fname) == "json") {
Json out { Object() };
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);
fo->Write(str.c_str(), str.size());
} else {
auto *bst = static_cast<Learner*>(handle);
bst->Save(fo.get());
}
API_END();
}
// The following two functions are `Load` and `Save` for memory based serialization
// methods. E.g. Python pickle.
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
const void* buf,
xgboost::bst_ulong len) {
const void* buf,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
@@ -594,16 +616,17 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
}
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
xgboost::bst_ulong* out_len,
const char** out_dptr) {
xgboost::bst_ulong* out_len,
const char** out_dptr) {
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
raw_str.resize(0);
API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str);
auto *bst = static_cast<Learner*>(handle);
bst->Save(&fo);
auto *learner = static_cast<Learner*>(handle);
learner->Configure();
learner->Save(&fo);
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
@@ -619,6 +642,7 @@ inline void XGBoostDumpModelImpl(
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
auto *bst = static_cast<Learner*>(handle);
bst->Configure();
str_vecs = bst->DumpModel(fmap, with_stats != 0, format);
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {