[REFACTOR] Add alias, allow missing variables, init gbm interface

This commit is contained in:
tqchen 2016-01-02 04:40:49 -08:00
parent 4f26d98150
commit e4567bbc47
16 changed files with 145 additions and 96 deletions

@ -1 +1 @@
Subproject commit e5c8ed0342fbbdf7e38cafafb126f91bcca5ec72
Subproject commit ec454218564fee8e531aee02b8943a4634330ce1

View File

@ -9,6 +9,14 @@
#include <dmlc/base.h>
#include <dmlc/omp.h>
/*!
* \brief string flag for R library, to leave hooks when needed.
*/
#ifndef XGBOOST_STRICT_R_MODE
#define XGBOOST_STRICT_R_MODE 0
#endif
/*! \brief namespace of xgboo st*/
namespace xgboost {
/*!
* \brief unsigned interger type used in boost,
@ -28,11 +36,14 @@ struct bst_gpair {
bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {}
};
/*! \brief small eps gap for minimum split decision. */
const float rt_eps = 1e-5f;
// min gap between feature values to allow a split happen
/*! \brief min gap between feature values to allow a split happen */
const float rt_2eps = rt_eps * 2.0f;
/*! \brief define unsigned long for openmp loop */
typedef dmlc::omp_ulong omp_ulong;
/*! \brief define unsigned int for openmp loop */
typedef dmlc::omp_uint bst_omp_uint;
/*!

View File

@ -1,52 +1,58 @@
/*!
* Copyright by Contributors
* \file gbm.h
* \brief interface of gradient booster, that learns through gradient statistics
* \brief Interface of gradient booster,
* that learns through gradient statistics.
* \author Tianqi Chen
*/
#ifndef XGBOOST_GBM_GBM_H_
#define XGBOOST_GBM_GBM_H_
#ifndef XGBOOST_GBM_H_
#define XGBOOST_GBM_H_
#include <dmlc/registry.h>
#include <vector>
#include <utility>
#include <string>
#include "../data.h"
#include "../utils/io.h"
#include "../utils/fmap.h"
#include <functional>
#include "./base.h"
#include "./data.h"
#include "./feature_map.h"
namespace xgboost {
/*! \brief namespace for gradient booster */
namespace gbm {
/*!
* \brief interface of gradient boosting model
* \brief interface of gradient boosting model.
*/
class IGradBooster {
class GradientBooster {
public:
/*! \brief virtual destructor */
virtual ~GradientBooster() {}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
* \brief Set the configuration of gradient boosting.
*
* User must call configure before trainig.
*
* \param cfg configurations on both training and model parameters.
*/
virtual void SetParam(const char *name, const char *val) = 0;
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
/*!
* \brief Initialize the model.
* User need to call Configure before calling InitModel.
*/
virtual void InitModel() = 0;
/*!
* \brief load model from stream
* \param fi input stream
* \param with_pbuffer whether the incoming data contains pbuffer
* \param fi input stream.
*/
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0; // NOLINT(*)
virtual void LoadModel(dmlc::Stream* fi) = 0;
/*!
* \brief save model to stream
* \brief save model to stream.
* \param fo output stream
* \param with_pbuffer whether save out pbuffer
*/
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0; // NOLINT(*)
virtual void SaveModel(dmlc::Stream* fo) const = 0;
/*!
* \brief initialize the model
*/
virtual void InitModel(void) = 0;
/*!
* \brief reset the predict buffer
* this will invalidate all the previous cached results
* \brief reset the predict buffer size.
* This will invalidate all the previous cached results
* and recalculate from scratch
* \param num_pbuffer The size of predict buffer.
*/
virtual void ResetPredBuffer(size_t num_pbuffer) {}
/*!
@ -54,7 +60,7 @@ class IGradBooster {
* return true if model is only updated in DoBoost
* after all Allreduce calls
*/
virtual bool AllowLazyCheckPoint(void) const {
virtual bool AllowLazyCheckPoint() const {
return false;
}
/*!
@ -66,9 +72,8 @@ class IGradBooster {
* \param in_gpair address of the gradient pair statistics of the data
* the booster may change content of gpair
*/
virtual void DoBoost(IFMatrix *p_fmat,
virtual void DoBoost(DMatrix* p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<bst_gpair>* in_gpair) = 0;
/*!
* \brief generate predictions for given feature matrix
@ -76,15 +81,14 @@ class IGradBooster {
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size")
* the size of buffer is set by convention using GradientBooster.ResetPredBuffer(size);
* \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void Predict(IFMatrix *p_fmat,
virtual void Predict(DMatrix* dmat,
int64_t buffer_offset,
const BoosterInfo &info,
std::vector<float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
@ -106,31 +110,51 @@ class IGradBooster {
/*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor
* \param p_fmat feature matrix
* \param info extra side information that may be needed for prediction
* \param dmat feature matrix
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void PredictLeaf(IFMatrix *p_fmat,
const BoosterInfo &info,
virtual void PredictLeaf(DMatrix* dmat,
std::vector<float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
* \brief dump the model in text format
* \param fmap feature map that may help give interpretations of feature
* \param option extra option of the dump model
* \return a vector of dump for boosters
* \return a vector of dump for boosters.
*/
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) = 0;
// destrcutor
virtual ~IGradBooster(void){}
};
virtual std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) = 0;
/*!
* \breif create a gradient booster from given name
* \param name name of gradient booster
*/
IGradBooster* CreateGradBooster(const char *name);
} // namespace gbm
static GradientBooster* Create(const char *name);
};
/*!
* \brief Registry entry for tree updater.
*/
struct GradientBoosterReg
: public dmlc::FunctionRegEntryBase<GradientBoosterReg,
std::function<GradientBooster* ()> > {
};
/*!
* \brief Macro to register gradient booster.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GBM(GBTree, "gbtree")
* .describe("Boosting tree ensembles.")
* .set_body([]() {
* return new GradientBooster<TStats>();
* });
* \endcode
*/
#define XGBOOST_REGISTER_GBM(UniqueId, Name) \
static ::xgboost::GradientBoosterReg & __make_ ## GradientBoosterReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(#Name)
} // namespace xgboost
#endif // XGBOOST_GBM_GBM_H_
#endif // XGBOOST_GBM_H_

View File

@ -70,7 +70,7 @@ struct TreeUpdaterReg
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_METRIC(ColMaker, "colmaker")
* XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "colmaker")
* .describe("Column based tree maker.")
* .set_body([]() {
* return new ColMaker<TStats>();
@ -80,5 +80,6 @@ struct TreeUpdaterReg
#define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
static ::xgboost::TreeUpdaterReg& __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(#Name)
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_H_

View File

@ -12,7 +12,7 @@
namespace xgboost {
namespace common {
/*!
* \brief Random Engine
* \brief Define mt19937 as default type Random Engine.
*/
typedef std::mt19937 RandomEngine;
/*!

View File

@ -6,6 +6,7 @@
#include <xgboost/objective.h>
#include <xgboost/metric.h>
#include <xgboost/tree_updater.h>
#include <xgboost/gbm.h>
#include "./common/random.h"
#include "./common/base64.h"
@ -13,6 +14,7 @@ namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
} // namespace dmlc
namespace xgboost {
@ -45,7 +47,6 @@ Metric* Metric::Create(const char* name) {
}
}
// implement factory functions
TreeUpdater* TreeUpdater::Create(const char* name) {
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) {
@ -54,6 +55,14 @@ TreeUpdater* TreeUpdater::Create(const char* name) {
return (e->body)();
}
GradientBooster* GradientBooster::Create(const char* name) {
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown gbm type " << name;
}
return (e->body)();
}
namespace common {
RandomEngine& GlobalRandom() {
static RandomEngine inst;
@ -61,4 +70,3 @@ RandomEngine& GlobalRandom() {
}
}
} // namespace xgboost

View File

@ -30,7 +30,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
: output_prob_(output_prob) {
}
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args);
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float>& preds,
const MetaInfo& info,

View File

@ -33,7 +33,7 @@ struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
class LambdaRankObj : public ObjFunction {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args);
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float>& preds,
const MetaInfo& info,

View File

@ -77,7 +77,7 @@ template<typename Loss>
class RegLossObj : public ObjFunction {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args);
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float> &preds,
const MetaInfo &info,
@ -156,7 +156,7 @@ class PoissonRegression : public ObjFunction {
public:
// declare functions
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.Init(args);
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<float> &preds,

View File

@ -16,9 +16,9 @@ namespace tree {
/*! \brief training parameters for regression tree */
struct TrainParam : public dmlc::Parameter<TrainParam> {
// learning step size for a time
float eta;
float learning_rate;
// minimum loss change required for a split
float gamma;
float min_split_loss;
// maximum depth of a tree
int max_depth;
//----- the rest parameters are less important ----
@ -59,9 +59,9 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
bool silent;
// declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) {
DMLC_DECLARE_FIELD(eta).set_lower_bound(0.0f).set_default(0.3f)
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
.describe("Learning rate(step size) of update.");
DMLC_DECLARE_FIELD(gamma).set_lower_bound(0.0f).set_default(0.0f)
DMLC_DECLARE_FIELD(min_split_loss).set_lower_bound(0.0f).set_default(0.0f)
.describe("Minimum loss reduction required to make a further partition.");
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6)
.describe("Maximum depth of the tree.");
@ -101,6 +101,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.describe("Number of threads used for training.");
DMLC_DECLARE_FIELD(silent).set_default(false)
.describe("Not print information during trainig.");
// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
DMLC_DECLARE_ALIAS(min_split_loss, gamma);
DMLC_DECLARE_ALIAS(learning_rate, eta);
}
// calculate the cost of loss function
@ -159,7 +164,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
}
/*! \brief given the loss change, whether we need to invoke pruning */
inline bool need_prune(double loss_chg, int depth) const {
return loss_chg < this->gamma;
return loss_chg < this->min_split_loss;
}
/*! \brief whether we can split with current hessian */
inline bool cannot_split(double sum_hess, int depth) const {

View File

@ -29,7 +29,7 @@ namespace tree {
class BaseMaker: public TreeUpdater {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args);
param.InitAllowUnknown(args);
}
protected:

View File

@ -20,7 +20,7 @@ template<typename TStats>
class ColMaker: public TreeUpdater {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args);
param.InitAllowUnknown(args);
}
void Update(const std::vector<bst_gpair> &gpair,
@ -28,14 +28,14 @@ class ColMaker: public TreeUpdater {
const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.eta;
param.eta = lr / trees.size();
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, dmat, trees[i]);
}
param.eta = lr;
param.learning_rate = lr;
}
protected:
@ -95,7 +95,7 @@ class ColMaker: public TreeUpdater {
// set all the rest expanding nodes to leaf
for (size_t i = 0; i < qexpand_.size(); ++i) {
const int nid = qexpand_[i];
(*p_tree)[nid].set_leaf(snode[nid].weight * param.eta);
(*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
}
// remember auxiliary statistics in the tree node
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
@ -606,7 +606,7 @@ class ColMaker: public TreeUpdater {
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
} else {
(*p_tree)[nid].set_leaf(e.weight * param.eta);
(*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
}
}
}
@ -732,7 +732,7 @@ class DistColMaker : public ColMaker<TStats> {
pruner.reset(TreeUpdater::Create("prune"));
}
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args);
param.InitAllowUnknown(args);
pruner->Init(args);
}
void Update(const std::vector<bst_gpair> &gpair,

View File

@ -23,13 +23,13 @@ class HistMaker: public BaseMaker {
const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(p_fmat->info());
// rescale learning rate according to size of trees
float lr = param.eta;
param.eta = lr / trees.size();
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, trees[i]);
}
param.eta = lr;
param.learning_rate = lr;
}
protected:
@ -139,7 +139,7 @@ class HistMaker: public BaseMaker {
}
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
}
// this function does two jobs
@ -246,7 +246,7 @@ class HistMaker: public BaseMaker {
this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum);
} else {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
}
}

View File

@ -22,7 +22,7 @@ class TreePruner: public TreeUpdater {
}
// set training parameter
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args);
param.InitAllowUnknown(args);
syncher->Init(args);
}
// update the tree, do pruning
@ -30,12 +30,12 @@ class TreePruner: public TreeUpdater {
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
// rescale learning rate according to size of trees
float lr = param.eta;
param.eta = lr / trees.size();
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
for (size_t i = 0; i < trees.size(); ++i) {
this->DoPrune(*trees[i]);
}
param.eta = lr;
param.learning_rate = lr;
syncher->Update(gpair, p_fmat, trees);
}
@ -48,7 +48,7 @@ class TreePruner: public TreeUpdater {
++s.leaf_child_cnt;
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
// need to be pruned
tree.ChangeToLeaf(pid, param.eta * s.base_weight);
tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight);
// tail recursion
return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
} else {

View File

@ -19,7 +19,7 @@ template<typename TStats>
class TreeRefresher: public TreeUpdater {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
param.Init(args);
param.InitAllowUnknown(args);
}
// update the tree, do pruning
void Update(const std::vector<bst_gpair> &gpair,
@ -94,8 +94,8 @@ class TreeRefresher: public TreeUpdater {
reducer.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size());
#endif
// rescale learning rate according to size of trees
float lr = param.eta;
param.eta = lr / trees.size();
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
int offset = 0;
for (size_t i = 0; i < trees.size(); ++i) {
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
@ -104,7 +104,7 @@ class TreeRefresher: public TreeUpdater {
offset += trees[i]->param.num_nodes;
}
// set learning rate back
param.eta = lr;
param.learning_rate = lr;
}
private:
@ -131,7 +131,7 @@ class TreeRefresher: public TreeUpdater {
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.eta);
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else {
tree.stat(nid).loss_chg = static_cast<float>(
gstats[tree[nid].cleft()].CalcGain(param) +

View File

@ -24,13 +24,13 @@ class SketchMaker: public BaseMaker {
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
// rescale learning rate according to size of trees
float lr = param.eta;
param.eta = lr / trees.size();
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, trees[i]);
}
param.eta = lr;
param.learning_rate = lr;
}
protected:
@ -67,7 +67,7 @@ class SketchMaker: public BaseMaker {
// set left leaves
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
}
// define the sketch we want to use
@ -302,7 +302,7 @@ class SketchMaker: public BaseMaker {
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
} else {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.eta);
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
}
}