[PYTHON-DIST] Distributed xgboost python training API.

This commit is contained in:
tqchen
2016-02-29 10:00:37 -08:00
parent 51bb556898
commit ecb3a271be
16 changed files with 427 additions and 32 deletions

View File

@@ -5,6 +5,7 @@
* \author Tianqi Chen
*/
#include <xgboost/learner.h>
#include <dmlc/io.h>
#include <algorithm>
#include <vector>
#include <utility>
@@ -43,8 +44,10 @@ struct LearnerModelParam
unsigned num_feature;
/* \brief number of classes, if it is multi-class classification */
int num_class;
/*! \brief Model contain additional properties */
int contain_extra_attrs;
/*! \brief reserved field */
int reserved[31];
int reserved[30];
/*! \brief constructor */
LearnerModelParam() {
std::memset(this, 0, sizeof(LearnerModelParam));
@@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
obj_.reset(ObjFunction::Create(name_obj_));
gbm_.reset(GradientBooster::Create(name_gbm_));
gbm_->Load(fi);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr;
fi->Read(&attr);
attributes_ = std::map<std::string, std::string>(
attr.begin(), attr.end());
}
if (metrics_.size() == 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
@@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
fo->Write(name_obj_);
fo->Write(name_gbm_);
gbm_->Save(fo);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr(
attributes_.begin(), attributes_.end());
fo->Write(attr);
}
}
void UpdateOneIter(int iter, DMatrix* train) override {
@@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
return os.str();
}
void SetAttr(const std::string& key, const std::string& value) override {
attributes_[key] = value;
mparam.contain_extra_attrs = 1;
}
bool GetAttr(const std::string& key, std::string* out) const override {
auto it = attributes_.find(key);
if (it == attributes_.end()) return false;
*out = it->second;
return true;
}
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
@@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
LearnerTrainParam tparam;
// configurations
std::map<std::string, std::string> cfg_;
// attributes
std::map<std::string, std::string> attributes_;
// name of gbm
std::string name_gbm_;
// name of objective functon