diff --git a/src/common/common.h b/src/common/common.h index d8adadae3..aaf2b8a37 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -26,6 +26,15 @@ inline std::vector Split(const std::string& s, char delim) { } return ret; } + +// simple routine to convert any data to string +template +inline std::string ToString(const T& data) { + std::ostringstream os; + os << data; + return os.str(); +} + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 29fb114a6..7e58a060a 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -126,6 +126,9 @@ class GBTree : public GradientBooster { CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees), sizeof(int) * mparam.num_trees); } + this->cfg.clear(); + this->cfg.push_back(std::make_pair(std::string("num_feature"), + common::ToString(mparam.num_feature))); // clear the predict buffer. this->ResetPredBuffer(num_pbuffer); } diff --git a/src/learner.cc b/src/learner.cc index 59f7c53f7..5d00496ef 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -14,6 +14,7 @@ #include #include #include "./common/io.h" +#include "./common/common.h" #include "./common/random.h" namespace xgboost { @@ -27,13 +28,6 @@ Learner::Dump2Text(const FeatureMap& fmap, int option) const { return gbm_->Dump2Text(fmap, option); } -// simple routine to convert any data to string -template -inline std::string ToString(const T& data) { - std::ostringstream os; - os << data; - return os.str(); -} /*! \brief training parameter for regression */ struct LearnerModelParam @@ -192,7 +186,7 @@ class LearnerImpl : public Learner { common::GlobalRandom().seed(tparam.seed); // set number of features correctly. - cfg_["num_feature"] = ToString(mparam.num_feature); + cfg_["num_feature"] = common::ToString(mparam.num_feature); if (gbm_.get() != nullptr) { gbm_->Configure(cfg_.begin(), cfg_.end()); } @@ -252,13 +246,13 @@ class LearnerImpl : public Learner { attributes_ = std::map( attr.begin(), attr.end()); } - if (metrics_.size() == 0) { metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric())); } this->base_score_ = mparam.base_score; gbm_->ResetPredBuffer(pred_buffer_size_); - cfg_["num_class"] = ToString(mparam.num_class); + cfg_["num_class"] = common::ToString(mparam.num_class); + cfg_["num_feature"] = common::ToString(mparam.num_feature); obj_->Configure(cfg_.begin(), cfg_.end()); } @@ -395,7 +389,7 @@ class LearnerImpl : public Learner { } // setup - cfg_["num_feature"] = ToString(mparam.num_feature); + cfg_["num_feature"] = common::ToString(mparam.num_feature); CHECK(obj_.get() == nullptr && gbm_.get() == nullptr); obj_.reset(ObjFunction::Create(name_obj_)); gbm_.reset(GradientBooster::Create(name_gbm_));