[PYTHON-DIST] Distributed xgboost python training API.

This commit is contained in:
tqchen
2016-02-29 10:00:37 -08:00
parent 51bb556898
commit ecb3a271be
16 changed files with 427 additions and 32 deletions

View File

@@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname,
<< "will split data among workers";
}
*out = DMatrix::Load(
fname, silent != 0, false);
fname, false, true);
API_END();
}
@@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
API_END();
}
int XGBoosterGetAttr(BoosterHandle handle,
const char* key,
const char** out,
int* success) {
Booster* bst = static_cast<Booster*>(handle);
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
if (bst->learner()->GetAttr(key, &ret_str)) {
*out = ret_str.c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
}
API_END();
}
int XGBoosterSetAttr(BoosterHandle handle,
const char* key,
const char* value) {
Booster* bst = static_cast<Booster*>(handle);
API_BEGIN();
bst->learner()->SetAttr(key, value);
API_END();
}
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner());
if (version != 0) {
if (*version != 0) {
bst->initialized_ = true;
}
API_END();
}
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) {

View File

@@ -184,7 +184,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
<< " of " << npart << " parts";
}
// legacy handling of binary data loading
if (file_format == "auto" && !load_row_split) {
if (file_format == "auto" && npart == 1) {
int magic;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi.get() != nullptr) {

View File

@@ -5,6 +5,7 @@
* \author Tianqi Chen
*/
#include <xgboost/learner.h>
#include <dmlc/io.h>
#include <algorithm>
#include <vector>
#include <utility>
@@ -43,8 +44,10 @@ struct LearnerModelParam
unsigned num_feature;
/* \brief number of classes, if it is multi-class classification */
int num_class;
/*! \brief Model contain additional properties */
int contain_extra_attrs;
/*! \brief reserved field */
int reserved[31];
int reserved[30];
/*! \brief constructor */
LearnerModelParam() {
std::memset(this, 0, sizeof(LearnerModelParam));
@@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
obj_.reset(ObjFunction::Create(name_obj_));
gbm_.reset(GradientBooster::Create(name_gbm_));
gbm_->Load(fi);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr;
fi->Read(&attr);
attributes_ = std::map<std::string, std::string>(
attr.begin(), attr.end());
}
if (metrics_.size() == 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
@@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
fo->Write(name_obj_);
fo->Write(name_gbm_);
gbm_->Save(fo);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr(
attributes_.begin(), attributes_.end());
fo->Write(attr);
}
}
void UpdateOneIter(int iter, DMatrix* train) override {
@@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
return os.str();
}
void SetAttr(const std::string& key, const std::string& value) override {
attributes_[key] = value;
mparam.contain_extra_attrs = 1;
}
bool GetAttr(const std::string& key, std::string* out) const override {
auto it = attributes_.find(key);
if (it == attributes_.end()) return false;
*out = it->second;
return true;
}
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
@@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
LearnerTrainParam tparam;
// configurations
std::map<std::string, std::string> cfg_;
// attributes
std::map<std::string, std::string> attributes_;
// name of gbm
std::string name_gbm_;
// name of objective functon

View File

@@ -56,6 +56,9 @@ class BaseMaker: public TreeUpdater {
}
}
}
}
/*! \brief synchronize the information */
inline void SyncInfo() {
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
}
// get feature type, 0:empty 1:binary 2:real

View File

@@ -313,6 +313,7 @@ class CQHistMaker: public HistMaker<TStats> {
feat_helper.InitByCol(p_fmat, tree);
cache_dmatrix_ = p_fmat;
}
feat_helper.SyncInfo();
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
}
// code to create histogram