tstats now depend on param

This commit is contained in:
tqchen 2014-08-24 16:08:58 -07:00
parent 49e6575c86
commit 4889b40abc
4 changed files with 51 additions and 27 deletions

View File

@ -42,11 +42,17 @@ class TreeModel {
int max_depth; int max_depth;
/*! \brief number of features used for tree construction */ /*! \brief number of features used for tree construction */
int num_feature; int num_feature;
/*!
* \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree
*/
int size_leaf_vector;
/*! \brief reserved part */ /*! \brief reserved part */
int reserved[32]; int reserved[31];
/*! \brief constructor */ /*! \brief constructor */
Param(void) { Param(void) {
max_depth = 0; max_depth = 0;
size_leaf_vector = 0;
memset(reserved, 0, sizeof(reserved)); memset(reserved, 0, sizeof(reserved));
} }
/*! /*!
@ -57,6 +63,7 @@ class TreeModel {
inline void SetParam(const char *name, const char *val) { inline void SetParam(const char *name, const char *val) {
if (!strcmp("num_roots", name)) num_roots = atoi(val); if (!strcmp("num_roots", name)) num_roots = atoi(val);
if (!strcmp("num_feature", name)) num_feature = atoi(val); if (!strcmp("num_feature", name)) num_feature = atoi(val);
if (!strcmp("size_leaf_vector", name)) size_leaf_vector = atoi(val);
} }
}; };
/*! \brief tree node */ /*! \brief tree node */
@ -166,10 +173,12 @@ class TreeModel {
protected: protected:
// vector of nodes // vector of nodes
std::vector<Node> nodes; std::vector<Node> nodes;
// stats of nodes
std::vector<TNodeStat> stats;
// free node space, used during training process // free node space, used during training process
std::vector<int> deleted_nodes; std::vector<int> deleted_nodes;
// stats of nodes
std::vector<TNodeStat> stats;
// leaf vector, that is used to store additional information
std::vector<bst_float> leaf_vector;
// allocate a new node, // allocate a new node,
// !!!!!! NOTE: may cause BUG here, nodes.resize // !!!!!! NOTE: may cause BUG here, nodes.resize
inline int AllocNode(void) { inline int AllocNode(void) {
@ -184,6 +193,7 @@ class TreeModel {
"number of nodes in the tree exceed 2^31"); "number of nodes in the tree exceed 2^31");
nodes.resize(param.num_nodes); nodes.resize(param.num_nodes);
stats.resize(param.num_nodes); stats.resize(param.num_nodes);
leaf_vector.resize(param.num_nodes * param.size_leaf_vector);
return nd; return nd;
} }
// delete a tree node // delete a tree node
@ -247,6 +257,14 @@ class TreeModel {
inline NodeStat &stat(int nid) { inline NodeStat &stat(int nid) {
return stats[nid]; return stats[nid];
} }
/*! \brief get leaf vector given nid */
inline bst_float* leafvec(int nid) {
return &leaf_vector[nid * param.size_leaf_vector];
}
/*! \brief get leaf vector given nid */
inline const bst_float* leafvec(int nid) const{
return &leaf_vector[nid * param.size_leaf_vector];
}
/*! \brief initialize the model */ /*! \brief initialize the model */
inline void InitModel(void) { inline void InitModel(void) {
param.num_nodes = param.num_roots; param.num_nodes = param.num_roots;

View File

@ -145,8 +145,8 @@ struct GradStats {
double sum_grad; double sum_grad;
/*! \brief sum hessian statistics */ /*! \brief sum hessian statistics */
double sum_hess; double sum_hess;
/*! \brief constructor */ /*! \brief constructor, the object must be cleared during construction */
GradStats(void) { explicit GradStats(const TrainParam &param) {
this->Clear(); this->Clear();
} }
/*! \brief clear the statistics */ /*! \brief clear the statistics */
@ -169,29 +169,31 @@ struct GradStats {
inline double CalcWeight(const TrainParam &param) const { inline double CalcWeight(const TrainParam &param) const {
return param.CalcWeight(sum_grad, sum_hess); return param.CalcWeight(sum_grad, sum_hess);
} }
/*!\brief calculate gain of the solution */ /*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const { inline double CalcGain(const TrainParam &param) const {
return param.CalcGain(sum_grad, sum_hess); return param.CalcGain(sum_grad, sum_hess);
} }
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) { inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess); this->Add(b.sum_grad, b.sum_hess);
} }
/*! \brief substract the statistics by b */ /*! \brief set current value to a - b */
inline GradStats Substract(const GradStats &b) const { inline void SetSubstract(const GradStats &a, const GradStats &b) {
GradStats res; sum_grad = a.sum_grad - b.sum_grad;
res.sum_grad = this->sum_grad - b.sum_grad; sum_hess = a.sum_hess - b.sum_hess;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
} }
/*! \return whether the statistics is not used yet */ /*! \return whether the statistics is not used yet */
inline bool Empty(void) const { inline bool Empty(void) const {
return sum_hess == 0.0; return sum_hess == 0.0;
} }
/*! \brief set leaf vector value based on statistics */
inline void SetLeafVec(const TrainParam &param, bst_float *vec) const{
}
protected:
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
}; };
/*! /*!

View File

@ -51,8 +51,8 @@ class ColMaker: public IUpdater<FMatrix> {
/*! \brief current best solution */ /*! \brief current best solution */
SplitEntry best; SplitEntry best;
// constructor // constructor
ThreadEntry(void) { explicit ThreadEntry(const TrainParam &param)
stats.Clear(); : stats(param) {
} }
}; };
struct NodeEntry { struct NodeEntry {
@ -65,8 +65,8 @@ class ColMaker: public IUpdater<FMatrix> {
/*! \brief current best solution */ /*! \brief current best solution */
SplitEntry best; SplitEntry best;
// constructor // constructor
NodeEntry(void) : root_gain(0.0f), weight(0.0f){ explicit NodeEntry(const TrainParam &param)
stats.Clear(); : stats(param), root_gain(0.0f), weight(0.0f){
} }
}; };
// actual builder that runs the algorithm // actual builder that runs the algorithm
@ -100,6 +100,7 @@ class ColMaker: public IUpdater<FMatrix> {
p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg; p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
p_tree->stat(nid).base_weight = snode[nid].weight; p_tree->stat(nid).base_weight = snode[nid].weight;
p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess); p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid));
} }
} }
@ -179,9 +180,9 @@ class ColMaker: public IUpdater<FMatrix> {
const RegTree &tree) { const RegTree &tree) {
{// setup statistics space for each tree node {// setup statistics space for each tree node
for (size_t i = 0; i < stemp.size(); ++i) { for (size_t i = 0; i < stemp.size(); ++i) {
stemp[i].resize(tree.param.num_nodes, ThreadEntry()); stemp[i].resize(tree.param.num_nodes, ThreadEntry(param));
} }
snode.resize(tree.param.num_nodes, NodeEntry()); snode.resize(tree.param.num_nodes, NodeEntry(param));
} }
const std::vector<bst_uint> &rowset = fmat.buffered_rowset(); const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
// setup position // setup position
@ -196,7 +197,7 @@ class ColMaker: public IUpdater<FMatrix> {
// sum the per thread statistics together // sum the per thread statistics together
for (size_t j = 0; j < qexpand.size(); ++j) { for (size_t j = 0; j < qexpand.size(); ++j) {
const int nid = qexpand[j]; const int nid = qexpand[j];
TStats stats; stats.Clear(); TStats stats(param);
for (size_t tid = 0; tid < stemp.size(); ++tid) { for (size_t tid = 0; tid < stemp.size(); ++tid) {
stats.Add(stemp[tid][nid].stats); stats.Add(stemp[tid][nid].stats);
} }
@ -231,6 +232,8 @@ class ColMaker: public IUpdater<FMatrix> {
for (size_t j = 0; j < qexpand.size(); ++j) { for (size_t j = 0; j < qexpand.size(); ++j) {
temp[qexpand[j]].stats.Clear(); temp[qexpand[j]].stats.Clear();
} }
// left statistics
TStats c(param);
while (it.Next()) { while (it.Next()) {
const bst_uint ridx = it.rindex(); const bst_uint ridx = it.rindex();
const int nid = position[ridx]; const int nid = position[ridx];
@ -246,7 +249,7 @@ class ColMaker: public IUpdater<FMatrix> {
} else { } else {
// try to find a split // try to find a split
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) { if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
TStats c = snode[nid].stats.Substract(e.stats); c.SetSubstract(snode[nid].stats, e.stats);
if (c.sum_hess >= param.min_child_weight) { if (c.sum_hess >= param.min_child_weight) {
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain; double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search); e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
@ -261,7 +264,7 @@ class ColMaker: public IUpdater<FMatrix> {
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
ThreadEntry &e = temp[nid]; ThreadEntry &e = temp[nid];
TStats c = snode[nid].stats.Substract(e.stats); c.SetSubstract(snode[nid].stats, e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain; const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
const float delta = is_forward_search ? rt_eps : -rt_eps; const float delta = is_forward_search ? rt_eps : -rt_eps;

View File

@ -44,8 +44,8 @@ class TreeRefresher: public IUpdater<FMatrix> {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
std::vector<TStats> &vec = stemp[tid * trees.size() + i]; std::vector<TStats> &vec = stemp[tid * trees.size() + i];
vec.resize(trees[i]->param.num_nodes); vec.resize(trees[i]->param.num_nodes, TStats(param));
std::fill(vec.begin(), vec.end(), TStats()); std::fill(vec.begin(), vec.end(), TStats(param));
} }
fvec_temp[tid].Init(trees[0]->param.num_feature); fvec_temp[tid].Init(trees[0]->param.num_feature);
} }
@ -114,6 +114,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
RegTree &tree = *p_tree; RegTree &tree = *p_tree;
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param); tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess); tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate); tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else { } else {