xgboost/tree/param.h
2014-08-15 20:15:58 -07:00

263 lines
9.0 KiB
C++

#ifndef XGBOOST_TREE_PARAM_H_
#define XGBOOST_TREE_PARAM_H_
/*!
* \file param.h
* \brief training parameters, statistics used to support tree construction
* \author Tianqi Chen
*/
#include <cstring>
#include "../data.h"
namespace xgboost {
namespace tree {
/*! \brief core statistics used for tree construction */
struct GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief constructor */
GradStats(void) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const bst_gpair& b) {
this->Add(b.grad, b.hess);
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief substract the statistics by b */
inline GradStats Substract(const GradStats &b) const {
GradStats res;
res.sum_grad = this->sum_grad - b.sum_grad;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
}
/*! \return whether the statistics is not used yet */
inline bool Empty(void) const {
return sum_hess == 0.0;
}
};
/*! \brief training parameters for regression tree */
struct TrainParam{
// learning step size for a time
float learning_rate;
// minimum loss change required for a split
float min_split_loss;
// maximum depth of a tree
int max_depth;
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
// weight decay parameter used to control leaf fitting
float reg_lambda;
// reg method
int reg_method;
// default direction choice
int default_direction;
// whether we want to do subsample
float subsample;
// whether to subsample columns each split, in each level
float colsample_bylevel;
// whether to subsample columns during tree construction
float colsample_bytree;
// speed optimization for dense column
float opt_dense_col;
// number of threads to be used for tree construction,
// if OpenMP is enabled, if equals 0, use system default
int nthread;
/*! \brief constructor */
TrainParam(void) {
learning_rate = 0.3f;
min_child_weight = 1.0f;
max_depth = 6;
reg_lambda = 1.0f;
reg_method = 2;
default_direction = 0;
subsample = 1.0f;
colsample_bytree = 1.0f;
colsample_bylevel = 1.0f;
opt_dense_col = 1.0f;
nthread = 0;
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam(const char *name, const char *val) {
// sync-names
if (!strcmp(name, "gamma")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "eta")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "reg_method")) reg_method = static_cast<float>(atof(val));
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
if (!strcmp(name, "max_depth")) max_depth = atoi(val);
if (!strcmp(name, "nthread")) nthread = atoi(val);
if (!strcmp(name, "default_direction")) {
if (!strcmp(val, "learn")) default_direction = 0;
if (!strcmp(val, "left")) default_direction = 1;
if (!strcmp(val, "right")) default_direction = 2;
}
}
// calculate the cost of loss function
inline double CalcGain(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) {
return 0.0;
}
switch (reg_method) {
case 1 : return Sqr(ThresholdL1(sum_grad, reg_lambda)) / sum_hess;
case 2 : return Sqr(sum_grad) / (sum_hess + reg_lambda);
case 3 : return
Sqr(ThresholdL1(sum_grad, 0.5 * reg_lambda)) /
(sum_hess + 0.5 * reg_lambda);
default: return Sqr(sum_grad) / sum_hess;
}
}
// calculate weight given the statistics
inline double CalcWeight(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) {
return 0.0;
} else {
switch (reg_method) {
case 1: return - ThresholdL1(sum_grad, reg_lambda) / sum_hess;
case 2: return - sum_grad / (sum_hess + reg_lambda);
case 3: return
- ThresholdL1(sum_grad, 0.5 * reg_lambda) /
(sum_hess + 0.5 * reg_lambda);
default: return - sum_grad / sum_hess;
}
}
}
/*! \brief whether need forward small to big search: default right */
inline bool need_forward_search(float col_density = 0.0f) const {
return this->default_direction == 2 ||
(default_direction == 0 && (col_density < opt_dense_col));
}
/*! \brief whether need backward big to small search: default left */
inline bool need_backward_search(float col_density = 0.0f) const {
return this->default_direction != 2;
}
/*! \brief given the loss change, whether we need to invode prunning */
inline bool need_prune(double loss_chg, int depth) const {
return loss_chg < this->min_split_loss;
}
/*! \brief whether we can split with current hessian */
inline bool cannot_split(double sum_hess, int depth) const {
return sum_hess < this->min_child_weight * 2.0;
}
// code support for template data
inline double CalcWeight(const GradStats &d) const {
return this->CalcWeight(d.sum_grad, d.sum_hess);
}
inline double CalcGain(const GradStats &d) const {
return this->CalcGain(d.sum_grad, d.sum_hess);
}
protected:
// functions for L1 cost
inline static double ThresholdL1(double w, double lambda) {
if (w > +lambda) return w - lambda;
if (w < -lambda) return w + lambda;
return 0.0;
}
inline static double Sqr(double a) {
return a * a;
}
};
/*!
* \brief statistics that is helpful to store
* and represent a split solution for the tree
*/
struct SplitEntry{
/*! \brief loss change after split this node */
bst_float loss_chg;
/*! \brief split index */
unsigned sindex;
/*! \brief split value */
float split_value;
/*! \brief constructor */
SplitEntry(void) : loss_chg(0.0f), sindex(0), split_value(0.0f) {}
/*!
* \brief decides whether a we can replace current entry with the statistics given
* This function gives better priority to lower index when loss_chg equals
* not the best way, but helps to give consistent result during multi-thread execution
* \param loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on
*/
inline bool NeedReplace(bst_float loss_chg, unsigned split_index) const {
if (this->split_index() <= split_index) {
return loss_chg > this->loss_chg;
} else {
return !(this->loss_chg > loss_chg);
}
}
/*!
* \brief update the split entry, replace it if e is better
* \param e candidate split solution
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(const SplitEntry &e) {
if (this->NeedReplace(e.loss_chg, e.split_index())) {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
return true;
} else {
return false;
}
}
/*!
* \brief update the split entry, replace it if e is better
* \param loss_chg loss reduction of new candidate
* \param split_index feature index to split on
* \param split_value the split point
* \param default_left whether the missing value goes to left
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float loss_chg, unsigned split_index,
float split_value, bool default_left) {
if (this->NeedReplace(loss_chg, split_index)) {
this->loss_chg = loss_chg;
if (default_left) split_index |= (1U << 31);
this->sindex = split_index;
this->split_value = split_value;
return true;
} else {
return false;
}
}
/*!\return feature index to split on */
inline unsigned split_index(void) const {
return sindex & ((1U << 31) - 1U);
}
/*!\return whether missing value goes to left branch */
inline bool default_left(void) const {
return (sindex >> 31) != 0;
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_PARAM_H_