#ifndef XGBOOST_TREE_MODEL_H_ #define XGBOOST_TREE_MODEL_H_ /*! * \file model.h * \brief model structure for tree * \author Tianqi Chen */ #include #include #include #include #include #include #include #include "../utils/io.h" #include "../utils/fmap.h" #include "../utils/utils.h" namespace xgboost { namespace tree { /*! * \brief template class of TreeModel * \tparam TSplitCond data type to indicate split condition * \tparam TNodeStat auxiliary statistics of node to help tree building */ template class TreeModel { public: /*! \brief data type to indicate split condition */ typedef TNodeStat NodeStat; /*! \brief auxiliary statistics of node to help tree building */ typedef TSplitCond SplitCond; /*! \brief parameters of the tree */ struct Param{ /*! \brief number of start root */ int num_roots; /*! \brief total number of nodes */ int num_nodes; /*!\brief number of deleted nodes */ int num_deleted; /*! \brief maximum depth, this is a statistics of the tree */ int max_depth; /*! \brief number of features used for tree construction */ int num_feature; /*! * \brief leaf vector size, used for vector tree * used to store more than one dimensional information in tree */ int size_leaf_vector; /*! \brief reserved part */ int reserved[31]; /*! \brief constructor */ Param(void) { max_depth = 0; size_leaf_vector = 0; std::memset(reserved, 0, sizeof(reserved)); } /*! * \brief set parameters from outside * \param name name of the parameter * \param val value of the parameter */ inline void SetParam(const char *name, const char *val) { using namespace std; if (!strcmp("num_roots", name)) num_roots = atoi(val); if (!strcmp("num_feature", name)) num_feature = atoi(val); if (!strcmp("size_leaf_vector", name)) size_leaf_vector = atoi(val); } }; /*! \brief tree node */ class Node{ public: /*! \brief index of left child */ inline int cleft(void) const { return this->cleft_; } /*! \brief index of right child */ inline int cright(void) const { return this->cright_; } /*! \brief index of default child when feature is missing */ inline int cdefault(void) const { return this->default_left() ? this->cleft() : this->cright(); } /*! \brief feature index of split condition */ inline unsigned split_index(void) const { return sindex_ & ((1U << 31) - 1U); } /*! \brief when feature is unknown, whether goes to left child */ inline bool default_left(void) const { return (sindex_ >> 31) != 0; } /*! \brief whether current node is leaf node */ inline bool is_leaf(void) const { return cleft_ == -1; } /*! \brief get leaf value of leaf node */ inline float leaf_value(void) const { return (this->info_).leaf_value; } /*! \brief get split condition of the node */ inline TSplitCond split_cond(void) const { return (this->info_).split_cond; } /*! \brief get parent of the node */ inline int parent(void) const { return parent_ & ((1U << 31) - 1); } /*! \brief whether current node is left child */ inline bool is_left_child(void) const { return (parent_ & (1U << 31)) != 0; } /*! \brief whether current node is root */ inline bool is_root(void) const { return parent_ == -1; } /*! * \brief set the right child * \param nide node id to right child */ inline void set_right_child(int nid) { this->cright_ = nid; } /*! * \brief set split condition of current node * \param split_index feature index to split * \param split_cond split condition * \param default_left the default direction when feature is unknown */ inline void set_split(unsigned split_index, TSplitCond split_cond, bool default_left = false) { if (default_left) split_index |= (1U << 31); this->sindex_ = split_index; (this->info_).split_cond = split_cond; } /*! * \brief set the leaf value of the node * \param value leaf value * \param right right index, could be used to store * additional information */ inline void set_leaf(float value, int right = -1) { (this->info_).leaf_value = value; this->cleft_ = -1; this->cright_ = right; } private: friend class TreeModel; /*! * \brief in leaf node, we have weights, in non-leaf nodes, * we have split condition */ union Info{ float leaf_value; TSplitCond split_cond; }; // pointer to parent, highest bit is used to // indicate whether it's a left child or not int parent_; // pointer to left, right int cleft_, cright_; // split feature index, left split or right split depends on the highest bit unsigned sindex_; // extra info Info info_; // set parent inline void set_parent(int pidx, bool is_left_child = true) { if (is_left_child) pidx |= (1U << 31); this->parent_ = pidx; } }; protected: // vector of nodes std::vector nodes; // free node space, used during training process std::vector deleted_nodes; // stats of nodes std::vector stats; // leaf vector, that is used to store additional information std::vector leaf_vector; // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize inline int AllocNode(void) { if (param.num_deleted != 0) { int nd = deleted_nodes.back(); deleted_nodes.pop_back(); --param.num_deleted; return nd; } int nd = param.num_nodes++; utils::Check(param.num_nodes < std::numeric_limits::max(), "number of nodes in the tree exceed 2^31"); nodes.resize(param.num_nodes); stats.resize(param.num_nodes); leaf_vector.resize(param.num_nodes * param.size_leaf_vector); return nd; } // delete a tree node inline void DeleteNode(int nid) { utils::Assert(nid >= param.num_roots, "can not delete root"); deleted_nodes.push_back(nid); nodes[nid].set_parent(-1); ++param.num_deleted; } public: /*! * \brief change a non leaf node to a leaf node, delete its children * \param rid node id of the node * \param new leaf value */ inline void ChangeToLeaf(int rid, float value) { utils::Assert(nodes[nodes[rid].cleft() ].is_leaf(), "can not delete a non termial child"); utils::Assert(nodes[nodes[rid].cright()].is_leaf(), "can not delete a non termial child"); this->DeleteNode(nodes[rid].cleft()); this->DeleteNode(nodes[rid].cright()); nodes[rid].set_leaf(value); } /*! * \brief collapse a non leaf node to a leaf node, delete its children * \param rid node id of the node * \param new leaf value */ inline void CollapseToLeaf(int rid, float value) { if (nodes[rid].is_leaf()) return; if (!nodes[nodes[rid].cleft() ].is_leaf()) { CollapseToLeaf(nodes[rid].cleft(), 0.0f); } if (!nodes[nodes[rid].cright() ].is_leaf()) { CollapseToLeaf(nodes[rid].cright(), 0.0f); } this->ChangeToLeaf(rid, value); } public: /*! \brief model parameter */ Param param; /*! \brief constructor */ TreeModel(void) { param.num_nodes = 1; param.num_roots = 1; param.num_deleted = 0; nodes.resize(1); } /*! \brief get node given nid */ inline Node &operator[](int nid) { return nodes[nid]; } /*! \brief get node given nid */ inline const Node &operator[](int nid) const { return nodes[nid]; } /*! \brief get node statistics given nid */ inline NodeStat &stat(int nid) { return stats[nid]; } /*! \brief get leaf vector given nid */ inline bst_float* leafvec(int nid) { if (leaf_vector.size() == 0) return NULL; return &leaf_vector[nid * param.size_leaf_vector]; } /*! \brief get leaf vector given nid */ inline const bst_float* leafvec(int nid) const{ if (leaf_vector.size() == 0) return NULL; return &leaf_vector[nid * param.size_leaf_vector]; } /*! \brief initialize the model */ inline void InitModel(void) { param.num_nodes = param.num_roots; nodes.resize(param.num_nodes); stats.resize(param.num_nodes); leaf_vector.resize(param.num_nodes * param.size_leaf_vector, 0.0f); for (int i = 0; i < param.num_nodes; i ++) { nodes[i].set_leaf(0.0f); nodes[i].set_parent(-1); } } /*! * \brief load model from stream * \param fi input stream */ inline void LoadModel(utils::IStream &fi) { utils::Check(fi.Read(¶m, sizeof(Param)) > 0, "TreeModel: wrong format"); nodes.resize(param.num_nodes); stats.resize(param.num_nodes); utils::Check(fi.Read(&nodes[0], sizeof(Node) * nodes.size()) > 0, "TreeModel: wrong format"); utils::Check(fi.Read(&stats[0], sizeof(NodeStat) * stats.size()) > 0, "TreeModel: wrong format"); if (param.size_leaf_vector != 0) { utils::Check(fi.Read(&leaf_vector), "TreeModel: wrong format"); } // chg deleted nodes deleted_nodes.resize(0); for (int i = param.num_roots; i < param.num_nodes; i ++) { if (nodes[i].is_root()) deleted_nodes.push_back(i); } utils::Assert(static_cast(deleted_nodes.size()) == param.num_deleted, "number of deleted nodes do not match"); } /*! * \brief save model to stream * \param fo output stream */ inline void SaveModel(utils::IStream &fo) const { utils::Assert(param.num_nodes == static_cast(nodes.size()), "Tree::SaveModel"); utils::Assert(param.num_nodes == static_cast(stats.size()), "Tree::SaveModel"); fo.Write(¶m, sizeof(Param)); fo.Write(&nodes[0], sizeof(Node) * nodes.size()); fo.Write(&stats[0], sizeof(NodeStat) * nodes.size()); if (param.size_leaf_vector != 0) fo.Write(leaf_vector); } /*! * \brief add child nodes to node * \param nid node id to add childs */ inline void AddChilds(int nid) { int pleft = this->AllocNode(); int pright = this->AllocNode(); nodes[nid].cleft_ = pleft; nodes[nid].cright_ = pright; nodes[nodes[nid].cleft() ].set_parent(nid, true); nodes[nodes[nid].cright()].set_parent(nid, false); } /*! * \brief only add a right child to a leaf node * \param node id to add right child */ inline void AddRightChild(int nid) { int pright = this->AllocNode(); nodes[nid].right = pright; nodes[nodes[nid].right].set_parent(nid, false); } /*! * \brief get current depth * \param nid node id * \param pass_rchild whether right child is not counted in depth */ inline int GetDepth(int nid, bool pass_rchild = false) const { int depth = 0; while (!nodes[nid].is_root()) { if (!pass_rchild || nodes[nid].is_left_child()) ++depth; nid = nodes[nid].parent(); } return depth; } /*! * \brief get maximum depth * \param nid node id */ inline int MaxDepth(int nid) const { if (nodes[nid].is_leaf()) return 0; return std::max(MaxDepth(nodes[nid].cleft())+1, MaxDepth(nodes[nid].cright())+1); } /*! * \brief get maximum depth */ inline int MaxDepth(void) { int maxd = 0; for (int i = 0; i < param.num_roots; ++i) { maxd = std::max(maxd, MaxDepth(i)); } return maxd; } /*! \brief number of extra nodes besides the root */ inline int num_extra_nodes(void) const { return param.num_nodes - param.num_roots - param.num_deleted; } /*! * \brief dump model to text string * \param fmap feature map of feature types * \param with_stats whether dump out statistics as well * \return the string of dumped model */ inline std::string DumpModel(const utils::FeatMap& fmap, bool with_stats) { std::stringstream fo(""); for (int i = 0; i < param.num_roots; ++i) { this->Dump(i, fo, fmap, 0, with_stats); } return fo.str(); } private: void Dump(int nid, std::stringstream &fo, const utils::FeatMap& fmap, int depth, bool with_stats) { for (int i = 0; i < depth; ++i) { fo << '\t'; } if (nodes[nid].is_leaf()) { fo << nid << ":leaf=" << nodes[nid].leaf_value(); if (with_stats) { stat(nid).Print(fo, true); } fo << '\n'; } else { // right then left, TSplitCond cond = nodes[nid].split_cond(); const unsigned split_index = nodes[nid].split_index(); if (split_index < fmap.size()) { switch (fmap.type(split_index)) { case utils::FeatMap::kIndicator: { int nyes = nodes[nid].default_left() ? nodes[nid].cright() : nodes[nid].cleft(); fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes << ",no=" << nodes[nid].cdefault(); break; } case utils::FeatMap::kInteger: { fo << nid << ":[" << fmap.name(split_index) << "<" << int(float(cond)+1.0f) << "] yes=" << nodes[nid].cleft() << ",no=" << nodes[nid].cright() << ",missing=" << nodes[nid].cdefault(); break; } case utils::FeatMap::kFloat: case utils::FeatMap::kQuantitive: { fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond) << "] yes=" << nodes[nid].cleft() << ",no=" << nodes[nid].cright() << ",missing=" << nodes[nid].cdefault(); break; } default: utils::Error("unknown fmap type"); } } else { fo << nid << ":[f" << split_index << "<"<< float(cond) << "] yes=" << nodes[nid].cleft() << ",no=" << nodes[nid].cright() << ",missing=" << nodes[nid].cdefault(); } if (with_stats) { fo << ' '; stat(nid).Print(fo, false); } fo << '\n'; this->Dump(nodes[nid].cleft(), fo, fmap, depth+1, with_stats); this->Dump(nodes[nid].cright(), fo, fmap, depth+1, with_stats); } } }; /*! \brief node statistics used in regression tree */ struct RTreeNodeStat { /*! \brief loss chg caused by current split */ float loss_chg; /*! \brief sum of hessian values, used to measure coverage of data */ float sum_hess; /*! \brief weight of current node */ float base_weight; /*! \brief number of child that is leaf node known up to now */ int leaf_child_cnt; /*! \brief print information of current stats to fo */ inline void Print(std::stringstream &fo, bool is_leaf) const { if (!is_leaf) { fo << "gain=" << loss_chg << ",cover=" << sum_hess; } else { fo << "cover=" << sum_hess; } } }; /*! \brief define regression tree to be the most common tree model */ class RegTree: public TreeModel{ public: /*! * \brief dense feature vector that can be taken by RegTree * to do tranverse efficiently * and can be construct from sparse feature vector */ struct FVec { /*! * \brief a union value of value and flag * when flag == -1, this indicate the value is missing */ union Entry{ float fvalue; int flag; }; std::vector data; /*! \brief intialize the vector with size vector */ inline void Init(size_t size) { Entry e; e.flag = -1; data.resize(size); std::fill(data.begin(), data.end(), e); } /*! \brief fill the vector with sparse vector */ inline void Fill(const RowBatch::Inst &inst) { for (bst_uint i = 0; i < inst.length; ++i) { data[inst[i].index].fvalue = inst[i].fvalue; } } /*! \brief drop the trace after fill, must be called after fill */ inline void Drop(const RowBatch::Inst &inst) { for (bst_uint i = 0; i < inst.length; ++i) { data[inst[i].index].flag = -1; } } /*! \brief get ith value */ inline float fvalue(size_t i) const { return data[i].fvalue; } /*! \brief check whether i-th entry is missing */ inline bool is_missing(size_t i) const { return data[i].flag == -1; } }; /*! * \brief get the leaf index * \param feats dense feature vector, if the feature is missing the field is set to NaN * \param root_gid starting root index of the instance * \return the leaf index of the given feature */ inline int GetLeafIndex(const FVec&feat, unsigned root_id = 0) const { // start from groups that belongs to current data int pid = static_cast(root_id); // tranverse tree while (!(*this)[ pid ].is_leaf()) { unsigned split_index = (*this)[pid].split_index(); pid = this->GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index)); } return pid; } /*! * \brief get the prediction of regression tree, only accepts dense feature vector * \param feats dense feature vector, if the feature is missing the field is set to NaN * \param root_gid starting root index of the instance * \return the leaf index of the given feature */ inline float Predict(const FVec &feat, unsigned root_id = 0) const { int pid = this->GetLeafIndex(feat, root_id); return (*this)[pid].leaf_value(); } /*! \brief get next position of the tree given current pid */ inline int GetNext(int pid, float fvalue, bool is_unknown) const { float split_value = (*this)[pid].split_cond(); if (is_unknown) { return (*this)[pid].cdefault(); } else { if (fvalue < split_value) { return (*this)[pid].cleft(); } else { return (*this)[pid].cright(); } } } }; } // namespace tree } // namespace xgboost #endif // XGBOOST_TREE_MODEL_H_