JSON configuration IO. (#5111)
* Add saving/loading JSON configuration. * Implement Python pickle interface with new IO routines. * Basic tests for training continuation.
This commit is contained in:
@@ -458,8 +458,8 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
|
||||
// xgboost implementation
|
||||
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
xgboost::bst_ulong len,
|
||||
BoosterHandle *out) {
|
||||
xgboost::bst_ulong len,
|
||||
BoosterHandle *out) {
|
||||
API_BEGIN();
|
||||
std::vector<std::shared_ptr<DMatrix> > mats;
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
@@ -485,6 +485,31 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::string str {json_parameters};
|
||||
Json config { Json::Load(StringView{str.c_str(), str.size()}) };
|
||||
static_cast<Learner*>(handle)->LoadConfig(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle,
|
||||
xgboost::bst_ulong *out_len,
|
||||
char const** out_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
Json config { Object() };
|
||||
auto* learner = static_cast<Learner*>(handle);
|
||||
learner->Configure();
|
||||
learner->SaveConfig(&config);
|
||||
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
Json::Dump(config, &raw_str);
|
||||
*out_str = raw_str.c_str();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
||||
int iter,
|
||||
DMatrixHandle dtrain) {
|
||||
@@ -579,7 +604,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
static_cast<Learner*>(handle)->LoadModel(in);
|
||||
} else {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Learner*>(handle)->Load(fi.get());
|
||||
static_cast<Learner*>(handle)->LoadModel(fi.get());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
@@ -598,20 +623,18 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) {
|
||||
fo->Write(str.c_str(), str.size());
|
||||
} else {
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
bst->Save(fo.get());
|
||||
bst->SaveModel(fo.get());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
// The following two functions are `Load` and `Save` for memory based serialization
|
||||
// methods. E.g. Python pickle.
|
||||
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
const void* buf,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
static_cast<Learner*>(handle)->Load(&fs);
|
||||
static_cast<Learner*>(handle)->LoadModel(&fs);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -621,6 +644,25 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
raw_str.resize(0);
|
||||
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
learner->Configure();
|
||||
learner->SaveModel(&fo);
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||
API_END();
|
||||
}
|
||||
|
||||
// The following two functions are `Load` and `Save` for memory based
|
||||
// serialization methods. E.g. Python pickle.
|
||||
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const char **out_dptr) {
|
||||
std::string &raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
raw_str.resize(0);
|
||||
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
@@ -632,6 +674,41 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
||||
const void *buf,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
static_cast<Learner*>(handle)->Load(&fs);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst);
|
||||
if (*version != 0) {
|
||||
bst->Configure();
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* learner = static_cast<Learner*>(handle);
|
||||
learner->Configure();
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner);
|
||||
} else {
|
||||
rabit::CheckPoint(learner);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
inline void XGBoostDumpModelImpl(
|
||||
BoosterHandle handle,
|
||||
const FeatureMap& fmap,
|
||||
@@ -758,29 +835,5 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst);
|
||||
if (*version != 0) {
|
||||
bst->Configure();
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
if (bst->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(bst);
|
||||
} else {
|
||||
rabit::CheckPoint(bst);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user