start unity refactor
This commit is contained in:
492
tree/model.h
Normal file
492
tree/model.h
Normal file
@@ -0,0 +1,492 @@
|
||||
#ifndef XGBOOST_TREE_MODEL_H_
|
||||
#define XGBOOST_TREE_MODEL_H_
|
||||
/*!
|
||||
* \file model.h
|
||||
* \brief model structure for tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/fmap.h"
|
||||
#include "../utils/utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*!
|
||||
* \brief template class of TreeModel
|
||||
* \tparam TSplitCond data type to indicate split condition
|
||||
* \tparam TNodeStat auxiliary statistics of node to help tree building
|
||||
*/
|
||||
template<typename TSplitCond, typename TNodeStat>
|
||||
class TreeModel {
|
||||
public:
|
||||
/*! \brief data type to indicate split condition */
|
||||
typedef TNodeStat NodeStat;
|
||||
/*! \brief auxiliary statistics of node to help tree building */
|
||||
typedef TSplitCond SplitCond;
|
||||
/*! \brief parameters of the tree */
|
||||
struct Param{
|
||||
/*! \brief number of start root */
|
||||
int num_roots;
|
||||
/*! \brief total number of nodes */
|
||||
int num_nodes;
|
||||
/*!\brief number of deleted nodes */
|
||||
int num_deleted;
|
||||
/*! \brief maximum depth, this is a statistics of the tree */
|
||||
int max_depth;
|
||||
/*! \brief number of features used for tree construction */
|
||||
int num_feature;
|
||||
/*! \brief reserved part */
|
||||
int reserved[32];
|
||||
/*! \brief constructor */
|
||||
Param(void) {
|
||||
max_depth = 0;
|
||||
memset(reserved, 0, sizeof(reserved));
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
* \param name name of the parameter
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp("num_roots", name)) num_roots = atoi(val);
|
||||
if (!strcmp("num_feature", name)) num_feature = atoi(val);
|
||||
}
|
||||
};
|
||||
/*! \brief tree node */
|
||||
class Node{
|
||||
public:
|
||||
/*! \brief index of left child */
|
||||
inline int cleft(void) const {
|
||||
return this->cleft_;
|
||||
}
|
||||
/*! \brief index of right child */
|
||||
inline int cright(void) const {
|
||||
return this->cright_;
|
||||
}
|
||||
/*! \brief index of default child when feature is missing */
|
||||
inline int cdefault(void) const {
|
||||
return this->default_left() ? this->cleft() : this->cright();
|
||||
}
|
||||
/*! \brief feature index of split condition */
|
||||
inline unsigned split_index(void) const {
|
||||
return sindex_ & ((1U << 31) - 1U);
|
||||
}
|
||||
/*! \brief when feature is unknown, whether goes to left child */
|
||||
inline bool default_left(void) const {
|
||||
return (sindex_ >> 31) != 0;
|
||||
}
|
||||
/*! \brief whether current node is leaf node */
|
||||
inline bool is_leaf(void) const {
|
||||
return cleft_ == -1;
|
||||
}
|
||||
/*! \brief get leaf value of leaf node */
|
||||
inline float leaf_value(void) const {
|
||||
return (this->info_).leaf_value;
|
||||
}
|
||||
/*! \brief get split condition of the node */
|
||||
inline TSplitCond split_cond(void) const {
|
||||
return (this->info_).split_cond;
|
||||
}
|
||||
/*! \brief get parent of the node */
|
||||
inline int parent(void) const {
|
||||
return parent_ & ((1U << 31) - 1);
|
||||
}
|
||||
/*! \brief whether current node is left child */
|
||||
inline bool is_left_child(void) const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
/*! \brief whether current node is root */
|
||||
inline bool is_root(void) const {
|
||||
return parent_ == -1;
|
||||
}
|
||||
/*!
|
||||
* \brief set the right child
|
||||
* \param nide node id to right child
|
||||
*/
|
||||
inline void set_right_child(int nid) {
|
||||
this->cright_ = nid;
|
||||
}
|
||||
/*!
|
||||
* \brief set split condition of current node
|
||||
* \param split_index feature index to split
|
||||
* \param split_cond split condition
|
||||
* \param default_left the default direction when feature is unknown
|
||||
*/
|
||||
inline void set_split(unsigned split_index, TSplitCond split_cond,
|
||||
bool default_left = false) {
|
||||
if (default_left) split_index |= (1U << 31);
|
||||
this->sindex_ = split_index;
|
||||
(this->info_).split_cond = split_cond;
|
||||
}
|
||||
/*!
|
||||
* \brief set the leaf value of the node
|
||||
* \param value leaf value
|
||||
* \param right right index, could be used to store
|
||||
* additional information
|
||||
*/
|
||||
inline void set_leaf(float value, int right = -1) {
|
||||
(this->info_).leaf_value = value;
|
||||
this->cleft_ = -1;
|
||||
this->cright_ = right;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class TreeModel<TSplitCond, TNodeStat>;
|
||||
/*!
|
||||
* \brief in leaf node, we have weights, in non-leaf nodes,
|
||||
* we have split condition
|
||||
*/
|
||||
union Info{
|
||||
float leaf_value;
|
||||
TSplitCond split_cond;
|
||||
};
|
||||
// pointer to parent, highest bit is used to
|
||||
// indicate whether it's a left child or not
|
||||
int parent_;
|
||||
// pointer to left, right
|
||||
int cleft_, cright_;
|
||||
// split feature index, left split or right split depends on the highest bit
|
||||
unsigned sindex_;
|
||||
// extra info
|
||||
Info info_;
|
||||
// set parent
|
||||
inline void set_parent(int pidx, bool is_left_child = true) {
|
||||
if (is_left_child) pidx |= (1U << 31);
|
||||
this->parent_ = pidx;
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
// vector of nodes
|
||||
std::vector<Node> nodes;
|
||||
// stats of nodes
|
||||
std::vector<TNodeStat> stats;
|
||||
// free node space, used during training process
|
||||
std::vector<int> deleted_nodes;
|
||||
// allocate a new node,
|
||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||
inline int AllocNode(void) {
|
||||
if (param.num_deleted != 0) {
|
||||
int nd = deleted_nodes.back();
|
||||
deleted_nodes.pop_back();
|
||||
--param.num_deleted;
|
||||
return nd;
|
||||
}
|
||||
int nd = param.num_nodes++;
|
||||
utils::Check(param.num_nodes < std::numeric_limits<int>::max(),
|
||||
"number of nodes in the tree exceed 2^31");
|
||||
nodes.resize(param.num_nodes);
|
||||
stats.resize(param.num_nodes);
|
||||
return nd;
|
||||
}
|
||||
// delete a tree node
|
||||
inline void DeleteNode(int nid) {
|
||||
utils::Assert(nid >= param.num_roots, "can not delete root");
|
||||
deleted_nodes.push_back(nid);
|
||||
nodes[nid].set_parent(-1);
|
||||
++param.num_deleted;
|
||||
}
|
||||
|
||||
public:
|
||||
/*!
|
||||
* \brief change a non leaf node to a leaf node, delete its children
|
||||
* \param rid node id of the node
|
||||
* \param new leaf value
|
||||
*/
|
||||
inline void ChangeToLeaf(int rid, float value) {
|
||||
utils::Assert(nodes[nodes[rid].cleft() ].is_leaf(),
|
||||
"can not delete a non termial child");
|
||||
utils::Assert(nodes[nodes[rid].cright()].is_leaf(),
|
||||
"can not delete a non termial child");
|
||||
this->DeleteNode(nodes[rid].cleft());
|
||||
this->DeleteNode(nodes[rid].cright());
|
||||
nodes[rid].set_leaf(value);
|
||||
}
|
||||
/*!
|
||||
* \brief collapse a non leaf node to a leaf node, delete its children
|
||||
* \param rid node id of the node
|
||||
* \param new leaf value
|
||||
*/
|
||||
inline void CollapseToLeaf(int rid, float value) {
|
||||
if (nodes[rid].is_leaf()) return;
|
||||
if (!nodes[nodes[rid].cleft() ].is_leaf()) {
|
||||
CollapseToLeaf(nodes[rid].cleft(), 0.0f);
|
||||
}
|
||||
if (!nodes[nodes[rid].cright() ].is_leaf()) {
|
||||
CollapseToLeaf(nodes[rid].cright(), 0.0f);
|
||||
}
|
||||
this->ChangeToLeaf(rid, value);
|
||||
}
|
||||
|
||||
public:
|
||||
/*! \brief model parameter */
|
||||
Param param;
|
||||
/*! \brief constructor */
|
||||
TreeModel(void) {
|
||||
param.num_nodes = 1;
|
||||
param.num_roots = 1;
|
||||
param.num_deleted = 0;
|
||||
nodes.resize(1);
|
||||
}
|
||||
/*! \brief get node given nid */
|
||||
inline Node &operator[](int nid) {
|
||||
return nodes[nid];
|
||||
}
|
||||
/*! \brief get node given nid */
|
||||
inline const Node &operator[](int nid) const {
|
||||
return nodes[nid];
|
||||
}
|
||||
/*! \brief get node statistics given nid */
|
||||
inline NodeStat &stat(int nid) {
|
||||
return stats[nid];
|
||||
}
|
||||
/*! \brief initialize the model */
|
||||
inline void InitModel(void) {
|
||||
param.num_nodes = param.num_roots;
|
||||
nodes.resize(param.num_nodes);
|
||||
stats.resize(param.num_nodes);
|
||||
for (int i = 0; i < param.num_nodes; i ++) {
|
||||
nodes[i].set_leaf(0.0f);
|
||||
nodes[i].set_parent(-1);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
*/
|
||||
inline void LoadModel(utils::IStream &fi) {
|
||||
utils::Check(fi.Read(¶m, sizeof(Param)) > 0,
|
||||
"TreeModel: wrong format");
|
||||
nodes.resize(param.num_nodes); stats.resize(param.num_nodes);
|
||||
utils::Check(fi.Read(&nodes[0], sizeof(Node) * nodes.size()) > 0,
|
||||
"TreeModel: wrong format");
|
||||
utils::Check(fi.Read(&stats[0], sizeof(NodeStat) * stats.size()) > 0,
|
||||
"TreeModel: wrong format");
|
||||
// chg deleted nodes
|
||||
deleted_nodes.resize(0);
|
||||
for (int i = param.num_roots; i < param.num_nodes; i ++) {
|
||||
if (nodes[i].is_root()) deleted_nodes.push_back(i);
|
||||
}
|
||||
utils::Assert(static_cast<int>(deleted_nodes.size()) == param.num_deleted,
|
||||
"number of deleted nodes do not match");
|
||||
}
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
inline void SaveModel(utils::IStream &fo) const {
|
||||
utils::Assert(param.num_nodes == static_cast<int>(nodes.size()),
|
||||
"Tree::SaveModel");
|
||||
utils::Assert(param.num_nodes == static_cast<int>(stats.size()),
|
||||
"Tree::SaveModel");
|
||||
fo.Write(¶m, sizeof(Param));
|
||||
fo.Write(&nodes[0], sizeof(Node) * nodes.size());
|
||||
fo.Write(&stats[0], sizeof(NodeStat) * nodes.size());
|
||||
}
|
||||
/*!
|
||||
* \brief add child nodes to node
|
||||
* \param nid node id to add childs
|
||||
*/
|
||||
inline void AddChilds(int nid) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
nodes[nid].cleft_ = pleft;
|
||||
nodes[nid].cright_ = pright;
|
||||
nodes[nodes[nid].cleft() ].set_parent(nid, true);
|
||||
nodes[nodes[nid].cright()].set_parent(nid, false);
|
||||
}
|
||||
/*!
|
||||
* \brief only add a right child to a leaf node
|
||||
* \param node id to add right child
|
||||
*/
|
||||
inline void AddRightChild(int nid) {
|
||||
int pright = this->AllocNode();
|
||||
nodes[nid].right = pright;
|
||||
nodes[nodes[nid].right].set_parent(nid, false);
|
||||
}
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
* \param pass_rchild whether right child is not counted in depth
|
||||
*/
|
||||
inline int GetDepth(int nid, bool pass_rchild = false) const {
|
||||
int depth = 0;
|
||||
while (!nodes[nid].is_root()) {
|
||||
if (!pass_rchild || nodes[nid].is_left_child()) ++depth;
|
||||
nid = nodes[nid].parent();
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
* \param nid node id
|
||||
*/
|
||||
inline int MaxDepth(int nid) const {
|
||||
if (nodes[nid].is_leaf()) return 0;
|
||||
return std::max(MaxDepth(nodes[nid].cleft())+1,
|
||||
MaxDepth(nodes[nid].cright())+1);
|
||||
}
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
*/
|
||||
inline int MaxDepth(void) {
|
||||
int maxd = 0;
|
||||
for (int i = 0; i < param.num_roots; ++i) {
|
||||
maxd = std::max(maxd, MaxDepth(i));
|
||||
}
|
||||
return maxd;
|
||||
}
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
inline int num_extra_nodes(void) const {
|
||||
return param.num_nodes - param.num_roots - param.num_deleted;
|
||||
}
|
||||
/*!
|
||||
* \brief dump model to text string
|
||||
* \param fmap feature map of feature types
|
||||
* \param with_stats whether dump out statistics as well
|
||||
* \return the string of dumped model
|
||||
*/
|
||||
inline std::string DumpModel(const utils::FeatMap& fmap, bool with_stats) {
|
||||
std::stringstream fo("");
|
||||
for (int i = 0; i < param.num_roots; ++i) {
|
||||
this->Dump(i, fo, fmap, 0, with_stats);
|
||||
}
|
||||
return fo.str();
|
||||
}
|
||||
|
||||
private:
|
||||
void Dump(int nid, std::stringstream &fo,
|
||||
const utils::FeatMap& fmap, int depth, bool with_stats) {
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
fo << '\t';
|
||||
}
|
||||
if (nodes[nid].is_leaf()) {
|
||||
fo << nid << ":leaf=" << nodes[nid].leaf_value();
|
||||
if (with_stats) {
|
||||
stat(nid).Print(fo, true);
|
||||
}
|
||||
fo << '\n';
|
||||
} else {
|
||||
// right then left,
|
||||
TSplitCond cond = nodes[nid].split_cond();
|
||||
const unsigned split_index = nodes[nid].split_index();
|
||||
if (split_index < fmap.size()) {
|
||||
switch (fmap.type(split_index)) {
|
||||
case utils::FeatMap::kIndicator: {
|
||||
int nyes = nodes[nid].default_left() ?
|
||||
nodes[nid].cright() : nodes[nid].cleft();
|
||||
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
|
||||
<< ",no=" << nodes[nid].cdefault();
|
||||
break;
|
||||
}
|
||||
case utils::FeatMap::kInteger: {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "<"
|
||||
<< int(float(cond)+1.0f)
|
||||
<< "] yes=" << nodes[nid].cleft()
|
||||
<< ",no=" << nodes[nid].cright()
|
||||
<< ",missing=" << nodes[nid].cdefault();
|
||||
break;
|
||||
}
|
||||
case utils::FeatMap::kFloat:
|
||||
case utils::FeatMap::kQuantitive: {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond)
|
||||
<< "] yes=" << nodes[nid].cleft()
|
||||
<< ",no=" << nodes[nid].cright()
|
||||
<< ",missing=" << nodes[nid].cdefault();
|
||||
break;
|
||||
}
|
||||
default: utils::Error("unknown fmap type");
|
||||
}
|
||||
} else {
|
||||
fo << nid << ":[f" << split_index << "<"<< float(cond)
|
||||
<< "] yes=" << nodes[nid].cleft()
|
||||
<< ",no=" << nodes[nid].cright()
|
||||
<< ",missing=" << nodes[nid].cdefault();
|
||||
}
|
||||
if (with_stats) {
|
||||
fo << ' ';
|
||||
stat(nid).Print(fo, false);
|
||||
}
|
||||
fo << '\n';
|
||||
this->Dump(nodes[nid].cleft(), fo, fmap, depth+1, with_stats);
|
||||
this->Dump(nodes[nid].cright(), fo, fmap, depth+1, with_stats);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief node statistics used in regression tree */
|
||||
struct RTreeNodeStat{
|
||||
/*! \brief loss chg caused by current split */
|
||||
float loss_chg;
|
||||
/*! \brief sum of hessian values, used to measure coverage of data */
|
||||
float sum_hess;
|
||||
/*! \brief weight of current node */
|
||||
float base_weight;
|
||||
/*! \brief number of child that is leaf node known up to now */
|
||||
int leaf_child_cnt;
|
||||
/*! \brief print information of current stats to fo */
|
||||
inline void Print(std::stringstream &fo, bool is_leaf) const {
|
||||
if (!is_leaf) {
|
||||
fo << "gain=" << loss_chg << ",cover=" << sum_hess;
|
||||
} else {
|
||||
fo << "cover=" << sum_hess;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief define regression tree to be the most common tree model */
|
||||
class RegTree: public TreeModel<bst_float, RTreeNodeStat>{
|
||||
public:
|
||||
/*!
|
||||
* \brief get the leaf index
|
||||
* \param feats dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_gid starting root index of the instance
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
inline int GetLeafIndex(const std::vector<float> &feat, unsigned root_id = 0) const {
|
||||
// start from groups that belongs to current data
|
||||
int pid = static_cast<int>(root_id);
|
||||
// tranverse tree
|
||||
while (!(*this)[ pid ].is_leaf()) {
|
||||
unsigned split_index = (*this)[pid].split_index();
|
||||
const float fvalue = feat[split_index];
|
||||
pid = this->GetNext(pid, fvalue, std::isnan(fvalue));
|
||||
}
|
||||
return pid;
|
||||
}
|
||||
/*!
|
||||
* \brief get the prediction of regression tree, only accepts dense feature vector
|
||||
* \param feats dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_gid starting root index of the instance
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
inline float Predict(const std::vector<float> &feat, unsigned root_id = 0) const {
|
||||
int pid = this->GetLeafIndex(feat, root_id);
|
||||
return (*this)[pid].leaf_value();
|
||||
}
|
||||
private:
|
||||
/*! \brief get next position of the tree given current pid */
|
||||
inline int GetNext(int pid, float fvalue, bool is_unknown) const {
|
||||
float split_value = (*this)[pid].split_cond();
|
||||
if (is_unknown) {
|
||||
return (*this)[pid].cdefault();
|
||||
} else {
|
||||
if (fvalue < split_value) {
|
||||
return (*this)[pid].cleft();
|
||||
} else {
|
||||
return (*this)[pid].cright();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_MODEL_H_
|
||||
262
tree/param.h
Normal file
262
tree/param.h
Normal file
@@ -0,0 +1,262 @@
|
||||
#ifndef XGBOOST_TREE_PARAM_H_
|
||||
#define XGBOOST_TREE_PARAM_H_
|
||||
/*!
|
||||
* \file param.h
|
||||
* \brief training parameters, statistics used to support tree construction
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <cstring>
|
||||
#include "../data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
/*! \brief core statistics used for tree construction */
|
||||
struct GradStats {
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief constructor */
|
||||
GradStats(void) {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \brief clear the statistics */
|
||||
inline void Clear(void) {
|
||||
sum_grad = sum_hess = 0.0f;
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(double grad, double hess) {
|
||||
sum_grad += grad; sum_hess += hess;
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const bst_gpair& b) {
|
||||
this->Add(b.grad, b.hess);
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const GradStats &b) {
|
||||
this->Add(b.sum_grad, b.sum_hess);
|
||||
}
|
||||
/*! \brief substract the statistics by b */
|
||||
inline GradStats Substract(const GradStats &b) const {
|
||||
GradStats res;
|
||||
res.sum_grad = this->sum_grad - b.sum_grad;
|
||||
res.sum_hess = this->sum_hess - b.sum_hess;
|
||||
return res;
|
||||
}
|
||||
/*! \return whether the statistics is not used yet */
|
||||
inline bool Empty(void) const {
|
||||
return sum_hess == 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief training parameters for regression tree */
|
||||
struct TrainParam{
|
||||
// learning step size for a time
|
||||
float learning_rate;
|
||||
// minimum loss change required for a split
|
||||
float min_split_loss;
|
||||
// maximum depth of a tree
|
||||
int max_depth;
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
float min_child_weight;
|
||||
// weight decay parameter used to control leaf fitting
|
||||
float reg_lambda;
|
||||
// reg method
|
||||
int reg_method;
|
||||
// default direction choice
|
||||
int default_direction;
|
||||
// whether we want to do subsample
|
||||
float subsample;
|
||||
// whether to subsample columns each split, in each level
|
||||
float colsample_bylevel;
|
||||
// whether to subsample columns during tree construction
|
||||
float colsample_bytree;
|
||||
// speed optimization for dense column
|
||||
float opt_dense_col;
|
||||
// number of threads to be used for tree construction,
|
||||
// if OpenMP is enabled, if equals 0, use system default
|
||||
int nthread;
|
||||
/*! \brief constructor */
|
||||
TrainParam(void) {
|
||||
learning_rate = 0.3f;
|
||||
min_child_weight = 1.0f;
|
||||
max_depth = 6;
|
||||
reg_lambda = 1.0f;
|
||||
reg_method = 2;
|
||||
default_direction = 0;
|
||||
subsample = 1.0f;
|
||||
colsample_bytree = 1.0f;
|
||||
colsample_bylevel = 1.0f;
|
||||
opt_dense_col = 1.0f;
|
||||
nthread = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
* \param name name of the parameter
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
// sync-names
|
||||
if (!strcmp(name, "gamma")) min_split_loss = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "eta")) learning_rate = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "lambda")) reg_lambda = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "reg_method")) reg_method = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "max_depth")) max_depth = atoi(val);
|
||||
if (!strcmp(name, "nthread")) nthread = atoi(val);
|
||||
if (!strcmp(name, "default_direction")) {
|
||||
if (!strcmp(val, "learn")) default_direction = 0;
|
||||
if (!strcmp(val, "left")) default_direction = 1;
|
||||
if (!strcmp(val, "right")) default_direction = 2;
|
||||
}
|
||||
}
|
||||
// calculate the cost of loss function
|
||||
inline double CalcGain(double sum_grad, double sum_hess) const {
|
||||
if (sum_hess < min_child_weight) {
|
||||
return 0.0;
|
||||
}
|
||||
switch (reg_method) {
|
||||
case 1 : return Sqr(ThresholdL1(sum_grad, reg_lambda)) / sum_hess;
|
||||
case 2 : return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
||||
case 3 : return
|
||||
Sqr(ThresholdL1(sum_grad, 0.5 * reg_lambda)) /
|
||||
(sum_hess + 0.5 * reg_lambda);
|
||||
default: return Sqr(sum_grad) / sum_hess;
|
||||
}
|
||||
}
|
||||
// calculate weight given the statistics
|
||||
inline double CalcWeight(double sum_grad, double sum_hess) const {
|
||||
if (sum_hess < min_child_weight) {
|
||||
return 0.0;
|
||||
} else {
|
||||
switch (reg_method) {
|
||||
case 1: return - ThresholdL1(sum_grad, reg_lambda) / sum_hess;
|
||||
case 2: return - sum_grad / (sum_hess + reg_lambda);
|
||||
case 3: return
|
||||
- ThresholdL1(sum_grad, 0.5 * reg_lambda) /
|
||||
(sum_hess + 0.5 * reg_lambda);
|
||||
default: return - sum_grad / sum_hess;
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief whether need forward small to big search: default right */
|
||||
inline bool need_forward_search(float col_density = 0.0f) const {
|
||||
return this->default_direction == 2 ||
|
||||
(default_direction == 0 && (col_density < opt_dense_col));
|
||||
}
|
||||
/*! \brief whether need backward big to small search: default left */
|
||||
inline bool need_backward_search(float col_density = 0.0f) const {
|
||||
return this->default_direction != 2;
|
||||
}
|
||||
/*! \brief given the loss change, whether we need to invode prunning */
|
||||
inline bool need_prune(double loss_chg, int depth) const {
|
||||
return loss_chg < this->min_split_loss;
|
||||
}
|
||||
/*! \brief whether we can split with current hessian */
|
||||
inline bool cannot_split(double sum_hess, int depth) const {
|
||||
return sum_hess < this->min_child_weight * 2.0;
|
||||
}
|
||||
// code support for template data
|
||||
inline double CalcWeight(const GradStats &d) const {
|
||||
return this->CalcWeight(d.sum_grad, d.sum_hess);
|
||||
}
|
||||
inline double CalcGain(const GradStats &d) const {
|
||||
return this->CalcGain(d.sum_grad, d.sum_hess);
|
||||
}
|
||||
|
||||
protected:
|
||||
// functions for L1 cost
|
||||
inline static double ThresholdL1(double w, double lambda) {
|
||||
if (w > +lambda) return w - lambda;
|
||||
if (w < -lambda) return w + lambda;
|
||||
return 0.0;
|
||||
}
|
||||
inline static double Sqr(double a) {
|
||||
return a * a;
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief statistics that is helpful to store
|
||||
* and represent a split solution for the tree
|
||||
*/
|
||||
struct SplitEntry{
|
||||
/*! \brief loss change after split this node */
|
||||
bst_float loss_chg;
|
||||
/*! \brief split index */
|
||||
unsigned sindex;
|
||||
/*! \brief split value */
|
||||
float split_value;
|
||||
/*! \brief constructor */
|
||||
SplitEntry(void) : loss_chg(0.0f), sindex(0), split_value(0.0f) {}
|
||||
/*!
|
||||
* \brief decides whether a we can replace current entry with the statistics given
|
||||
* This function gives better priority to lower index when loss_chg equals
|
||||
* not the best way, but helps to give consistent result during multi-thread execution
|
||||
* \param loss_chg the loss reduction get through the split
|
||||
* \param split_index the feature index where the split is on
|
||||
*/
|
||||
inline bool NeedReplace(bst_float loss_chg, unsigned split_index) const {
|
||||
if (this->split_index() <= split_index) {
|
||||
return loss_chg > this->loss_chg;
|
||||
} else {
|
||||
return !(this->loss_chg > loss_chg);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief update the split entry, replace it if e is better
|
||||
* \param e candidate split solution
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
inline bool Update(const SplitEntry &e) {
|
||||
if (this->NeedReplace(e.loss_chg, e.split_index())) {
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
this->split_value = e.split_value;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief update the split entry, replace it if e is better
|
||||
* \param loss_chg loss reduction of new candidate
|
||||
* \param split_index feature index to split on
|
||||
* \param split_value the split point
|
||||
* \param default_left whether the missing value goes to left
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
inline bool Update(bst_float loss_chg, unsigned split_index,
|
||||
float split_value, bool default_left) {
|
||||
if (this->NeedReplace(loss_chg, split_index)) {
|
||||
this->loss_chg = loss_chg;
|
||||
if (default_left) split_index |= (1U << 31);
|
||||
this->sindex = split_index;
|
||||
this->split_value = split_value;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!\return feature index to split on */
|
||||
inline unsigned split_index(void) const {
|
||||
return sindex & ((1U << 31) - 1U);
|
||||
}
|
||||
/*!\return whether missing value goes to left branch */
|
||||
inline bool default_left(void) const {
|
||||
return (sindex >> 31) != 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_PARAM_H_
|
||||
70
tree/updater.h
Normal file
70
tree/updater.h
Normal file
@@ -0,0 +1,70 @@
|
||||
#ifndef XGBOOST_TREE_UPDATER_H_
|
||||
#define XGBOOST_TREE_UPDATER_H_
|
||||
/*!
|
||||
* \file updater.h
|
||||
* \brief interface to update the tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <vector>
|
||||
|
||||
#include "../data.h"
|
||||
#include "./model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*!
|
||||
* \brief interface of tree update module, that performs update of a tree
|
||||
* \tparam FMatrix the data type updater taking
|
||||
*/
|
||||
template<typename FMatrix>
|
||||
class IUpdater {
|
||||
public:
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
* \param name name of the parameter
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val) = 0;
|
||||
/*!
|
||||
* \brief peform update to the tree models
|
||||
* \param gpair the gradient pair statistics of the data
|
||||
* \param fmat feature matrix that provide access to features
|
||||
* \param root_index pre-partitioned root_index of each instance,
|
||||
* root_index.size() can be 0 which indicates that no pre-partition involved
|
||||
* \param trees pointer to the trese to be updated, upater will change the content of the tree
|
||||
* note: all the trees in the vector are updated, with the same statistics,
|
||||
* but maybe different random seeds, usually one tree is passed in at a time,
|
||||
* there can be multiple trees when we train random forest style model
|
||||
*/
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
const std::vector<RegTree*> &trees) = 0;
|
||||
// destructor
|
||||
virtual ~IUpdater(void) {}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#include "./updater_prune-inl.hpp"
|
||||
#include "./updater_colmaker-inl.hpp"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*!
|
||||
* \brief create a updater based on name
|
||||
* \param name name of updater
|
||||
* \return return the updater instance
|
||||
*/
|
||||
template<typename FMatrix>
|
||||
inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
|
||||
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
|
||||
utils::Error("unknown updater:%s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_H_
|
||||
357
tree/updater_colmaker-inl.hpp
Normal file
357
tree/updater_colmaker-inl.hpp
Normal file
@@ -0,0 +1,357 @@
|
||||
#ifndef XGBOOST_TREE_UPDATER_COLMAKER_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_COLMAKER_INL_HPP_
|
||||
/*!
|
||||
* \file updater_colmaker-inl.hpp
|
||||
* \brief use columnwise update to construct a tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "./param.h"
|
||||
#include "./updater.h"
|
||||
#include "../utils/omp.h"
|
||||
#include "../utils/random.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*! \brief pruner that prunes a tree after growing finishs */
|
||||
template<typename FMatrix, typename TStats>
|
||||
class ColMaker: public IUpdater<FMatrix> {
|
||||
public:
|
||||
virtual ~ColMaker(void) {}
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
param.SetParam(name, val);
|
||||
}
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
fmat.InitColAccess();
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
Builder builder(param);
|
||||
builder.Update(gpair, fmat, root_index, trees[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
// data structure
|
||||
/*! \brief per thread x per node entry to store tmp data */
|
||||
struct ThreadEntry {
|
||||
/*! \brief statistics of data*/
|
||||
TStats stats;
|
||||
/*! \brief last feature value scanned */
|
||||
float last_fvalue;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
// constructor
|
||||
ThreadEntry(void) {
|
||||
stats.Clear();
|
||||
}
|
||||
};
|
||||
struct NodeEntry {
|
||||
/*! \brief statics for node entry */
|
||||
TStats stats;
|
||||
/*! \brief loss of this node, without split */
|
||||
bst_float root_gain;
|
||||
/*! \brief weight calculated related to current data */
|
||||
float weight;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
// constructor
|
||||
NodeEntry(void) : root_gain(0.0f), weight(0.0f){
|
||||
stats.Clear();
|
||||
}
|
||||
};
|
||||
// actual builder that runs the algorithm
|
||||
struct Builder{
|
||||
public:
|
||||
// constructor
|
||||
explicit Builder(const TrainParam ¶m) : param(param) {}
|
||||
// update one tree, growing
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair, FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
RegTree *p_tree) {
|
||||
this->InitData(gpair, fmat, root_index, *p_tree);
|
||||
this->InitNewNode(qexpand, gpair, *p_tree);
|
||||
|
||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree);
|
||||
this->ResetPosition(this->qexpand, fmat, *p_tree);
|
||||
this->UpdateQueueExpand(*p_tree, &this->qexpand);
|
||||
this->InitNewNode(qexpand, gpair, *p_tree);
|
||||
// if nothing left to be expand, break
|
||||
if (qexpand.size() == 0) break;
|
||||
}
|
||||
// set all the rest expanding nodes to leaf
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
(*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
|
||||
}
|
||||
// remember auxiliary statistics in the tree node
|
||||
for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
|
||||
p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
|
||||
p_tree->stat(nid).base_weight = snode[nid].weight;
|
||||
p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// initialize temp data structure
|
||||
inline void InitData(const std::vector<bst_gpair> &gpair, FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index, const RegTree &tree) {
|
||||
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
|
||||
{// setup position
|
||||
position.resize(gpair.size());
|
||||
if (root_index.size() == 0) {
|
||||
std::fill(position.begin(), position.end(), 0);
|
||||
} else {
|
||||
for (size_t i = 0; i < root_index.size(); ++i) {
|
||||
position[i] = root_index[i];
|
||||
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting");
|
||||
}
|
||||
}
|
||||
// mark delete for the deleted datas
|
||||
for (size_t i = 0; i < gpair.size(); ++i) {
|
||||
if (gpair[i].hess < 0.0f) position[i] = -1;
|
||||
}
|
||||
// mark subsample
|
||||
if (param.subsample < 1.0f) {
|
||||
for (size_t i = 0; i < gpair.size(); ++i) {
|
||||
if (gpair[i].hess < 0.0f) continue;
|
||||
if (random::SampleBinary(param.subsample) == 0) position[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// initialize feature index
|
||||
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
|
||||
for (unsigned i = 0; i < ncol; ++i) {
|
||||
if (fmat.GetColSize(i) != 0) feat_index.push_back(i);
|
||||
}
|
||||
unsigned n = static_cast<unsigned>(param.colsample_bytree * feat_index.size());
|
||||
random::Shuffle(feat_index);
|
||||
utils::Check(n > 0, "colsample_bytree is too small that no feature can be included");
|
||||
feat_index.resize(n);
|
||||
}
|
||||
{// setup temp space for each thread
|
||||
#pragma omp parallel
|
||||
{
|
||||
this->nthread = omp_get_num_threads();
|
||||
}
|
||||
// reserve a small space
|
||||
stemp.clear();
|
||||
stemp.resize(this->nthread, std::vector<ThreadEntry>());
|
||||
for (size_t i = 0; i < stemp.size(); ++i) {
|
||||
stemp[i].clear(); stemp[i].reserve(256);
|
||||
}
|
||||
snode.reserve(256);
|
||||
}
|
||||
{// expand query
|
||||
qexpand.reserve(256); qexpand.clear();
|
||||
for (int i = 0; i < tree.param.num_roots; ++i) {
|
||||
qexpand.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief initialize the base_weight, root_gain, and NodeEntry for all the new nodes in qexpand */
|
||||
inline void InitNewNode(const std::vector<int> &qexpand,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const RegTree &tree) {
|
||||
{// setup statistics space for each tree node
|
||||
for (size_t i = 0; i < stemp.size(); ++i) {
|
||||
stemp[i].resize(tree.param.num_nodes, ThreadEntry());
|
||||
}
|
||||
snode.resize(tree.param.num_nodes, NodeEntry());
|
||||
}
|
||||
// setup position
|
||||
const unsigned ndata = static_cast<unsigned>(position.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
if (position[i] < 0) continue;
|
||||
stemp[tid][position[i]].stats.Add(gpair[i]);
|
||||
}
|
||||
// sum the per thread statistics together
|
||||
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||
const int nid = qexpand[j];
|
||||
TStats stats; stats.Clear();
|
||||
for (size_t tid = 0; tid < stemp.size(); ++tid) {
|
||||
stats.Add(stemp[tid][nid].stats);
|
||||
}
|
||||
// update node statistics
|
||||
snode[nid].stats = stats;
|
||||
snode[nid].root_gain = param.CalcGain(stats);
|
||||
snode[nid].weight = param.CalcWeight(stats);
|
||||
}
|
||||
}
|
||||
/*! \brief update queue expand add in new leaves */
|
||||
inline void UpdateQueueExpand(const RegTree &tree, std::vector<int> *p_qexpand) {
|
||||
std::vector<int> &qexpand = *p_qexpand;
|
||||
std::vector<int> newnodes;
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
if (!tree[ nid ].is_leaf()) {
|
||||
newnodes.push_back(tree[nid].cleft());
|
||||
newnodes.push_back(tree[nid].cright());
|
||||
}
|
||||
}
|
||||
// use new nodes for qexpand
|
||||
qexpand = newnodes;
|
||||
}
|
||||
// enumerate the split values of specific feature
|
||||
template<typename Iter>
|
||||
inline void EnumerateSplit(Iter it, unsigned fid,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
std::vector<ThreadEntry> &temp,
|
||||
bool is_forward_search) {
|
||||
// clear all the temp statistics
|
||||
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||
temp[qexpand[j]].stats.Clear();
|
||||
}
|
||||
while (it.Next()) {
|
||||
const bst_uint ridx = it.rindex();
|
||||
const int nid = position[ridx];
|
||||
if (nid < 0) continue;
|
||||
// start working
|
||||
const float fvalue = it.fvalue();
|
||||
// get the statistics of nid
|
||||
ThreadEntry &e = temp[nid];
|
||||
// test if first hit, this is fine, because we set 0 during init
|
||||
if (e.stats.Empty()) {
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.last_fvalue = fvalue;
|
||||
} else {
|
||||
// try to find a split
|
||||
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
||||
TStats c = snode[nid].stats.Substract(e.stats);
|
||||
if (c.sum_hess >= param.min_child_weight) {
|
||||
double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
|
||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
||||
}
|
||||
}
|
||||
// update the statistics
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.last_fvalue = fvalue;
|
||||
}
|
||||
}
|
||||
// finish updating all statistics, check if it is possible to include all sum statistics
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
ThreadEntry &e = temp[nid];
|
||||
TStats c = snode[nid].stats.Substract(e.stats);
|
||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
||||
const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
|
||||
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
|
||||
}
|
||||
}
|
||||
}
|
||||
// find splits at current level, do split per level
|
||||
inline void FindSplit(int depth, const std::vector<int> &qexpand,
|
||||
const std::vector<bst_gpair> &gpair, const FMatrix &fmat,
|
||||
RegTree *p_tree) {
|
||||
std::vector<unsigned> feat_set = feat_index;
|
||||
if (param.colsample_bylevel != 1.0f) {
|
||||
random::Shuffle(feat_set);
|
||||
unsigned n = static_cast<unsigned>(param.colsample_bylevel * feat_index.size());
|
||||
utils::Check(n > 0, "colsample_bylevel is too small that no feature can be included");
|
||||
feat_set.resize(n);
|
||||
}
|
||||
// start enumeration
|
||||
const unsigned nsize = static_cast<unsigned>(feat_set.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (unsigned i = 0; i < nsize; ++i) {
|
||||
const unsigned fid = feat_set[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
||||
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true);
|
||||
}
|
||||
if (param.need_backward_search(fmat.GetColDensity(fid))) {
|
||||
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false);
|
||||
}
|
||||
}
|
||||
// after this each thread's stemp will get the best candidates, aggregate results
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
NodeEntry &e = snode[nid];
|
||||
for (int tid = 0; tid < this->nthread; ++tid) {
|
||||
e.best.Update(stemp[tid][nid].best);
|
||||
}
|
||||
// now we know the solution in snode[nid], set split
|
||||
if (e.best.loss_chg > rt_eps) {
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left());
|
||||
} else {
|
||||
(*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reset position of each data points after split is created in the tree
|
||||
inline void ResetPosition(const std::vector<int> &qexpand, const FMatrix &fmat, const RegTree &tree) {
|
||||
// step 1, set default direct nodes to default, and leaf nodes to -1
|
||||
const unsigned ndata = static_cast<unsigned>(position.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
const int nid = position[i];
|
||||
if (nid >= 0) {
|
||||
if (tree[nid].is_leaf()) {
|
||||
position[i] = -1;
|
||||
} else {
|
||||
// push to default branch, correct latter
|
||||
position[i] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
|
||||
}
|
||||
}
|
||||
}
|
||||
// step 2, classify the non-default data into right places
|
||||
std::vector<unsigned> fsplits;
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
if (!tree[nid].is_leaf()) fsplits.push_back(tree[nid].split_index());
|
||||
}
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
// start put things into right place
|
||||
const unsigned nfeats = static_cast<unsigned>(fsplits.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (unsigned i = 0; i < nfeats; ++i) {
|
||||
const unsigned fid = fsplits[i];
|
||||
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
|
||||
const bst_uint ridx = it.rindex();
|
||||
int nid = position[ridx];
|
||||
if (nid == -1) continue;
|
||||
// go back to parent, correct those who are not default
|
||||
nid = tree[nid].parent();
|
||||
if (tree[nid].split_index() == fid) {
|
||||
if (it.fvalue() < tree[nid].split_cond()) {
|
||||
position[ridx] = tree[nid].cleft();
|
||||
} else {
|
||||
position[ridx] = tree[nid].cright();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//--data fields--
|
||||
const TrainParam ¶m;
|
||||
// number of omp thread used during training
|
||||
int nthread;
|
||||
// Per feature: shuffle index of each feature index
|
||||
std::vector<unsigned> feat_index;
|
||||
// Instance Data: current node position in the tree of each instance
|
||||
std::vector<int> position;
|
||||
// PerThread x PerTreeNode: statistics for per thread construction
|
||||
std::vector< std::vector<ThreadEntry> > stemp;
|
||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||
std::vector<NodeEntry> snode;
|
||||
/*! \brief queue of nodes to be expanded */
|
||||
std::vector<int> qexpand;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_COLMAKER_INL_HPP_
|
||||
67
tree/updater_prune-inl.hpp
Normal file
67
tree/updater_prune-inl.hpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#ifndef XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
|
||||
#define XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
|
||||
/*!
|
||||
* \file updater_prune-inl.hpp
|
||||
* \brief prune a tree given the statistics
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <vector>
|
||||
#include "./param.h"
|
||||
#include "./updater.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*! \brief pruner that prunes a tree after growing finishs */
|
||||
template<typename FMatrix>
|
||||
class TreePruner: public IUpdater<FMatrix> {
|
||||
public:
|
||||
virtual ~TreePruner(void) {}
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
param.SetParam(name, val);
|
||||
}
|
||||
// update the tree, do pruning
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair, FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
const std::vector<RegTree*> &trees) {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
this->DoPrune(*trees[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// try to prune off current leaf
|
||||
inline void TryPruneLeaf(RegTree &tree, int nid, int depth) {
|
||||
if (tree[nid].is_root()) return;
|
||||
int pid = tree[nid].parent();
|
||||
RegTree::NodeStat &s = tree.stat(pid);
|
||||
++s.leaf_child_cnt;
|
||||
|
||||
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
|
||||
// need to be pruned
|
||||
tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight);
|
||||
// tail recursion
|
||||
this->TryPruneLeaf(tree, pid, depth - 1);
|
||||
}
|
||||
}
|
||||
/*! \brief do prunning of a tree */
|
||||
inline void DoPrune(RegTree &tree) {
|
||||
// initialize auxiliary statistics
|
||||
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
||||
tree.stat(nid).leaf_child_cnt = 0;
|
||||
}
|
||||
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
||||
if (tree[nid].is_leaf()) {
|
||||
this->TryPruneLeaf(tree, nid, tree.GetDepth(nid));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_UPDATER_PRUNE_INL_HPP_
|
||||
Reference in New Issue
Block a user