[TREE] move tree model
This commit is contained in:
@@ -1,573 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file model.h
|
||||
* \brief model structure for tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_MODEL_H_
|
||||
#define XGBOOST_TREE_MODEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/fmap.h"
|
||||
#include "../utils/utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/*!
|
||||
* \brief template class of TreeModel
|
||||
* \tparam TSplitCond data type to indicate split condition
|
||||
* \tparam TNodeStat auxiliary statistics of node to help tree building
|
||||
*/
|
||||
template<typename TSplitCond, typename TNodeStat>
|
||||
class TreeModel {
|
||||
public:
|
||||
/*! \brief data type to indicate split condition */
|
||||
typedef TNodeStat NodeStat;
|
||||
/*! \brief auxiliary statistics of node to help tree building */
|
||||
typedef TSplitCond SplitCond;
|
||||
/*! \brief parameters of the tree */
|
||||
struct Param{
|
||||
/*! \brief number of start root */
|
||||
int num_roots;
|
||||
/*! \brief total number of nodes */
|
||||
int num_nodes;
|
||||
/*!\brief number of deleted nodes */
|
||||
int num_deleted;
|
||||
/*! \brief maximum depth, this is a statistics of the tree */
|
||||
int max_depth;
|
||||
/*! \brief number of features used for tree construction */
|
||||
int num_feature;
|
||||
/*!
|
||||
* \brief leaf vector size, used for vector tree
|
||||
* used to store more than one dimensional information in tree
|
||||
*/
|
||||
int size_leaf_vector;
|
||||
/*! \brief reserved part */
|
||||
int reserved[31];
|
||||
/*! \brief constructor */
|
||||
Param(void) {
|
||||
max_depth = 0;
|
||||
size_leaf_vector = 0;
|
||||
std::memset(reserved, 0, sizeof(reserved));
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
* \param name name of the parameter
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
using namespace std;
|
||||
if (!strcmp("num_roots", name)) num_roots = atoi(val);
|
||||
if (!strcmp("num_feature", name)) num_feature = atoi(val);
|
||||
if (!strcmp("size_leaf_vector", name)) size_leaf_vector = atoi(val);
|
||||
}
|
||||
};
|
||||
/*! \brief tree node */
|
||||
class Node {
|
||||
public:
|
||||
Node(void) : sindex_(0) {}
|
||||
/*! \brief index of left child */
|
||||
inline int cleft(void) const {
|
||||
return this->cleft_;
|
||||
}
|
||||
/*! \brief index of right child */
|
||||
inline int cright(void) const {
|
||||
return this->cright_;
|
||||
}
|
||||
/*! \brief index of default child when feature is missing */
|
||||
inline int cdefault(void) const {
|
||||
return this->default_left() ? this->cleft() : this->cright();
|
||||
}
|
||||
/*! \brief feature index of split condition */
|
||||
inline unsigned split_index(void) const {
|
||||
return sindex_ & ((1U << 31) - 1U);
|
||||
}
|
||||
/*! \brief when feature is unknown, whether goes to left child */
|
||||
inline bool default_left(void) const {
|
||||
return (sindex_ >> 31) != 0;
|
||||
}
|
||||
/*! \brief whether current node is leaf node */
|
||||
inline bool is_leaf(void) const {
|
||||
return cleft_ == -1;
|
||||
}
|
||||
/*! \brief get leaf value of leaf node */
|
||||
inline float leaf_value(void) const {
|
||||
return (this->info_).leaf_value;
|
||||
}
|
||||
/*! \brief get split condition of the node */
|
||||
inline TSplitCond split_cond(void) const {
|
||||
return (this->info_).split_cond;
|
||||
}
|
||||
/*! \brief get parent of the node */
|
||||
inline int parent(void) const {
|
||||
return parent_ & ((1U << 31) - 1);
|
||||
}
|
||||
/*! \brief whether current node is left child */
|
||||
inline bool is_left_child(void) const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
/*! \brief whether this node is deleted */
|
||||
inline bool is_deleted(void) const {
|
||||
return sindex_ == std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
/*! \brief whether current node is root */
|
||||
inline bool is_root(void) const {
|
||||
return parent_ == -1;
|
||||
}
|
||||
/*!
|
||||
* \brief set the right child
|
||||
* \param nide node id to right child
|
||||
*/
|
||||
inline void set_right_child(int nid) {
|
||||
this->cright_ = nid;
|
||||
}
|
||||
/*!
|
||||
* \brief set split condition of current node
|
||||
* \param split_index feature index to split
|
||||
* \param split_cond split condition
|
||||
* \param default_left the default direction when feature is unknown
|
||||
*/
|
||||
inline void set_split(unsigned split_index, TSplitCond split_cond,
|
||||
bool default_left = false) {
|
||||
if (default_left) split_index |= (1U << 31);
|
||||
this->sindex_ = split_index;
|
||||
(this->info_).split_cond = split_cond;
|
||||
}
|
||||
/*!
|
||||
* \brief set the leaf value of the node
|
||||
* \param value leaf value
|
||||
* \param right right index, could be used to store
|
||||
* additional information
|
||||
*/
|
||||
inline void set_leaf(float value, int right = -1) {
|
||||
(this->info_).leaf_value = value;
|
||||
this->cleft_ = -1;
|
||||
this->cright_ = right;
|
||||
}
|
||||
/*! \brief mark that this node is deleted */
|
||||
inline void mark_delete(void) {
|
||||
this->sindex_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
private:
|
||||
friend class TreeModel<TSplitCond, TNodeStat>;
|
||||
/*!
|
||||
* \brief in leaf node, we have weights, in non-leaf nodes,
|
||||
* we have split condition
|
||||
*/
|
||||
union Info{
|
||||
float leaf_value;
|
||||
TSplitCond split_cond;
|
||||
};
|
||||
// pointer to parent, highest bit is used to
|
||||
// indicate whether it's a left child or not
|
||||
int parent_;
|
||||
// pointer to left, right
|
||||
int cleft_, cright_;
|
||||
// split feature index, left split or right split depends on the highest bit
|
||||
unsigned sindex_;
|
||||
// extra info
|
||||
Info info_;
|
||||
// set parent
|
||||
inline void set_parent(int pidx, bool is_left_child = true) {
|
||||
if (is_left_child) pidx |= (1U << 31);
|
||||
this->parent_ = pidx;
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
// vector of nodes
|
||||
std::vector<Node> nodes;
|
||||
// free node space, used during training process
|
||||
std::vector<int> deleted_nodes;
|
||||
// stats of nodes
|
||||
std::vector<TNodeStat> stats;
|
||||
// leaf vector, that is used to store additional information
|
||||
std::vector<bst_float> leaf_vector;
|
||||
// allocate a new node,
|
||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||
inline int AllocNode(void) {
|
||||
if (param.num_deleted != 0) {
|
||||
int nd = deleted_nodes.back();
|
||||
deleted_nodes.pop_back();
|
||||
--param.num_deleted;
|
||||
return nd;
|
||||
}
|
||||
int nd = param.num_nodes++;
|
||||
utils::Check(param.num_nodes < std::numeric_limits<int>::max(),
|
||||
"number of nodes in the tree exceed 2^31");
|
||||
nodes.resize(param.num_nodes);
|
||||
stats.resize(param.num_nodes);
|
||||
leaf_vector.resize(param.num_nodes * param.size_leaf_vector);
|
||||
return nd;
|
||||
}
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
inline void DeleteNode(int nid) {
|
||||
utils::Assert(nid >= param.num_roots, "can not delete root");
|
||||
deleted_nodes.push_back(nid);
|
||||
nodes[nid].mark_delete();
|
||||
++param.num_deleted;
|
||||
}
|
||||
|
||||
public:
|
||||
/*!
|
||||
* \brief change a non leaf node to a leaf node, delete its children
|
||||
* \param rid node id of the node
|
||||
* \param new leaf value
|
||||
*/
|
||||
inline void ChangeToLeaf(int rid, float value) {
|
||||
utils::Assert(nodes[nodes[rid].cleft() ].is_leaf(),
|
||||
"can not delete a non termial child");
|
||||
utils::Assert(nodes[nodes[rid].cright()].is_leaf(),
|
||||
"can not delete a non termial child");
|
||||
this->DeleteNode(nodes[rid].cleft());
|
||||
this->DeleteNode(nodes[rid].cright());
|
||||
nodes[rid].set_leaf(value);
|
||||
}
|
||||
/*!
|
||||
* \brief collapse a non leaf node to a leaf node, delete its children
|
||||
* \param rid node id of the node
|
||||
* \param new leaf value
|
||||
*/
|
||||
inline void CollapseToLeaf(int rid, float value) {
|
||||
if (nodes[rid].is_leaf()) return;
|
||||
if (!nodes[nodes[rid].cleft() ].is_leaf()) {
|
||||
CollapseToLeaf(nodes[rid].cleft(), 0.0f);
|
||||
}
|
||||
if (!nodes[nodes[rid].cright() ].is_leaf()) {
|
||||
CollapseToLeaf(nodes[rid].cright(), 0.0f);
|
||||
}
|
||||
this->ChangeToLeaf(rid, value);
|
||||
}
|
||||
|
||||
public:
|
||||
/*! \brief model parameter */
|
||||
Param param;
|
||||
/*! \brief constructor */
|
||||
TreeModel(void) {
|
||||
param.num_nodes = 1;
|
||||
param.num_roots = 1;
|
||||
param.num_deleted = 0;
|
||||
nodes.resize(1);
|
||||
}
|
||||
/*! \brief get node given nid */
|
||||
inline Node &operator[](int nid) {
|
||||
return nodes[nid];
|
||||
}
|
||||
/*! \brief get node given nid */
|
||||
inline const Node &operator[](int nid) const {
|
||||
return nodes[nid];
|
||||
}
|
||||
/*! \brief get node statistics given nid */
|
||||
inline NodeStat &stat(int nid) {
|
||||
return stats[nid];
|
||||
}
|
||||
/*! \brief get leaf vector given nid */
|
||||
inline bst_float* leafvec(int nid) {
|
||||
if (leaf_vector.size() == 0) return NULL;
|
||||
return &leaf_vector[nid * param.size_leaf_vector];
|
||||
}
|
||||
/*! \brief get leaf vector given nid */
|
||||
inline const bst_float* leafvec(int nid) const {
|
||||
if (leaf_vector.size() == 0) return NULL;
|
||||
return &leaf_vector[nid * param.size_leaf_vector];
|
||||
}
|
||||
/*! \brief initialize the model */
|
||||
inline void InitModel(void) {
|
||||
param.num_nodes = param.num_roots;
|
||||
nodes.resize(param.num_nodes);
|
||||
stats.resize(param.num_nodes);
|
||||
leaf_vector.resize(param.num_nodes * param.size_leaf_vector, 0.0f);
|
||||
for (int i = 0; i < param.num_nodes; i ++) {
|
||||
nodes[i].set_leaf(0.0f);
|
||||
nodes[i].set_parent(-1);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
*/
|
||||
inline void LoadModel(utils::IStream &fi) { // NOLINT(*)
|
||||
utils::Check(fi.Read(¶m, sizeof(Param)) > 0,
|
||||
"TreeModel: wrong format");
|
||||
nodes.resize(param.num_nodes); stats.resize(param.num_nodes);
|
||||
utils::Assert(param.num_nodes != 0, "invalid model");
|
||||
utils::Check(fi.Read(BeginPtr(nodes), sizeof(Node) * nodes.size()) > 0,
|
||||
"TreeModel: wrong format");
|
||||
utils::Check(fi.Read(BeginPtr(stats), sizeof(NodeStat) * stats.size()) > 0,
|
||||
"TreeModel: wrong format");
|
||||
if (param.size_leaf_vector != 0) {
|
||||
utils::Check(fi.Read(&leaf_vector), "TreeModel: wrong format");
|
||||
}
|
||||
// chg deleted nodes
|
||||
deleted_nodes.resize(0);
|
||||
for (int i = param.num_roots; i < param.num_nodes; ++i) {
|
||||
if (nodes[i].is_deleted()) deleted_nodes.push_back(i);
|
||||
}
|
||||
utils::Assert(static_cast<int>(deleted_nodes.size()) == param.num_deleted,
|
||||
"number of deleted nodes do not match, num_deleted=%d, dnsize=%lu, num_nodes=%d",
|
||||
param.num_deleted, deleted_nodes.size(), param.num_nodes);
|
||||
}
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
inline void SaveModel(utils::IStream &fo) const { // NOLINT(*)
|
||||
utils::Assert(param.num_nodes == static_cast<int>(nodes.size()),
|
||||
"TreeModel::SaveModel");
|
||||
utils::Assert(param.num_nodes == static_cast<int>(stats.size()),
|
||||
"TreeModel::SaveModel");
|
||||
fo.Write(¶m, sizeof(Param));
|
||||
utils::Assert(param.num_nodes != 0, "invalid model");
|
||||
fo.Write(BeginPtr(nodes), sizeof(Node) * nodes.size());
|
||||
fo.Write(BeginPtr(stats), sizeof(NodeStat) * nodes.size());
|
||||
if (param.size_leaf_vector != 0) fo.Write(leaf_vector);
|
||||
}
|
||||
/*!
|
||||
* \brief add child nodes to node
|
||||
* \param nid node id to add childs
|
||||
*/
|
||||
inline void AddChilds(int nid) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
nodes[nid].cleft_ = pleft;
|
||||
nodes[nid].cright_ = pright;
|
||||
nodes[nodes[nid].cleft() ].set_parent(nid, true);
|
||||
nodes[nodes[nid].cright()].set_parent(nid, false);
|
||||
}
|
||||
/*!
|
||||
* \brief only add a right child to a leaf node
|
||||
* \param node id to add right child
|
||||
*/
|
||||
inline void AddRightChild(int nid) {
|
||||
int pright = this->AllocNode();
|
||||
nodes[nid].right = pright;
|
||||
nodes[nodes[nid].right].set_parent(nid, false);
|
||||
}
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
* \param pass_rchild whether right child is not counted in depth
|
||||
*/
|
||||
inline int GetDepth(int nid, bool pass_rchild = false) const {
|
||||
int depth = 0;
|
||||
while (!nodes[nid].is_root()) {
|
||||
if (!pass_rchild || nodes[nid].is_left_child()) ++depth;
|
||||
nid = nodes[nid].parent();
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
* \param nid node id
|
||||
*/
|
||||
inline int MaxDepth(int nid) const {
|
||||
if (nodes[nid].is_leaf()) return 0;
|
||||
return std::max(MaxDepth(nodes[nid].cleft())+1,
|
||||
MaxDepth(nodes[nid].cright())+1);
|
||||
}
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
*/
|
||||
inline int MaxDepth(void) {
|
||||
int maxd = 0;
|
||||
for (int i = 0; i < param.num_roots; ++i) {
|
||||
maxd = std::max(maxd, MaxDepth(i));
|
||||
}
|
||||
return maxd;
|
||||
}
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
inline int num_extra_nodes(void) const {
|
||||
return param.num_nodes - param.num_roots - param.num_deleted;
|
||||
}
|
||||
/*!
|
||||
* \brief dump model to text string
|
||||
* \param fmap feature map of feature types
|
||||
* \param with_stats whether dump out statistics as well
|
||||
* \return the string of dumped model
|
||||
*/
|
||||
inline std::string DumpModel(const utils::FeatMap& fmap, bool with_stats) {
|
||||
std::stringstream fo("");
|
||||
for (int i = 0; i < param.num_roots; ++i) {
|
||||
this->Dump(i, fo, fmap, 0, with_stats);
|
||||
}
|
||||
return fo.str();
|
||||
}
|
||||
|
||||
private:
|
||||
void Dump(int nid, std::stringstream &fo, // NOLINT(*)
|
||||
const utils::FeatMap& fmap, int depth, bool with_stats) {
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
fo << '\t';
|
||||
}
|
||||
if (nodes[nid].is_leaf()) {
|
||||
fo << nid << ":leaf=" << nodes[nid].leaf_value();
|
||||
if (with_stats) {
|
||||
stat(nid).Print(fo, true);
|
||||
}
|
||||
fo << '\n';
|
||||
} else {
|
||||
// right then left,
|
||||
TSplitCond cond = nodes[nid].split_cond();
|
||||
const unsigned split_index = nodes[nid].split_index();
|
||||
if (split_index < fmap.size()) {
|
||||
switch (fmap.type(split_index)) {
|
||||
case utils::FeatMap::kIndicator: {
|
||||
int nyes = nodes[nid].default_left() ?
|
||||
nodes[nid].cright() : nodes[nid].cleft();
|
||||
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
|
||||
<< ",no=" << nodes[nid].cdefault();
|
||||
break;
|
||||
}
|
||||
case utils::FeatMap::kInteger: {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "<"
|
||||
<< int(float(cond)+1.0f)
|
||||
<< "] yes=" << nodes[nid].cleft()
|
||||
<< ",no=" << nodes[nid].cright()
|
||||
<< ",missing=" << nodes[nid].cdefault();
|
||||
break;
|
||||
}
|
||||
case utils::FeatMap::kFloat:
|
||||
case utils::FeatMap::kQuantitive: {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond)
|
||||
<< "] yes=" << nodes[nid].cleft()
|
||||
<< ",no=" << nodes[nid].cright()
|
||||
<< ",missing=" << nodes[nid].cdefault();
|
||||
break;
|
||||
}
|
||||
default: utils::Error("unknown fmap type");
|
||||
}
|
||||
} else {
|
||||
fo << nid << ":[f" << split_index << "<"<< float(cond)
|
||||
<< "] yes=" << nodes[nid].cleft()
|
||||
<< ",no=" << nodes[nid].cright()
|
||||
<< ",missing=" << nodes[nid].cdefault();
|
||||
}
|
||||
if (with_stats) {
|
||||
stat(nid).Print(fo, false);
|
||||
}
|
||||
fo << '\n';
|
||||
this->Dump(nodes[nid].cleft(), fo, fmap, depth+1, with_stats);
|
||||
this->Dump(nodes[nid].cright(), fo, fmap, depth+1, with_stats);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief node statistics used in regression tree */
|
||||
struct RTreeNodeStat {
|
||||
/*! \brief loss change caused by current split */
|
||||
float loss_chg;
|
||||
/*! \brief sum of hessian values, used to measure coverage of data */
|
||||
float sum_hess;
|
||||
/*! \brief weight of current node */
|
||||
float base_weight;
|
||||
/*! \brief number of child that is leaf node known up to now */
|
||||
int leaf_child_cnt;
|
||||
/*! \brief print information of current stats to fo */
|
||||
inline void Print(std::stringstream &fo, bool is_leaf) const { // NOLINT(*)
|
||||
if (!is_leaf) {
|
||||
fo << ",gain=" << loss_chg << ",cover=" << sum_hess;
|
||||
} else {
|
||||
fo << ",cover=" << sum_hess;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief define regression tree to be the most common tree model */
|
||||
class RegTree: public TreeModel<bst_float, RTreeNodeStat>{
|
||||
public:
|
||||
/*!
|
||||
* \brief dense feature vector that can be taken by RegTree
|
||||
* to do traverse efficiently
|
||||
* and can be construct from sparse feature vector
|
||||
*/
|
||||
struct FVec {
|
||||
/*!
|
||||
* \brief a union value of value and flag
|
||||
* when flag == -1, this indicate the value is missing
|
||||
*/
|
||||
union Entry{
|
||||
float fvalue;
|
||||
int flag;
|
||||
};
|
||||
std::vector<Entry> data;
|
||||
/*! \brief initialize the vector with size vector */
|
||||
inline void Init(size_t size) {
|
||||
Entry e; e.flag = -1;
|
||||
data.resize(size);
|
||||
std::fill(data.begin(), data.end(), e);
|
||||
}
|
||||
/*! \brief fill the vector with sparse vector */
|
||||
inline void Fill(const RowBatch::Inst &inst) {
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
if (inst[i].index >= data.size()) continue;
|
||||
data[inst[i].index].fvalue = inst[i].fvalue;
|
||||
}
|
||||
}
|
||||
/*! \brief drop the trace after fill, must be called after fill */
|
||||
inline void Drop(const RowBatch::Inst &inst) {
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
if (inst[i].index >= data.size()) continue;
|
||||
data[inst[i].index].flag = -1;
|
||||
}
|
||||
}
|
||||
/*! \brief get ith value */
|
||||
inline float fvalue(size_t i) const {
|
||||
return data[i].fvalue;
|
||||
}
|
||||
/*! \brief check whether i-th entry is missing */
|
||||
inline bool is_missing(size_t i) const {
|
||||
return data[i].flag == -1;
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief get the leaf index
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_id starting root index of the instance
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
inline int GetLeafIndex(const FVec &feat, unsigned root_id = 0) const {
|
||||
// start from groups that belongs to current data
|
||||
int pid = static_cast<int>(root_id);
|
||||
// traverse tree
|
||||
while (!(*this)[ pid ].is_leaf()) {
|
||||
unsigned split_index = (*this)[pid].split_index();
|
||||
pid = this->GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
|
||||
}
|
||||
return pid;
|
||||
}
|
||||
/*!
|
||||
* \brief get the prediction of regression tree, only accepts dense feature vector
|
||||
* \param feats dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_id starting root index of the instance
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
inline float Predict(const FVec &feat, unsigned root_id = 0) const {
|
||||
int pid = this->GetLeafIndex(feat, root_id);
|
||||
return (*this)[pid].leaf_value();
|
||||
}
|
||||
/*! \brief get next position of the tree given current pid */
|
||||
inline int GetNext(int pid, float fvalue, bool is_unknown) const {
|
||||
float split_value = (*this)[pid].split_cond();
|
||||
if (is_unknown) {
|
||||
return (*this)[pid].cdefault();
|
||||
} else {
|
||||
if (fvalue < split_value) {
|
||||
return (*this)[pid].cleft();
|
||||
} else {
|
||||
return (*this)[pid].cright();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_MODEL_H_
|
||||
@@ -1,114 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file group_data.h
|
||||
* \brief this file defines utils to group data by integer keys
|
||||
* Input: given input sequence (key,value), (k1,v1), (k2,v2)
|
||||
* Ouptupt: an array of values data = [v1,v2,v3 .. vn]
|
||||
* and a group pointer ptr,
|
||||
* data[ptr[k]:ptr[k+1]] contains values that corresponds to key k
|
||||
*
|
||||
* This can be used to construct CSR/CSC matrix from un-ordered input
|
||||
* The major algorithm is a two pass linear scan algorithm that requires two pass scan over the data
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_UTILS_GROUP_DATA_H_
|
||||
#define XGBOOST_UTILS_GROUP_DATA_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
namespace utils {
|
||||
/*!
|
||||
* \brief multi-thread version of group builder
|
||||
* \tparam ValueType type of entries in the sparse matrix
|
||||
* \tparam SizeType type of the index range holder
|
||||
*/
|
||||
template<typename ValueType, typename SizeType = size_t>
|
||||
struct ParallelGroupBuilder {
|
||||
public:
|
||||
// parallel group builder of data
|
||||
ParallelGroupBuilder(std::vector<SizeType> *p_rptr,
|
||||
std::vector<ValueType> *p_data)
|
||||
: rptr(*p_rptr), data(*p_data), thread_rptr(tmp_thread_rptr) {
|
||||
}
|
||||
ParallelGroupBuilder(std::vector<SizeType> *p_rptr,
|
||||
std::vector<ValueType> *p_data,
|
||||
std::vector< std::vector<SizeType> > *p_thread_rptr)
|
||||
: rptr(*p_rptr), data(*p_data), thread_rptr(*p_thread_rptr) {
|
||||
}
|
||||
|
||||
public:
|
||||
/*!
|
||||
* \brief step 1: initialize the helper, with hint of number keys
|
||||
* and thread used in the construction
|
||||
* \param nkeys number of keys in the matrix, can be smaller than expected
|
||||
* \param nthread number of thread that will be used in construction
|
||||
*/
|
||||
inline void InitBudget(size_t nkeys, int nthread) {
|
||||
thread_rptr.resize(nthread);
|
||||
for (size_t i = 0; i < thread_rptr.size(); ++i) {
|
||||
thread_rptr[i].resize(nkeys);
|
||||
std::fill(thread_rptr[i].begin(), thread_rptr[i].end(), 0);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief step 2: add budget to each key
|
||||
* \param key the key
|
||||
* \param threadid the id of thread that calls this function
|
||||
* \param nelem number of element budget add to this row
|
||||
*/
|
||||
inline void AddBudget(size_t key, int threadid, SizeType nelem = 1) {
|
||||
std::vector<SizeType> &trptr = thread_rptr[threadid];
|
||||
if (trptr.size() < key + 1) {
|
||||
trptr.resize(key + 1, 0);
|
||||
}
|
||||
trptr[key] += nelem;
|
||||
}
|
||||
/*! \brief step 3: initialize the necessary storage */
|
||||
inline void InitStorage(void) {
|
||||
// set rptr to correct size
|
||||
for (size_t tid = 0; tid < thread_rptr.size(); ++tid) {
|
||||
if (rptr.size() <= thread_rptr[tid].size()) {
|
||||
rptr.resize(thread_rptr[tid].size() + 1);
|
||||
}
|
||||
}
|
||||
// initialize rptr to be beginning of each segment
|
||||
size_t start = 0;
|
||||
for (size_t i = 0; i + 1 < rptr.size(); ++i) {
|
||||
for (size_t tid = 0; tid < thread_rptr.size(); ++tid) {
|
||||
std::vector<SizeType> &trptr = thread_rptr[tid];
|
||||
if (i < trptr.size()) {
|
||||
size_t ncnt = trptr[i];
|
||||
trptr[i] = start;
|
||||
start += ncnt;
|
||||
}
|
||||
}
|
||||
rptr[i + 1] = start;
|
||||
}
|
||||
data.resize(start);
|
||||
}
|
||||
/*!
|
||||
* \brief step 4: add data to the allocated space,
|
||||
* the calls to this function should be exactly match previous call to AddBudget
|
||||
*
|
||||
* \param key the key of
|
||||
* \param threadid the id of thread that calls this function
|
||||
*/
|
||||
inline void Push(size_t key, ValueType value, int threadid) {
|
||||
SizeType &rp = thread_rptr[threadid][key];
|
||||
data[rp++] = value;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief pointer to the beginning and end of each continuous key */
|
||||
std::vector<SizeType> &rptr;
|
||||
/*! \brief index of nonzero entries in each row */
|
||||
std::vector<ValueType> &data;
|
||||
/*! \brief thread local data structure */
|
||||
std::vector< std::vector<SizeType> > &thread_rptr;
|
||||
/*! \brief local temp thread ptr, use this if not specified by the constructor */
|
||||
std::vector< std::vector<SizeType> > tmp_thread_rptr;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_UTILS_GROUP_DATA_H_
|
||||
Reference in New Issue
Block a user