start unity refactor

This commit is contained in:
tqchen
2014-08-15 20:15:58 -07:00
parent 5b215742c2
commit 2a92c82b92
49 changed files with 3659 additions and 5803 deletions

492
tree/model.h Normal file
View File

@@ -0,0 +1,492 @@
#ifndef XGBOOST_TREE_MODEL_H_
#define XGBOOST_TREE_MODEL_H_
/*!
* \file model.h
* \brief model structure for tree
* \author Tianqi Chen
*/
#include <string>
#include <cstring>
#include <sstream>
#include <limits>
#include <algorithm>
#include <vector>
#include <cmath>
#include "../utils/io.h"
#include "../utils/fmap.h"
#include "../utils/utils.h"
namespace xgboost {
namespace tree {
/*!
* \brief template class of TreeModel
* \tparam TSplitCond data type to indicate split condition
* \tparam TNodeStat auxiliary statistics of node to help tree building
*/
template<typename TSplitCond, typename TNodeStat>
class TreeModel {
public:
/*! \brief data type to indicate split condition */
typedef TNodeStat NodeStat;
/*! \brief auxiliary statistics of node to help tree building */
typedef TSplitCond SplitCond;
/*! \brief parameters of the tree */
struct Param{
/*! \brief number of start root */
int num_roots;
/*! \brief total number of nodes */
int num_nodes;
/*!\brief number of deleted nodes */
int num_deleted;
/*! \brief maximum depth, this is a statistics of the tree */
int max_depth;
/*! \brief number of features used for tree construction */
int num_feature;
/*! \brief reserved part */
int reserved[32];
/*! \brief constructor */
Param(void) {
max_depth = 0;
memset(reserved, 0, sizeof(reserved));
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam(const char *name, const char *val) {
if (!strcmp("num_roots", name)) num_roots = atoi(val);
if (!strcmp("num_feature", name)) num_feature = atoi(val);
}
};
/*! \brief tree node */
class Node{
public:
/*! \brief index of left child */
inline int cleft(void) const {
return this->cleft_;
}
/*! \brief index of right child */
inline int cright(void) const {
return this->cright_;
}
/*! \brief index of default child when feature is missing */
inline int cdefault(void) const {
return this->default_left() ? this->cleft() : this->cright();
}
/*! \brief feature index of split condition */
inline unsigned split_index(void) const {
return sindex_ & ((1U << 31) - 1U);
}
/*! \brief when feature is unknown, whether goes to left child */
inline bool default_left(void) const {
return (sindex_ >> 31) != 0;
}
/*! \brief whether current node is leaf node */
inline bool is_leaf(void) const {
return cleft_ == -1;
}
/*! \brief get leaf value of leaf node */
inline float leaf_value(void) const {
return (this->info_).leaf_value;
}
/*! \brief get split condition of the node */
inline TSplitCond split_cond(void) const {
return (this->info_).split_cond;
}
/*! \brief get parent of the node */
inline int parent(void) const {
return parent_ & ((1U << 31) - 1);
}
/*! \brief whether current node is left child */
inline bool is_left_child(void) const {
return (parent_ & (1U << 31)) != 0;
}
/*! \brief whether current node is root */
inline bool is_root(void) const {
return parent_ == -1;
}
/*!
* \brief set the right child
* \param nide node id to right child
*/
inline void set_right_child(int nid) {
this->cright_ = nid;
}
/*!
* \brief set split condition of current node
* \param split_index feature index to split
* \param split_cond split condition
* \param default_left the default direction when feature is unknown
*/
inline void set_split(unsigned split_index, TSplitCond split_cond,
bool default_left = false) {
if (default_left) split_index |= (1U << 31);
this->sindex_ = split_index;
(this->info_).split_cond = split_cond;
}
/*!
* \brief set the leaf value of the node
* \param value leaf value
* \param right right index, could be used to store
* additional information
*/
inline void set_leaf(float value, int right = -1) {
(this->info_).leaf_value = value;
this->cleft_ = -1;
this->cright_ = right;
}
private:
friend class TreeModel<TSplitCond, TNodeStat>;
/*!
* \brief in leaf node, we have weights, in non-leaf nodes,
* we have split condition
*/
union Info{
float leaf_value;
TSplitCond split_cond;
};
// pointer to parent, highest bit is used to
// indicate whether it's a left child or not
int parent_;
// pointer to left, right
int cleft_, cright_;
// split feature index, left split or right split depends on the highest bit
unsigned sindex_;
// extra info
Info info_;
// set parent
inline void set_parent(int pidx, bool is_left_child = true) {
if (is_left_child) pidx |= (1U << 31);
this->parent_ = pidx;
}
};
protected:
// vector of nodes
std::vector<Node> nodes;
// stats of nodes
std::vector<TNodeStat> stats;
// free node space, used during training process
std::vector<int> deleted_nodes;
// allocate a new node,
// !!!!!! NOTE: may cause BUG here, nodes.resize
inline int AllocNode(void) {
if (param.num_deleted != 0) {
int nd = deleted_nodes.back();
deleted_nodes.pop_back();
--param.num_deleted;
return nd;
}
int nd = param.num_nodes++;
utils::Check(param.num_nodes < std::numeric_limits<int>::max(),
"number of nodes in the tree exceed 2^31");
nodes.resize(param.num_nodes);
stats.resize(param.num_nodes);
return nd;
}
// delete a tree node
inline void DeleteNode(int nid) {
utils::Assert(nid >= param.num_roots, "can not delete root");
deleted_nodes.push_back(nid);
nodes[nid].set_parent(-1);
++param.num_deleted;
}
public:
/*!
* \brief change a non leaf node to a leaf node, delete its children
* \param rid node id of the node
* \param new leaf value
*/
inline void ChangeToLeaf(int rid, float value) {
utils::Assert(nodes[nodes[rid].cleft() ].is_leaf(),
"can not delete a non termial child");
utils::Assert(nodes[nodes[rid].cright()].is_leaf(),
"can not delete a non termial child");
this->DeleteNode(nodes[rid].cleft());
this->DeleteNode(nodes[rid].cright());
nodes[rid].set_leaf(value);
}
/*!
* \brief collapse a non leaf node to a leaf node, delete its children
* \param rid node id of the node
* \param new leaf value
*/
inline void CollapseToLeaf(int rid, float value) {
if (nodes[rid].is_leaf()) return;
if (!nodes[nodes[rid].cleft() ].is_leaf()) {
CollapseToLeaf(nodes[rid].cleft(), 0.0f);
}
if (!nodes[nodes[rid].cright() ].is_leaf()) {
CollapseToLeaf(nodes[rid].cright(), 0.0f);
}
this->ChangeToLeaf(rid, value);
}
public:
/*! \brief model parameter */
Param param;
/*! \brief constructor */
TreeModel(void) {
param.num_nodes = 1;
param.num_roots = 1;
param.num_deleted = 0;
nodes.resize(1);
}
/*! \brief get node given nid */
inline Node &operator[](int nid) {
return nodes[nid];
}
/*! \brief get node given nid */
inline const Node &operator[](int nid) const {
return nodes[nid];
}
/*! \brief get node statistics given nid */
inline NodeStat &stat(int nid) {
return stats[nid];
}
/*! \brief initialize the model */
inline void InitModel(void) {
param.num_nodes = param.num_roots;
nodes.resize(param.num_nodes);
stats.resize(param.num_nodes);
for (int i = 0; i < param.num_nodes; i ++) {
nodes[i].set_leaf(0.0f);
nodes[i].set_parent(-1);
}
}
/*!
* \brief load model from stream
* \param fi input stream
*/
inline void LoadModel(utils::IStream &fi) {
utils::Check(fi.Read(&param, sizeof(Param)) > 0,
"TreeModel: wrong format");
nodes.resize(param.num_nodes); stats.resize(param.num_nodes);
utils::Check(fi.Read(&nodes[0], sizeof(Node) * nodes.size()) > 0,
"TreeModel: wrong format");
utils::Check(fi.Read(&stats[0], sizeof(NodeStat) * stats.size()) > 0,
"TreeModel: wrong format");
// chg deleted nodes
deleted_nodes.resize(0);
for (int i = param.num_roots; i < param.num_nodes; i ++) {
if (nodes[i].is_root()) deleted_nodes.push_back(i);
}
utils::Assert(static_cast<int>(deleted_nodes.size()) == param.num_deleted,
"number of deleted nodes do not match");
}
/*!
* \brief save model to stream
* \param fo output stream
*/
inline void SaveModel(utils::IStream &fo) const {
utils::Assert(param.num_nodes == static_cast<int>(nodes.size()),
"Tree::SaveModel");
utils::Assert(param.num_nodes == static_cast<int>(stats.size()),
"Tree::SaveModel");
fo.Write(&param, sizeof(Param));
fo.Write(&nodes[0], sizeof(Node) * nodes.size());
fo.Write(&stats[0], sizeof(NodeStat) * nodes.size());
}
/*!
* \brief add child nodes to node
* \param nid node id to add childs
*/
inline void AddChilds(int nid) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
nodes[nid].cleft_ = pleft;
nodes[nid].cright_ = pright;
nodes[nodes[nid].cleft() ].set_parent(nid, true);
nodes[nodes[nid].cright()].set_parent(nid, false);
}
/*!
* \brief only add a right child to a leaf node
* \param node id to add right child
*/
inline void AddRightChild(int nid) {
int pright = this->AllocNode();
nodes[nid].right = pright;
nodes[nodes[nid].right].set_parent(nid, false);
}
/*!
* \brief get current depth
* \param nid node id
* \param pass_rchild whether right child is not counted in depth
*/
inline int GetDepth(int nid, bool pass_rchild = false) const {
int depth = 0;
while (!nodes[nid].is_root()) {
if (!pass_rchild || nodes[nid].is_left_child()) ++depth;
nid = nodes[nid].parent();
}
return depth;
}
/*!
* \brief get maximum depth
* \param nid node id
*/
inline int MaxDepth(int nid) const {
if (nodes[nid].is_leaf()) return 0;
return std::max(MaxDepth(nodes[nid].cleft())+1,
MaxDepth(nodes[nid].cright())+1);
}
/*!
* \brief get maximum depth
*/
inline int MaxDepth(void) {
int maxd = 0;
for (int i = 0; i < param.num_roots; ++i) {
maxd = std::max(maxd, MaxDepth(i));
}
return maxd;
}
/*! \brief number of extra nodes besides the root */
inline int num_extra_nodes(void) const {
return param.num_nodes - param.num_roots - param.num_deleted;
}
/*!
* \brief dump model to text string
* \param fmap feature map of feature types
* \param with_stats whether dump out statistics as well
* \return the string of dumped model
*/
inline std::string DumpModel(const utils::FeatMap& fmap, bool with_stats) {
std::stringstream fo("");
for (int i = 0; i < param.num_roots; ++i) {
this->Dump(i, fo, fmap, 0, with_stats);
}
return fo.str();
}
private:
void Dump(int nid, std::stringstream &fo,
const utils::FeatMap& fmap, int depth, bool with_stats) {
for (int i = 0; i < depth; ++i) {
fo << '\t';
}
if (nodes[nid].is_leaf()) {
fo << nid << ":leaf=" << nodes[nid].leaf_value();
if (with_stats) {
stat(nid).Print(fo, true);
}
fo << '\n';
} else {
// right then left,
TSplitCond cond = nodes[nid].split_cond();
const unsigned split_index = nodes[nid].split_index();
if (split_index < fmap.size()) {
switch (fmap.type(split_index)) {
case utils::FeatMap::kIndicator: {
int nyes = nodes[nid].default_left() ?
nodes[nid].cright() : nodes[nid].cleft();
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
<< ",no=" << nodes[nid].cdefault();
break;
}
case utils::FeatMap::kInteger: {
fo << nid << ":[" << fmap.name(split_index) << "<"
<< int(float(cond)+1.0f)
<< "] yes=" << nodes[nid].cleft()
<< ",no=" << nodes[nid].cright()
<< ",missing=" << nodes[nid].cdefault();
break;
}
case utils::FeatMap::kFloat:
case utils::FeatMap::kQuantitive: {
fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond)
<< "] yes=" << nodes[nid].cleft()
<< ",no=" << nodes[nid].cright()
<< ",missing=" << nodes[nid].cdefault();
break;
}
default: utils::Error("unknown fmap type");
}
} else {
fo << nid << ":[f" << split_index << "<"<< float(cond)
<< "] yes=" << nodes[nid].cleft()
<< ",no=" << nodes[nid].cright()
<< ",missing=" << nodes[nid].cdefault();
}
if (with_stats) {
fo << ' ';
stat(nid).Print(fo, false);
}
fo << '\n';
this->Dump(nodes[nid].cleft(), fo, fmap, depth+1, with_stats);
this->Dump(nodes[nid].cright(), fo, fmap, depth+1, with_stats);
}
}
};
/*! \brief node statistics used in regression tree */
struct RTreeNodeStat{
/*! \brief loss chg caused by current split */
float loss_chg;
/*! \brief sum of hessian values, used to measure coverage of data */
float sum_hess;
/*! \brief weight of current node */
float base_weight;
/*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt;
/*! \brief print information of current stats to fo */
inline void Print(std::stringstream &fo, bool is_leaf) const {
if (!is_leaf) {
fo << "gain=" << loss_chg << ",cover=" << sum_hess;
} else {
fo << "cover=" << sum_hess;
}
}
};
/*! \brief define regression tree to be the most common tree model */
class RegTree: public TreeModel<bst_float, RTreeNodeStat>{
public:
/*!
* \brief get the leaf index
* \param feats dense feature vector, if the feature is missing the field is set to NaN
* \param root_gid starting root index of the instance
* \return the leaf index of the given feature
*/
inline int GetLeafIndex(const std::vector<float> &feat, unsigned root_id = 0) const {
// start from groups that belongs to current data
int pid = static_cast<int>(root_id);
// tranverse tree
while (!(*this)[ pid ].is_leaf()) {
unsigned split_index = (*this)[pid].split_index();
const float fvalue = feat[split_index];
pid = this->GetNext(pid, fvalue, std::isnan(fvalue));
}
return pid;
}
/*!
* \brief get the prediction of regression tree, only accepts dense feature vector
* \param feats dense feature vector, if the feature is missing the field is set to NaN
* \param root_gid starting root index of the instance
* \return the leaf index of the given feature
*/
inline float Predict(const std::vector<float> &feat, unsigned root_id = 0) const {
int pid = this->GetLeafIndex(feat, root_id);
return (*this)[pid].leaf_value();
}
private:
/*! \brief get next position of the tree given current pid */
inline int GetNext(int pid, float fvalue, bool is_unknown) const {
float split_value = (*this)[pid].split_cond();
if (is_unknown) {
return (*this)[pid].cdefault();
} else {
if (fvalue < split_value) {
return (*this)[pid].cleft();
} else {
return (*this)[pid].cright();
}
}
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_

262
tree/param.h Normal file
View File

@@ -0,0 +1,262 @@
#ifndef XGBOOST_TREE_PARAM_H_
#define XGBOOST_TREE_PARAM_H_
/*!
* \file param.h
* \brief training parameters, statistics used to support tree construction
* \author Tianqi Chen
*/
#include <cstring>
#include "../data.h"
namespace xgboost {
namespace tree {
/*! \brief core statistics used for tree construction */
struct GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief constructor */
GradStats(void) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const bst_gpair& b) {
this->Add(b.grad, b.hess);
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief substract the statistics by b */
inline GradStats Substract(const GradStats &b) const {
GradStats res;
res.sum_grad = this->sum_grad - b.sum_grad;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
}
/*! \return whether the statistics is not used yet */
inline bool Empty(void) const {
return sum_hess == 0.0;
}
};
/*! \brief training parameters for regression tree */
struct TrainParam{
// learning step size for a time
float learning_rate;
// minimum loss change required for a split
float min_split_loss;
// maximum depth of a tree
int max_depth;
//----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child
float min_child_weight;
// weight decay parameter used to control leaf fitting
float reg_lambda;
// reg method
int reg_method;
// default direction choice
int default_direction;
// whether we want to do subsample
float subsample;
// whether to subsample columns each split, in each level
float colsample_bylevel;
// whether to subsample columns during tree construction
float colsample_bytree;
// speed optimization for dense column
float opt_dense_col;
// number of threads to be used for tree construction,
// if OpenMP is enabled, if equals 0, use system default
int nthread;
/*! \brief constructor */
TrainParam(void) {
learning_rate = 0.3f;
min_child_weight = 1.0f;
max_depth = 6;
reg_lambda = 1.0f;
reg_method = 2;
default_direction = 0;
subsample = 1.0f;
colsample_bytree = 1.0f;
colsample_bylevel = 1.0f;
opt_dense_col = 1.0f;
nthread = 0;
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam(const char *name, const char *val) {
// sync-names
if (!strcmp(name, "gamma")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "eta")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "reg_method")) reg_method = static_cast<float>(atof(val));
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
if (!strcmp(name, "max_depth")) max_depth = atoi(val);
if (!strcmp(name, "nthread")) nthread = atoi(val);
if (!strcmp(name, "default_direction")) {
if (!strcmp(val, "learn")) default_direction = 0;
if (!strcmp(val, "left")) default_direction = 1;
if (!strcmp(val, "right")) default_direction = 2;
}
}
// calculate the cost of loss function
inline double CalcGain(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) {
return 0.0;
}
switch (reg_method) {
case 1 : return Sqr(ThresholdL1(sum_grad, reg_lambda)) / sum_hess;
case 2 : return Sqr(sum_grad) / (sum_hess + reg_lambda);
case 3 : return
Sqr(ThresholdL1(sum_grad, 0.5 * reg_lambda)) /
(sum_hess + 0.5 * reg_lambda);
default: return Sqr(sum_grad) / sum_hess;
}
}
// calculate weight given the statistics
inline double CalcWeight(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) {
return 0.0;
} else {
switch (reg_method) {
case 1: return - ThresholdL1(sum_grad, reg_lambda) / sum_hess;
case 2: return - sum_grad / (sum_hess + reg_lambda);
case 3: return
- ThresholdL1(sum_grad, 0.5 * reg_lambda) /
(sum_hess + 0.5 * reg_lambda);
default: return - sum_grad / sum_hess;
}
}
}
/*! \brief whether need forward small to big search: default right */
inline bool need_forward_search(float col_density = 0.0f) const {
return this->default_direction == 2 ||
(default_direction == 0 && (col_density < opt_dense_col));
}
/*! \brief whether need backward big to small search: default left */
inline bool need_backward_search(float col_density = 0.0f) const {
return this->default_direction != 2;
}
/*! \brief given the loss change, whether we need to invode prunning */
inline bool need_prune(double loss_chg, int depth) const {
return loss_chg < this->min_split_loss;
}
/*! \brief whether we can split with current hessian */
inline bool cannot_split(double sum_hess, int depth) const {
return sum_hess < this->min_child_weight * 2.0;
}
// code support for template data
inline double CalcWeight(const GradStats &d) const {
return this->CalcWeight(d.sum_grad, d.sum_hess);
}
inline double CalcGain(const GradStats &d) const {
return this->CalcGain(d.sum_grad, d.sum_hess);
}
protected:
// functions for L1 cost
inline static double ThresholdL1(double w, double lambda) {
if (w > +lambda) return w - lambda;
if (w < -lambda) return w + lambda;
return 0.0;
}
inline static double Sqr(double a) {
return a * a;
}
};
/*!
* \brief statistics that is helpful to store
* and represent a split solution for the tree
*/
struct SplitEntry{
/*! \brief loss change after split this node */
bst_float loss_chg;
/*! \brief split index */
unsigned sindex;
/*! \brief split value */
float split_value;
/*! \brief constructor */
SplitEntry(void) : loss_chg(0.0f), sindex(0), split_value(0.0f) {}
/*!
* \brief decides whether a we can replace current entry with the statistics given
* This function gives better priority to lower index when loss_chg equals
* not the best way, but helps to give consistent result during multi-thread execution
* \param loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on
*/
inline bool NeedReplace(bst_float loss_chg, unsigned split_index) const {
if (this->split_index() <= split_index) {
return loss_chg > this->loss_chg;
} else {
return !(this->loss_chg > loss_chg);
}
}
/*!
* \brief update the split entry, replace it if e is better
* \param e candidate split solution
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(const SplitEntry &e) {
if (this->NeedReplace(e.loss_chg, e.split_index())) {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
return true;
} else {
return false;
}
}
/*!
* \brief update the split entry, replace it if e is better
* \param loss_chg loss reduction of new candidate
* \param split_index feature index to split on
* \param split_value the split point
* \param default_left whether the missing value goes to left
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float loss_chg, unsigned split_index,
float split_value, bool default_left) {
if (this->NeedReplace(loss_chg, split_index)) {
this->loss_chg = loss_chg;
if (default_left) split_index |= (1U << 31);
this->sindex = split_index;
this->split_value = split_value;
return true;
} else {
return false;
}
}
/*!\return feature index to split on */
inline unsigned split_index(void) const {
return sindex & ((1U << 31) - 1U);
}
/*!\return whether missing value goes to left branch */
inline bool default_left(void) const {
return (sindex >> 31) != 0;
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_PARAM_H_

70
tree/updater.h Normal file
View File

@@ -0,0 +1,70 @@
#ifndef XGBOOST_TREE_UPDATER_H_
#define XGBOOST_TREE_UPDATER_H_
/*!
* \file updater.h
* \brief interface to update the tree
* \author Tianqi Chen
*/
#include <vector>
#include "../data.h"
#include "./model.h"
namespace xgboost {
namespace tree {
/*!
* \brief interface of tree update module, that performs update of a tree
* \tparam FMatrix the data type updater taking
*/
template<typename FMatrix>
class IUpdater {
public:
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
virtual void SetParam(const char *name, const char *val) = 0;
/*!
* \brief peform update to the tree models
* \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features
* \param root_index pre-partitioned root_index of each instance,
* root_index.size() can be 0 which indicates that no pre-partition involved
* \param trees pointer to the trese to be updated, upater will change the content of the tree
* note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model
*/
virtual void Update(const std::vector<bst_gpair> &gpair,
FMatrix &fmat,
const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) = 0;
// destructor
virtual ~IUpdater(void) {}
};
} // namespace tree
} // namespace xgboost
#include "./updater_prune-inl.hpp"
#include "./updater_colmaker-inl.hpp"
namespace xgboost {
namespace tree {
/*!
* \brief create a updater based on name
* \param name name of updater
* \return return the updater instance
*/
template<typename FMatrix>
inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
utils::Error("unknown updater:%s", name);
return NULL;
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_H_

View File

@@ -0,0 +1,357 @@
#ifndef XGBOOST_TREE_UPDATER_COLMAKER_INL_HPP_
#define XGBOOST_TREE_UPDATER_COLMAKER_INL_HPP_
/*!
* \file updater_colmaker-inl.hpp
* \brief use columnwise update to construct a tree
* \author Tianqi Chen
*/
#include <vector>
#include <algorithm>
#include "./param.h"
#include "./updater.h"
#include "../utils/omp.h"
#include "../utils/random.h"
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix, typename TStats>
class ColMaker: public IUpdater<FMatrix> {
public:
virtual ~ColMaker(void) {}
// set training parameter
virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val);
}
virtual void Update(const std::vector<bst_gpair> &gpair,
FMatrix &fmat,
const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) {
fmat.InitColAccess();
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, fmat, root_index, trees[i]);
}
}
private:
// training parameter
TrainParam param;
// data structure
/*! \brief per thread x per node entry to store tmp data */
struct ThreadEntry {
/*! \brief statistics of data*/
TStats stats;
/*! \brief last feature value scanned */
float last_fvalue;
/*! \brief current best solution */
SplitEntry best;
// constructor
ThreadEntry(void) {
stats.Clear();
}
};
struct NodeEntry {
/*! \brief statics for node entry */
TStats stats;
/*! \brief loss of this node, without split */
bst_float root_gain;
/*! \brief weight calculated related to current data */
float weight;
/*! \brief current best solution */
SplitEntry best;
// constructor
NodeEntry(void) : root_gain(0.0f), weight(0.0f){
stats.Clear();
}
};
// actual builder that runs the algorithm
struct Builder{
public:
// constructor
explicit Builder(const TrainParam &param) : param(param) {}
// update one tree, growing
virtual void Update(const std::vector<bst_gpair> &gpair, FMatrix &fmat,
const std::vector<unsigned> &root_index,
RegTree *p_tree) {
this->InitData(gpair, fmat, root_index, *p_tree);
this->InitNewNode(qexpand, gpair, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree);
this->ResetPosition(this->qexpand, fmat, *p_tree);
this->UpdateQueueExpand(*p_tree, &this->qexpand);
this->InitNewNode(qexpand, gpair, *p_tree);
// if nothing left to be expand, break
if (qexpand.size() == 0) break;
}
// set all the rest expanding nodes to leaf
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
}
// remember auxiliary statistics in the tree node
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
p_tree->stat(nid).base_weight = snode[nid].weight;
p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
}
}
private:
// initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair, FMatrix &fmat,
const std::vector<unsigned> &root_index, const RegTree &tree) {
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
{// setup position
position.resize(gpair.size());
if (root_index.size() == 0) {
std::fill(position.begin(), position.end(), 0);
} else {
for (size_t i = 0; i < root_index.size(); ++i) {
position[i] = root_index[i];
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting");
}
}
// mark delete for the deleted datas
for (size_t i = 0; i < gpair.size(); ++i) {
if (gpair[i].hess < 0.0f) position[i] = -1;
}
// mark subsample
if (param.subsample < 1.0f) {
for (size_t i = 0; i < gpair.size(); ++i) {
if (gpair[i].hess < 0.0f) continue;
if (random::SampleBinary(param.subsample) == 0) position[i] = -1;
}
}
}
{
// initialize feature index
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
for (unsigned i = 0; i < ncol; ++i) {
if (fmat.GetColSize(i) != 0) feat_index.push_back(i);
}
unsigned n = static_cast<unsigned>(param.colsample_bytree * feat_index.size());
random::Shuffle(feat_index);
utils::Check(n > 0, "colsample_bytree is too small that no feature can be included");
feat_index.resize(n);
}
{// setup temp space for each thread
#pragma omp parallel
{
this->nthread = omp_get_num_threads();
}
// reserve a small space
stemp.clear();
stemp.resize(this->nthread, std::vector<ThreadEntry>());
for (size_t i = 0; i < stemp.size(); ++i) {
stemp[i].clear(); stemp[i].reserve(256);
}
snode.reserve(256);
}
{// expand query
qexpand.reserve(256); qexpand.clear();
for (int i = 0; i < tree.param.num_roots; ++i) {
qexpand.push_back(i);
}
}
}
/*! \brief initialize the base_weight, root_gain, and NodeEntry for all the new nodes in qexpand */
inline void InitNewNode(const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const RegTree &tree) {
{// setup statistics space for each tree node
for (size_t i = 0; i < stemp.size(); ++i) {
stemp[i].resize(tree.param.num_nodes, ThreadEntry());
}
snode.resize(tree.param.num_nodes, NodeEntry());
}
// setup position
const unsigned ndata = static_cast<unsigned>(position.size());
#pragma omp parallel for schedule(static)
for (unsigned i = 0; i < ndata; ++i) {
const int tid = omp_get_thread_num();
if (position[i] < 0) continue;
stemp[tid][position[i]].stats.Add(gpair[i]);
}
// sum the per thread statistics together
for (size_t j = 0; j < qexpand.size(); ++j) {
const int nid = qexpand[j];
TStats stats; stats.Clear();
for (size_t tid = 0; tid < stemp.size(); ++tid) {
stats.Add(stemp[tid][nid].stats);
}
// update node statistics
snode[nid].stats = stats;
snode[nid].root_gain = param.CalcGain(stats);
snode[nid].weight = param.CalcWeight(stats);
}
}
/*! \brief update queue expand add in new leaves */
inline void UpdateQueueExpand(const RegTree &tree, std::vector<int> *p_qexpand) {
std::vector<int> &qexpand = *p_qexpand;
std::vector<int> newnodes;
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
if (!tree[ nid ].is_leaf()) {
newnodes.push_back(tree[nid].cleft());
newnodes.push_back(tree[nid].cright());
}
}
// use new nodes for qexpand
qexpand = newnodes;
}
// enumerate the split values of specific feature
template<typename Iter>
inline void EnumerateSplit(Iter it, unsigned fid,
const std::vector<bst_gpair> &gpair,
std::vector<ThreadEntry> &temp,
bool is_forward_search) {
// clear all the temp statistics
for (size_t j = 0; j < qexpand.size(); ++j) {
temp[qexpand[j]].stats.Clear();
}
while (it.Next()) {
const bst_uint ridx = it.rindex();
const int nid = position[ridx];
if (nid < 0) continue;
// start working
const float fvalue = it.fvalue();
// get the statistics of nid
ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init
if (e.stats.Empty()) {
e.stats.Add(gpair[ridx]);
e.last_fvalue = fvalue;
} else {
// try to find a split
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
TStats c = snode[nid].stats.Substract(e.stats);
if (c.sum_hess >= param.min_child_weight) {
double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
}
}
// update the statistics
e.stats.Add(gpair[ridx]);
e.last_fvalue = fvalue;
}
}
// finish updating all statistics, check if it is possible to include all sum statistics
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
ThreadEntry &e = temp[nid];
TStats c = snode[nid].stats.Substract(e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
const float delta = is_forward_search ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
}
}
}
// find splits at current level, do split per level
inline void FindSplit(int depth, const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair, const FMatrix &fmat,
RegTree *p_tree) {
std::vector<unsigned> feat_set = feat_index;
if (param.colsample_bylevel != 1.0f) {
random::Shuffle(feat_set);
unsigned n = static_cast<unsigned>(param.colsample_bylevel * feat_index.size());
utils::Check(n > 0, "colsample_bylevel is too small that no feature can be included");
feat_set.resize(n);
}
// start enumeration
const unsigned nsize = static_cast<unsigned>(feat_set.size());
#pragma omp parallel for schedule(dynamic, 1)
for (unsigned i = 0; i < nsize; ++i) {
const unsigned fid = feat_set[i];
const int tid = omp_get_thread_num();
if (param.need_forward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true);
}
if (param.need_backward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false);
}
}
// after this each thread's stemp will get the best candidates, aggregate results
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
NodeEntry &e = snode[nid];
for (int tid = 0; tid < this->nthread; ++tid) {
e.best.Update(stemp[tid][nid].best);
}
// now we know the solution in snode[nid], set split
if (e.best.loss_chg > rt_eps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left());
} else {
(*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
}
}
}
// reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, const FMatrix &fmat, const RegTree &tree) {
// step 1, set default direct nodes to default, and leaf nodes to -1
const unsigned ndata = static_cast<unsigned>(position.size());
#pragma omp parallel for schedule(static)
for (unsigned i = 0; i < ndata; ++i) {
const int nid = position[i];
if (nid >= 0) {
if (tree[nid].is_leaf()) {
position[i] = -1;
} else {
// push to default branch, correct latter
position[i] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
}
}
}
// step 2, classify the non-default data into right places
std::vector<unsigned> fsplits;
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
if (!tree[nid].is_leaf()) fsplits.push_back(tree[nid].split_index());
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
// start put things into right place
const unsigned nfeats = static_cast<unsigned>(fsplits.size());
#pragma omp parallel for schedule(dynamic, 1)
for (unsigned i = 0; i < nfeats; ++i) {
const unsigned fid = fsplits[i];
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
const bst_uint ridx = it.rindex();
int nid = position[ridx];
if (nid == -1) continue;
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {
if (it.fvalue() < tree[nid].split_cond()) {
position[ridx] = tree[nid].cleft();
} else {
position[ridx] = tree[nid].cright();
}
}
}
}
}
//--data fields--
const TrainParam &param;
// number of omp thread used during training
int nthread;
// Per feature: shuffle index of each feature index
std::vector<unsigned> feat_index;
// Instance Data: current node position in the tree of each instance
std::vector<int> position;
// PerThread x PerTreeNode: statistics for per thread construction
std::vector< std::vector<ThreadEntry> > stemp;
/*! \brief TreeNode Data: statistics for each constructed node */
std::vector<NodeEntry> snode;
/*! \brief queue of nodes to be expanded */
std::vector<int> qexpand;
};
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_COLMAKER_INL_HPP_

View File

@@ -0,0 +1,67 @@
#ifndef XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
#define XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
/*!
* \file updater_prune-inl.hpp
* \brief prune a tree given the statistics
* \author Tianqi Chen
*/
#include <vector>
#include "./param.h"
#include "./updater.h"
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix>
class TreePruner: public IUpdater<FMatrix> {
public:
virtual ~TreePruner(void) {}
// set training parameter
virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val);
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair, FMatrix &fmat,
const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) {
for (size_t i = 0; i < trees.size(); ++i) {
this->DoPrune(*trees[i]);
}
}
private:
// try to prune off current leaf
inline void TryPruneLeaf(RegTree &tree, int nid, int depth) {
if (tree[nid].is_root()) return;
int pid = tree[nid].parent();
RegTree::NodeStat &s = tree.stat(pid);
++s.leaf_child_cnt;
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
// need to be pruned
tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight);
// tail recursion
this->TryPruneLeaf(tree, pid, depth - 1);
}
}
/*! \brief do prunning of a tree */
inline void DoPrune(RegTree &tree) {
// initialize auxiliary statistics
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
tree.stat(nid).leaf_child_cnt = 0;
}
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
if (tree[nid].is_leaf()) {
this->TryPruneLeaf(tree, nid, tree.GetDepth(nid));
}
}
}
private:
// training parameter
TrainParam param;
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_