[PYTHON-DIST] Distributed xgboost python training API.
This commit is contained in:
@@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname,
|
||||
<< "will split data among workers";
|
||||
}
|
||||
*out = DMatrix::Load(
|
||||
fname, silent != 0, false);
|
||||
fname, false, true);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterGetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char** out,
|
||||
int* success) {
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
if (bst->learner()->GetAttr(key, &ret_str)) {
|
||||
*out = ret_str.c_str();
|
||||
*success = 1;
|
||||
} else {
|
||||
*out = nullptr;
|
||||
*success = 0;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char* value) {
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
API_BEGIN();
|
||||
bst->learner()->SetAttr(key, value);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst->learner());
|
||||
if (version != 0) {
|
||||
if (*version != 0) {
|
||||
bst->initialized_ = true;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
|
||||
int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||
|
||||
@@ -184,7 +184,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
||||
<< " of " << npart << " parts";
|
||||
}
|
||||
// legacy handling of binary data loading
|
||||
if (file_format == "auto" && !load_row_split) {
|
||||
if (file_format == "auto" && npart == 1) {
|
||||
int magic;
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||
if (fi.get() != nullptr) {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/learner.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
@@ -43,8 +44,10 @@ struct LearnerModelParam
|
||||
unsigned num_feature;
|
||||
/* \brief number of classes, if it is multi-class classification */
|
||||
int num_class;
|
||||
/*! \brief Model contain additional properties */
|
||||
int contain_extra_attrs;
|
||||
/*! \brief reserved field */
|
||||
int reserved[31];
|
||||
int reserved[30];
|
||||
/*! \brief constructor */
|
||||
LearnerModelParam() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParam));
|
||||
@@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
|
||||
obj_.reset(ObjFunction::Create(name_obj_));
|
||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||
gbm_->Load(fi);
|
||||
if (mparam.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr;
|
||||
fi->Read(&attr);
|
||||
attributes_ = std::map<std::string, std::string>(
|
||||
attr.begin(), attr.end());
|
||||
}
|
||||
|
||||
if (metrics_.size() == 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
@@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
|
||||
fo->Write(name_obj_);
|
||||
fo->Write(name_gbm_);
|
||||
gbm_->Save(fo);
|
||||
if (mparam.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr(
|
||||
attributes_.begin(), attributes_.end());
|
||||
fo->Write(attr);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
@@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void SetAttr(const std::string& key, const std::string& value) override {
|
||||
attributes_[key] = value;
|
||||
mparam.contain_extra_attrs = 1;
|
||||
}
|
||||
|
||||
bool GetAttr(const std::string& key, std::string* out) const override {
|
||||
auto it = attributes_.find(key);
|
||||
if (it == attributes_.end()) return false;
|
||||
*out = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||
@@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
|
||||
LearnerTrainParam tparam;
|
||||
// configurations
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// attributes
|
||||
std::map<std::string, std::string> attributes_;
|
||||
// name of gbm
|
||||
std::string name_gbm_;
|
||||
// name of objective functon
|
||||
|
||||
@@ -56,6 +56,9 @@ class BaseMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief synchronize the information */
|
||||
inline void SyncInfo() {
|
||||
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
|
||||
}
|
||||
// get feature type, 0:empty 1:binary 2:real
|
||||
|
||||
@@ -313,6 +313,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
feat_helper.InitByCol(p_fmat, tree);
|
||||
cache_dmatrix_ = p_fmat;
|
||||
}
|
||||
feat_helper.SyncInfo();
|
||||
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
||||
}
|
||||
// code to create histogram
|
||||
|
||||
Reference in New Issue
Block a user