[PYTHON] Simplify training logic, update rabit lib

This commit is contained in:
tqchen
2016-02-27 19:56:29 -08:00
parent 90bc7f8f6b
commit 4a16b729fc
10 changed files with 108 additions and 77 deletions

View File

@@ -3,6 +3,8 @@
#include <xgboost/data.h>
#include <xgboost/learner.h>
#include <xgboost/c_api.h>
#include <xgboost/logging.h>
#include <rabit/rabit.h>
#include <cstdio>
#include <vector>
#include <string>
@@ -84,6 +86,10 @@ int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {
API_BEGIN();
if (rabit::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
}
*out = DMatrix::Load(
fname, silent != 0, false);
API_END();
@@ -526,3 +532,28 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
API_END();
}
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner());
if (version != 0) {
bst->initialized_ = true;
}
API_END();
}
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst->learner());
} else {
rabit::CheckPoint(bst->learner());
}
API_END();
}
// force link rabit
static int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();