[REFACTOR] Add alias, allow missing variables, init gbm interface

This commit is contained in:
tqchen 2016-01-02 04:40:49 -08:00
parent 4f26d98150
commit e4567bbc47
16 changed files with 145 additions and 96 deletions

@ -1 +1 @@
Subproject commit e5c8ed0342fbbdf7e38cafafb126f91bcca5ec72 Subproject commit ec454218564fee8e531aee02b8943a4634330ce1

View File

@ -9,6 +9,14 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
/*!
* \brief string flag for R library, to leave hooks when needed.
*/
#ifndef XGBOOST_STRICT_R_MODE
#define XGBOOST_STRICT_R_MODE 0
#endif
/*! \brief namespace of xgboo st*/
namespace xgboost { namespace xgboost {
/*! /*!
* \brief unsigned interger type used in boost, * \brief unsigned interger type used in boost,
@ -28,11 +36,14 @@ struct bst_gpair {
bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {} bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {}
}; };
/*! \brief small eps gap for minimum split decision. */
const float rt_eps = 1e-5f; const float rt_eps = 1e-5f;
// min gap between feature values to allow a split happen /*! \brief min gap between feature values to allow a split happen */
const float rt_2eps = rt_eps * 2.0f; const float rt_2eps = rt_eps * 2.0f;
/*! \brief define unsigned long for openmp loop */
typedef dmlc::omp_ulong omp_ulong; typedef dmlc::omp_ulong omp_ulong;
/*! \brief define unsigned int for openmp loop */
typedef dmlc::omp_uint bst_omp_uint; typedef dmlc::omp_uint bst_omp_uint;
/*! /*!

View File

@ -1,52 +1,58 @@
/*! /*!
* Copyright by Contributors * Copyright by Contributors
* \file gbm.h * \file gbm.h
* \brief interface of gradient booster, that learns through gradient statistics * \brief Interface of gradient booster,
* that learns through gradient statistics.
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_GBM_GBM_H_ #ifndef XGBOOST_GBM_H_
#define XGBOOST_GBM_GBM_H_ #define XGBOOST_GBM_H_
#include <dmlc/registry.h>
#include <vector> #include <vector>
#include <utility>
#include <string> #include <string>
#include "../data.h" #include <functional>
#include "../utils/io.h" #include "./base.h"
#include "../utils/fmap.h" #include "./data.h"
#include "./feature_map.h"
namespace xgboost { namespace xgboost {
/*! \brief namespace for gradient booster */
namespace gbm {
/*! /*!
* \brief interface of gradient boosting model * \brief interface of gradient boosting model.
*/ */
class IGradBooster { class GradientBooster {
public: public:
/*! \brief virtual destructor */
virtual ~GradientBooster() {}
/*! /*!
* \brief set parameters from outside * \brief Set the configuration of gradient boosting.
* \param name name of the parameter *
* \param val value of the parameter * User must call configure before trainig.
*
* \param cfg configurations on both training and model parameters.
*/ */
virtual void SetParam(const char *name, const char *val) = 0; virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
/*!
* \brief Initialize the model.
* User need to call Configure before calling InitModel.
*/
virtual void InitModel() = 0;
/*! /*!
* \brief load model from stream * \brief load model from stream
* \param fi input stream * \param fi input stream.
* \param with_pbuffer whether the incoming data contains pbuffer
*/ */
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0; // NOLINT(*) virtual void LoadModel(dmlc::Stream* fi) = 0;
/*! /*!
* \brief save model to stream * \brief save model to stream.
* \param fo output stream * \param fo output stream
* \param with_pbuffer whether save out pbuffer
*/ */
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0; // NOLINT(*) virtual void SaveModel(dmlc::Stream* fo) const = 0;
/*! /*!
* \brief initialize the model * \brief reset the predict buffer size.
*/ * This will invalidate all the previous cached results
virtual void InitModel(void) = 0; * and recalculate from scratch
/*! * \param num_pbuffer The size of predict buffer.
* \brief reset the predict buffer
* this will invalidate all the previous cached results
* and recalculate from scratch
*/ */
virtual void ResetPredBuffer(size_t num_pbuffer) {} virtual void ResetPredBuffer(size_t num_pbuffer) {}
/*! /*!
@ -54,7 +60,7 @@ class IGradBooster {
* return true if model is only updated in DoBoost * return true if model is only updated in DoBoost
* after all Allreduce calls * after all Allreduce calls
*/ */
virtual bool AllowLazyCheckPoint(void) const { virtual bool AllowLazyCheckPoint() const {
return false; return false;
} }
/*! /*!
@ -66,26 +72,24 @@ class IGradBooster {
* \param in_gpair address of the gradient pair statistics of the data * \param in_gpair address of the gradient pair statistics of the data
* the booster may change content of gpair * the booster may change content of gpair
*/ */
virtual void DoBoost(IFMatrix *p_fmat, virtual void DoBoost(DMatrix* p_fmat,
int64_t buffer_offset, int64_t buffer_offset,
const BoosterInfo &info, std::vector<bst_gpair>* in_gpair) = 0;
std::vector<bst_gpair> *in_gpair) = 0;
/*! /*!
* \brief generate predictions for given feature matrix * \brief generate predictions for given feature matrix
* \param p_fmat feature matrix * \param p_fmat feature matrix
* \param buffer_offset buffer index offset of these instances, if equals -1 * \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm * this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction * a buffer index is assigned to each instance that requires repeative prediction
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size") * the size of buffer is set by convention using GradientBooster.ResetPredBuffer(size);
* \param info extra side information that may be needed for prediction * \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/ */
virtual void Predict(IFMatrix *p_fmat, virtual void Predict(DMatrix* dmat,
int64_t buffer_offset, int64_t buffer_offset,
const BoosterInfo &info, std::vector<float>* out_preds,
std::vector<float> *out_preds,
unsigned ntree_limit = 0) = 0; unsigned ntree_limit = 0) = 0;
/*! /*!
* \brief online prediction function, predict score for one instance at a time * \brief online prediction function, predict score for one instance at a time
@ -99,38 +103,58 @@ class IGradBooster {
* \param root_index the root index * \param root_index the root index
* \sa Predict * \sa Predict
*/ */
virtual void Predict(const SparseBatch::Inst &inst, virtual void Predict(const SparseBatch::Inst& inst,
std::vector<float> *out_preds, std::vector<float>* out_preds,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
unsigned root_index = 0) = 0; unsigned root_index = 0) = 0;
/*! /*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector * \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor * this is only valid in gbtree predictor
* \param p_fmat feature matrix * \param dmat feature matrix
* \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/ */
virtual void PredictLeaf(IFMatrix *p_fmat, virtual void PredictLeaf(DMatrix* dmat,
const BoosterInfo &info, std::vector<float>* out_preds,
std::vector<float> *out_preds,
unsigned ntree_limit = 0) = 0; unsigned ntree_limit = 0) = 0;
/*! /*!
* \brief dump the model in text format * \brief dump the model in text format
* \param fmap feature map that may help give interpretations of feature * \param fmap feature map that may help give interpretations of feature
* \param option extra option of the dump model * \param option extra option of the dump model
* \return a vector of dump for boosters * \return a vector of dump for boosters.
*/ */
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) = 0; virtual std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) = 0;
// destrcutor /*!
virtual ~IGradBooster(void){} * \breif create a gradient booster from given name
* \param name name of gradient booster
*/
static GradientBooster* Create(const char *name);
}; };
/*! /*!
* \breif create a gradient booster from given name * \brief Registry entry for tree updater.
* \param name name of gradient booster
*/ */
IGradBooster* CreateGradBooster(const char *name); struct GradientBoosterReg
} // namespace gbm : public dmlc::FunctionRegEntryBase<GradientBoosterReg,
std::function<GradientBooster* ()> > {
};
/*!
* \brief Macro to register gradient booster.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GBM(GBTree, "gbtree")
* .describe("Boosting tree ensembles.")
* .set_body([]() {
* return new GradientBooster<TStats>();
* });
* \endcode
*/
#define XGBOOST_REGISTER_GBM(UniqueId, Name) \
static ::xgboost::GradientBoosterReg & __make_ ## GradientBoosterReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(#Name)
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_GBM_GBM_H_ #endif // XGBOOST_GBM_H_

View File

@ -70,7 +70,7 @@ struct TreeUpdaterReg
* *
* \code * \code
* // example of registering a objective ndcg@k * // example of registering a objective ndcg@k
* XGBOOST_REGISTER_METRIC(ColMaker, "colmaker") * XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "colmaker")
* .describe("Column based tree maker.") * .describe("Column based tree maker.")
* .set_body([]() { * .set_body([]() {
* return new ColMaker<TStats>(); * return new ColMaker<TStats>();
@ -78,7 +78,8 @@ struct TreeUpdaterReg
* \endcode * \endcode
*/ */
#define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \ #define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
static ::xgboost::TreeUpdaterReg & __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \ static ::xgboost::TreeUpdaterReg& __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(#Name) ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(#Name)
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_H_ #endif // XGBOOST_TREE_UPDATER_H_

View File

@ -12,7 +12,7 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! /*!
* \brief Random Engine * \brief Define mt19937 as default type Random Engine.
*/ */
typedef std::mt19937 RandomEngine; typedef std::mt19937 RandomEngine;
/*! /*!

View File

@ -6,6 +6,7 @@
#include <xgboost/objective.h> #include <xgboost/objective.h>
#include <xgboost/metric.h> #include <xgboost/metric.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <xgboost/gbm.h>
#include "./common/random.h" #include "./common/random.h"
#include "./common/base64.h" #include "./common/base64.h"
@ -13,6 +14,7 @@ namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
} // namespace dmlc } // namespace dmlc
namespace xgboost { namespace xgboost {
@ -45,7 +47,6 @@ Metric* Metric::Create(const char* name) {
} }
} }
// implement factory functions
TreeUpdater* TreeUpdater::Create(const char* name) { TreeUpdater* TreeUpdater::Create(const char* name) {
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) { if (e == nullptr) {
@ -54,6 +55,14 @@ TreeUpdater* TreeUpdater::Create(const char* name) {
return (e->body)(); return (e->body)();
} }
GradientBooster* GradientBooster::Create(const char* name) {
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown gbm type " << name;
}
return (e->body)();
}
namespace common { namespace common {
RandomEngine& GlobalRandom() { RandomEngine& GlobalRandom() {
static RandomEngine inst; static RandomEngine inst;
@ -61,4 +70,3 @@ RandomEngine& GlobalRandom() {
} }
} }
} // namespace xgboost } // namespace xgboost

View File

@ -30,7 +30,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
: output_prob_(output_prob) { : output_prob_(output_prob) {
} }
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args); param_.InitAllowUnknown(args);
} }
void GetGradient(const std::vector<float>& preds, void GetGradient(const std::vector<float>& preds,
const MetaInfo& info, const MetaInfo& info,

View File

@ -33,7 +33,7 @@ struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
class LambdaRankObj : public ObjFunction { class LambdaRankObj : public ObjFunction {
public: public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args); param_.InitAllowUnknown(args);
} }
void GetGradient(const std::vector<float>& preds, void GetGradient(const std::vector<float>& preds,
const MetaInfo& info, const MetaInfo& info,

View File

@ -77,7 +77,7 @@ template<typename Loss>
class RegLossObj : public ObjFunction { class RegLossObj : public ObjFunction {
public: public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args); param_.InitAllowUnknown(args);
} }
void GetGradient(const std::vector<float> &preds, void GetGradient(const std::vector<float> &preds,
const MetaInfo &info, const MetaInfo &info,
@ -156,7 +156,7 @@ class PoissonRegression : public ObjFunction {
public: public:
// declare functions // declare functions
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args); param_.InitAllowUnknown(args);
} }
void GetGradient(const std::vector<float> &preds, void GetGradient(const std::vector<float> &preds,

View File

@ -16,9 +16,9 @@ namespace tree {
/*! \brief training parameters for regression tree */ /*! \brief training parameters for regression tree */
struct TrainParam : public dmlc::Parameter<TrainParam> { struct TrainParam : public dmlc::Parameter<TrainParam> {
// learning step size for a time // learning step size for a time
float eta; float learning_rate;
// minimum loss change required for a split // minimum loss change required for a split
float gamma; float min_split_loss;
// maximum depth of a tree // maximum depth of a tree
int max_depth; int max_depth;
//----- the rest parameters are less important ---- //----- the rest parameters are less important ----
@ -59,9 +59,9 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
bool silent; bool silent;
// declare the parameters // declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) { DMLC_DECLARE_PARAMETER(TrainParam) {
DMLC_DECLARE_FIELD(eta).set_lower_bound(0.0f).set_default(0.3f) DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
.describe("Learning rate(step size) of update."); .describe("Learning rate(step size) of update.");
DMLC_DECLARE_FIELD(gamma).set_lower_bound(0.0f).set_default(0.0f) DMLC_DECLARE_FIELD(min_split_loss).set_lower_bound(0.0f).set_default(0.0f)
.describe("Minimum loss reduction required to make a further partition."); .describe("Minimum loss reduction required to make a further partition.");
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6) DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6)
.describe("Maximum depth of the tree."); .describe("Maximum depth of the tree.");
@ -101,6 +101,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.describe("Number of threads used for training."); .describe("Number of threads used for training.");
DMLC_DECLARE_FIELD(silent).set_default(false) DMLC_DECLARE_FIELD(silent).set_default(false)
.describe("Not print information during trainig."); .describe("Not print information during trainig.");
// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
DMLC_DECLARE_ALIAS(min_split_loss, gamma);
DMLC_DECLARE_ALIAS(learning_rate, eta);
} }
// calculate the cost of loss function // calculate the cost of loss function
@ -159,7 +164,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
} }
/*! \brief given the loss change, whether we need to invoke pruning */ /*! \brief given the loss change, whether we need to invoke pruning */
inline bool need_prune(double loss_chg, int depth) const { inline bool need_prune(double loss_chg, int depth) const {
return loss_chg < this->gamma; return loss_chg < this->min_split_loss;
} }
/*! \brief whether we can split with current hessian */ /*! \brief whether we can split with current hessian */
inline bool cannot_split(double sum_hess, int depth) const { inline bool cannot_split(double sum_hess, int depth) const {

View File

@ -29,7 +29,7 @@ namespace tree {
class BaseMaker: public TreeUpdater { class BaseMaker: public TreeUpdater {
public: public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args); param.InitAllowUnknown(args);
} }
protected: protected:

View File

@ -20,7 +20,7 @@ template<typename TStats>
class ColMaker: public TreeUpdater { class ColMaker: public TreeUpdater {
public: public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args); param.InitAllowUnknown(args);
} }
void Update(const std::vector<bst_gpair> &gpair, void Update(const std::vector<bst_gpair> &gpair,
@ -28,14 +28,14 @@ class ColMaker: public TreeUpdater {
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(dmat->info()); TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.eta; float lr = param.learning_rate;
param.eta = lr / trees.size(); param.learning_rate = lr / trees.size();
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param); Builder builder(param);
builder.Update(gpair, dmat, trees[i]); builder.Update(gpair, dmat, trees[i]);
} }
param.eta = lr; param.learning_rate = lr;
} }
protected: protected:
@ -95,7 +95,7 @@ class ColMaker: public TreeUpdater {
// set all the rest expanding nodes to leaf // set all the rest expanding nodes to leaf
for (size_t i = 0; i < qexpand_.size(); ++i) { for (size_t i = 0; i < qexpand_.size(); ++i) {
const int nid = qexpand_[i]; const int nid = qexpand_[i];
(*p_tree)[nid].set_leaf(snode[nid].weight * param.eta); (*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
} }
// remember auxiliary statistics in the tree node // remember auxiliary statistics in the tree node
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
@ -606,7 +606,7 @@ class ColMaker: public TreeUpdater {
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
} else { } else {
(*p_tree)[nid].set_leaf(e.weight * param.eta); (*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
} }
} }
} }
@ -732,7 +732,7 @@ class DistColMaker : public ColMaker<TStats> {
pruner.reset(TreeUpdater::Create("prune")); pruner.reset(TreeUpdater::Create("prune"));
} }
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args); param.InitAllowUnknown(args);
pruner->Init(args); pruner->Init(args);
} }
void Update(const std::vector<bst_gpair> &gpair, void Update(const std::vector<bst_gpair> &gpair,

View File

@ -23,13 +23,13 @@ class HistMaker: public BaseMaker {
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(p_fmat->info()); TStats::CheckInfo(p_fmat->info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.eta; float lr = param.learning_rate;
param.eta = lr / trees.size(); param.learning_rate = lr / trees.size();
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, trees[i]); this->Update(gpair, p_fmat, trees[i]);
} }
param.eta = lr; param.learning_rate = lr;
} }
protected: protected:
@ -139,7 +139,7 @@ class HistMaker: public BaseMaker {
} }
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
} }
} }
// this function does two jobs // this function does two jobs
@ -246,7 +246,7 @@ class HistMaker: public BaseMaker {
this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]); this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum); this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum);
} else { } else {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
} }
} }
} }

View File

@ -22,7 +22,7 @@ class TreePruner: public TreeUpdater {
} }
// set training parameter // set training parameter
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args); param.InitAllowUnknown(args);
syncher->Init(args); syncher->Init(args);
} }
// update the tree, do pruning // update the tree, do pruning
@ -30,12 +30,12 @@ class TreePruner: public TreeUpdater {
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.eta; float lr = param.learning_rate;
param.eta = lr / trees.size(); param.learning_rate = lr / trees.size();
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
this->DoPrune(*trees[i]); this->DoPrune(*trees[i]);
} }
param.eta = lr; param.learning_rate = lr;
syncher->Update(gpair, p_fmat, trees); syncher->Update(gpair, p_fmat, trees);
} }
@ -48,7 +48,7 @@ class TreePruner: public TreeUpdater {
++s.leaf_child_cnt; ++s.leaf_child_cnt;
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) { if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
// need to be pruned // need to be pruned
tree.ChangeToLeaf(pid, param.eta * s.base_weight); tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight);
// tail recursion // tail recursion
return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2); return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
} else { } else {

View File

@ -19,7 +19,7 @@ template<typename TStats>
class TreeRefresher: public TreeUpdater { class TreeRefresher: public TreeUpdater {
public: public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args); param.InitAllowUnknown(args);
} }
// update the tree, do pruning // update the tree, do pruning
void Update(const std::vector<bst_gpair> &gpair, void Update(const std::vector<bst_gpair> &gpair,
@ -94,8 +94,8 @@ class TreeRefresher: public TreeUpdater {
reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size()); reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size());
#endif #endif
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.eta; float lr = param.learning_rate;
param.eta = lr / trees.size(); param.learning_rate = lr / trees.size();
int offset = 0; int offset = 0;
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) { for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
@ -104,7 +104,7 @@ class TreeRefresher: public TreeUpdater {
offset += trees[i]->param.num_nodes; offset += trees[i]->param.num_nodes;
} }
// set learning rate back // set learning rate back
param.eta = lr; param.learning_rate = lr;
} }
private: private:
@ -131,7 +131,7 @@ class TreeRefresher: public TreeUpdater {
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess); tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
gstats[nid].SetLeafVec(param, tree.leafvec(nid)); gstats[nid].SetLeafVec(param, tree.leafvec(nid));
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.eta); tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else { } else {
tree.stat(nid).loss_chg = static_cast<float>( tree.stat(nid).loss_chg = static_cast<float>(
gstats[tree[nid].cleft()].CalcGain(param) + gstats[tree[nid].cleft()].CalcGain(param) +

View File

@ -24,13 +24,13 @@ class SketchMaker: public BaseMaker {
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.eta; float lr = param.learning_rate;
param.eta = lr / trees.size(); param.learning_rate = lr / trees.size();
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, trees[i]); this->Update(gpair, p_fmat, trees[i]);
} }
param.eta = lr; param.learning_rate = lr;
} }
protected: protected:
@ -67,7 +67,7 @@ class SketchMaker: public BaseMaker {
// set left leaves // set left leaves
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
} }
} }
// define the sketch we want to use // define the sketch we want to use
@ -302,7 +302,7 @@ class SketchMaker: public BaseMaker {
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
} else { } else {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
} }
} }
} }