Merge pull request #959 from tqchen/master
Fix continue training in CLI
This commit is contained in:
commit
d02bd41623
@ -26,6 +26,15 @@ inline std::vector<std::string> Split(const std::string& s, char delim) {
|
|||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// simple routine to convert any data to string
|
||||||
|
template<typename T>
|
||||||
|
inline std::string ToString(const T& data) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << data;
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_COMMON_COMMON_H_
|
#endif // XGBOOST_COMMON_COMMON_H_
|
||||||
|
|||||||
@ -126,6 +126,9 @@ class GBTree : public GradientBooster {
|
|||||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
|
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
|
||||||
sizeof(int) * mparam.num_trees);
|
sizeof(int) * mparam.num_trees);
|
||||||
}
|
}
|
||||||
|
this->cfg.clear();
|
||||||
|
this->cfg.push_back(std::make_pair(std::string("num_feature"),
|
||||||
|
common::ToString(mparam.num_feature)));
|
||||||
// clear the predict buffer.
|
// clear the predict buffer.
|
||||||
this->ResetPredBuffer(num_pbuffer);
|
this->ResetPredBuffer(num_pbuffer);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include "./common/io.h"
|
#include "./common/io.h"
|
||||||
|
#include "./common/common.h"
|
||||||
#include "./common/random.h"
|
#include "./common/random.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -27,13 +28,6 @@ Learner::Dump2Text(const FeatureMap& fmap, int option) const {
|
|||||||
return gbm_->Dump2Text(fmap, option);
|
return gbm_->Dump2Text(fmap, option);
|
||||||
}
|
}
|
||||||
|
|
||||||
// simple routine to convert any data to string
|
|
||||||
template<typename T>
|
|
||||||
inline std::string ToString(const T& data) {
|
|
||||||
std::ostringstream os;
|
|
||||||
os << data;
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief training parameter for regression */
|
/*! \brief training parameter for regression */
|
||||||
struct LearnerModelParam
|
struct LearnerModelParam
|
||||||
@ -192,7 +186,7 @@ class LearnerImpl : public Learner {
|
|||||||
common::GlobalRandom().seed(tparam.seed);
|
common::GlobalRandom().seed(tparam.seed);
|
||||||
|
|
||||||
// set number of features correctly.
|
// set number of features correctly.
|
||||||
cfg_["num_feature"] = ToString(mparam.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
||||||
if (gbm_.get() != nullptr) {
|
if (gbm_.get() != nullptr) {
|
||||||
gbm_->Configure(cfg_.begin(), cfg_.end());
|
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
@ -252,13 +246,13 @@ class LearnerImpl : public Learner {
|
|||||||
attributes_ = std::map<std::string, std::string>(
|
attributes_ = std::map<std::string, std::string>(
|
||||||
attr.begin(), attr.end());
|
attr.begin(), attr.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (metrics_.size() == 0) {
|
if (metrics_.size() == 0) {
|
||||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||||
}
|
}
|
||||||
this->base_score_ = mparam.base_score;
|
this->base_score_ = mparam.base_score;
|
||||||
gbm_->ResetPredBuffer(pred_buffer_size_);
|
gbm_->ResetPredBuffer(pred_buffer_size_);
|
||||||
cfg_["num_class"] = ToString(mparam.num_class);
|
cfg_["num_class"] = common::ToString(mparam.num_class);
|
||||||
|
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
||||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,7 +389,7 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setup
|
// setup
|
||||||
cfg_["num_feature"] = ToString(mparam.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
||||||
CHECK(obj_.get() == nullptr && gbm_.get() == nullptr);
|
CHECK(obj_.get() == nullptr && gbm_.get() == nullptr);
|
||||||
obj_.reset(ObjFunction::Create(name_obj_));
|
obj_.reset(ObjFunction::Create(name_obj_));
|
||||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user