[PYTHON-DIST] Distributed xgboost python training API.
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/learner.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
@@ -43,8 +44,10 @@ struct LearnerModelParam
|
||||
unsigned num_feature;
|
||||
/* \brief number of classes, if it is multi-class classification */
|
||||
int num_class;
|
||||
/*! \brief Model contain additional properties */
|
||||
int contain_extra_attrs;
|
||||
/*! \brief reserved field */
|
||||
int reserved[31];
|
||||
int reserved[30];
|
||||
/*! \brief constructor */
|
||||
LearnerModelParam() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParam));
|
||||
@@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
|
||||
obj_.reset(ObjFunction::Create(name_obj_));
|
||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||
gbm_->Load(fi);
|
||||
if (mparam.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr;
|
||||
fi->Read(&attr);
|
||||
attributes_ = std::map<std::string, std::string>(
|
||||
attr.begin(), attr.end());
|
||||
}
|
||||
|
||||
if (metrics_.size() == 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
@@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
|
||||
fo->Write(name_obj_);
|
||||
fo->Write(name_gbm_);
|
||||
gbm_->Save(fo);
|
||||
if (mparam.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr(
|
||||
attributes_.begin(), attributes_.end());
|
||||
fo->Write(attr);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
@@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void SetAttr(const std::string& key, const std::string& value) override {
|
||||
attributes_[key] = value;
|
||||
mparam.contain_extra_attrs = 1;
|
||||
}
|
||||
|
||||
bool GetAttr(const std::string& key, std::string* out) const override {
|
||||
auto it = attributes_.find(key);
|
||||
if (it == attributes_.end()) return false;
|
||||
*out = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||
@@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
|
||||
LearnerTrainParam tparam;
|
||||
// configurations
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// attributes
|
||||
std::map<std::string, std::string> attributes_;
|
||||
// name of gbm
|
||||
std::string name_gbm_;
|
||||
// name of objective functon
|
||||
|
||||
Reference in New Issue
Block a user