JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan
2019-12-15 17:31:53 +08:00
committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
24 changed files with 761 additions and 390 deletions

View File

@@ -458,8 +458,8 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
// xgboost implementation
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
xgboost::bst_ulong len,
BoosterHandle *out) {
xgboost::bst_ulong len,
BoosterHandle *out) {
API_BEGIN();
std::vector<std::shared_ptr<DMatrix> > mats;
for (xgboost::bst_ulong i = 0; i < len; ++i) {
@@ -485,6 +485,31 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
API_BEGIN();
CHECK_HANDLE();
std::string str {json_parameters};
Json config { Json::Load(StringView{str.c_str(), str.size()}) };
static_cast<Learner*>(handle)->LoadConfig(config);
API_END();
}
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle,
xgboost::bst_ulong *out_len,
char const** out_str) {
API_BEGIN();
CHECK_HANDLE();
Json config { Object() };
auto* learner = static_cast<Learner*>(handle);
learner->Configure();
learner->SaveConfig(&config);
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
Json::Dump(config, &raw_str);
*out_str = raw_str.c_str();
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}
XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
int iter,
DMatrixHandle dtrain) {
@@ -579,7 +604,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
static_cast<Learner*>(handle)->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Learner*>(handle)->Load(fi.get());
static_cast<Learner*>(handle)->LoadModel(fi.get());
}
API_END();
}
@@ -598,20 +623,18 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) {
fo->Write(str.c_str(), str.size());
} else {
auto *bst = static_cast<Learner*>(handle);
bst->Save(fo.get());
bst->SaveModel(fo.get());
}
API_END();
}
// The following two functions are `Load` and `Save` for memory based serialization
// methods. E.g. Python pickle.
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
const void* buf,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Learner*>(handle)->Load(&fs);
static_cast<Learner*>(handle)->LoadModel(&fs);
API_END();
}
@@ -621,6 +644,25 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
raw_str.resize(0);
API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str);
auto *learner = static_cast<Learner*>(handle);
learner->Configure();
learner->SaveModel(&fo);
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}
// The following two functions are `Load` and `Save` for memory based
// serialization methods. E.g. Python pickle.
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle,
xgboost::bst_ulong *out_len,
const char **out_dptr) {
std::string &raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
raw_str.resize(0);
API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str);
@@ -632,6 +674,41 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Learner*>(handle)->Load(&fs);
API_END();
}
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
*version = rabit::LoadCheckPoint(bst);
if (*version != 0) {
bst->Configure();
}
API_END();
}
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto* learner = static_cast<Learner*>(handle);
learner->Configure();
if (learner->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner);
} else {
rabit::CheckPoint(learner);
}
API_END();
}
inline void XGBoostDumpModelImpl(
BoosterHandle handle,
const FeatureMap& fmap,
@@ -758,29 +835,5 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
*version = rabit::LoadCheckPoint(bst);
if (*version != 0) {
bst->Configure();
}
API_END();
}
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
if (bst->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst);
} else {
rabit::CheckPoint(bst);
}
API_END();
}
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();