[PYTHON-DIST] Distributed xgboost python training API.

This commit is contained in:
tqchen
2016-02-29 10:00:37 -08:00
parent 51bb556898
commit ecb3a271be
16 changed files with 427 additions and 32 deletions

View File

@@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname,
<< "will split data among workers";
}
*out = DMatrix::Load(
fname, silent != 0, false);
fname, false, true);
API_END();
}
@@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
API_END();
}
int XGBoosterGetAttr(BoosterHandle handle,
const char* key,
const char** out,
int* success) {
Booster* bst = static_cast<Booster*>(handle);
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
if (bst->learner()->GetAttr(key, &ret_str)) {
*out = ret_str.c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
}
API_END();
}
int XGBoosterSetAttr(BoosterHandle handle,
const char* key,
const char* value) {
Booster* bst = static_cast<Booster*>(handle);
API_BEGIN();
bst->learner()->SetAttr(key, value);
API_END();
}
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner());
if (version != 0) {
if (*version != 0) {
bst->initialized_ = true;
}
API_END();
}
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) {