[PYTHON] Simplify training logic, update rabit lib
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
@@ -84,6 +86,10 @@ int XGDMatrixCreateFromFile(const char *fname,
|
||||
int silent,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
}
|
||||
*out = DMatrix::Load(
|
||||
fname, silent != 0, false);
|
||||
API_END();
|
||||
@@ -526,3 +532,28 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst->learner());
|
||||
if (version != 0) {
|
||||
bst->initialized_ = true;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(bst->learner());
|
||||
} else {
|
||||
rabit::CheckPoint(bst->learner());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user