refactor grad stats to be like visitor

This commit is contained in:
tqchen 2014-08-24 15:17:22 -07:00
parent d49c6e6e84
commit f71b732e7a
3 changed files with 83 additions and 69 deletions

View File

@ -11,45 +11,6 @@
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
/*! \brief core statistics used for tree construction */
struct GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief constructor */
GradStats(void) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const bst_gpair& b) {
this->Add(b.grad, b.hess);
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief substract the statistics by b */
inline GradStats Substract(const GradStats &b) const {
GradStats res;
res.sum_grad = this->sum_grad - b.sum_grad;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
}
/*! \return whether the statistics is not used yet */
inline bool Empty(void) const {
return sum_hess == 0.0;
}
};
/*! \brief training parameters for regression tree */ /*! \brief training parameters for regression tree */
struct TrainParam{ struct TrainParam{
// learning step size for a time // learning step size for a time
@ -165,13 +126,6 @@ struct TrainParam{
inline bool cannot_split(double sum_hess, int depth) const { inline bool cannot_split(double sum_hess, int depth) const {
return sum_hess < this->min_child_weight * 2.0; return sum_hess < this->min_child_weight * 2.0;
} }
// code support for template data
inline double CalcWeight(const GradStats &d) const {
return this->CalcWeight(d.sum_grad, d.sum_hess);
}
inline double CalcGain(const GradStats &d) const {
return this->CalcGain(d.sum_grad, d.sum_hess);
}
protected: protected:
// functions for L1 cost // functions for L1 cost
@ -185,6 +139,61 @@ struct TrainParam{
} }
}; };
/*! \brief core statistics used for tree construction */
struct GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief constructor */
GradStats(void) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*!
* \brief accumulate statistics,
* \param gpair the vector storing the gradient statistics
* \param info the additional information
* \param ridx instance index of this instance
*/
inline void Add(const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
bst_uint ridx) {
const bst_gpair &b = gpair[ridx];
this->Add(b.grad, b.hess);
}
/*! \brief caculate leaf weight */
inline double CalcWeight(const TrainParam &param) const {
return param.CalcWeight(sum_grad, sum_hess);
}
/*!\brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const {
return param.CalcGain(sum_grad, sum_hess);
}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief substract the statistics by b */
inline GradStats Substract(const GradStats &b) const {
GradStats res;
res.sum_grad = this->sum_grad - b.sum_grad;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
}
/*! \return whether the statistics is not used yet */
inline bool Empty(void) const {
return sum_hess == 0.0;
}
};
/*! /*!
* \brief statistics that is helpful to store * \brief statistics that is helpful to store
* and represent a split solution for the tree * and represent a split solution for the tree

View File

@ -80,13 +80,13 @@ class ColMaker: public IUpdater<FMatrix> {
const BoosterInfo &info, const BoosterInfo &info,
RegTree *p_tree) { RegTree *p_tree) {
this->InitData(gpair, fmat, info.root_index, *p_tree); this->InitData(gpair, fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, *p_tree); this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) { for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree); this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
this->ResetPosition(this->qexpand, fmat, *p_tree); this->ResetPosition(this->qexpand, fmat, *p_tree);
this->UpdateQueueExpand(*p_tree, &this->qexpand); this->UpdateQueueExpand(*p_tree, &this->qexpand);
this->InitNewNode(qexpand, gpair, fmat, *p_tree); this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
// if nothing left to be expand, break // if nothing left to be expand, break
if (qexpand.size() == 0) break; if (qexpand.size() == 0) break;
} }
@ -175,6 +175,7 @@ class ColMaker: public IUpdater<FMatrix> {
inline void InitNewNode(const std::vector<int> &qexpand, inline void InitNewNode(const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const BoosterInfo &info,
const RegTree &tree) { const RegTree &tree) {
{// setup statistics space for each tree node {// setup statistics space for each tree node
for (size_t i = 0; i < stemp.size(); ++i) { for (size_t i = 0; i < stemp.size(); ++i) {
@ -190,7 +191,7 @@ class ColMaker: public IUpdater<FMatrix> {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
if (position[ridx] < 0) continue; if (position[ridx] < 0) continue;
stemp[tid][position[ridx]].stats.Add(gpair[ridx]); stemp[tid][position[ridx]].stats.Add(gpair, info, ridx);
} }
// sum the per thread statistics together // sum the per thread statistics together
for (size_t j = 0; j < qexpand.size(); ++j) { for (size_t j = 0; j < qexpand.size(); ++j) {
@ -201,8 +202,8 @@ class ColMaker: public IUpdater<FMatrix> {
} }
// update node statistics // update node statistics
snode[nid].stats = stats; snode[nid].stats = stats;
snode[nid].root_gain = param.CalcGain(stats); snode[nid].root_gain = stats.CalcGain(param);
snode[nid].weight = param.CalcWeight(stats); snode[nid].weight = stats.CalcWeight(param);
} }
} }
/*! \brief update queue expand add in new leaves */ /*! \brief update queue expand add in new leaves */
@ -223,6 +224,7 @@ class ColMaker: public IUpdater<FMatrix> {
template<typename Iter> template<typename Iter>
inline void EnumerateSplit(Iter it, unsigned fid, inline void EnumerateSplit(Iter it, unsigned fid,
const std::vector<bst_gpair> &gpair, const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
std::vector<ThreadEntry> &temp, std::vector<ThreadEntry> &temp,
bool is_forward_search) { bool is_forward_search) {
// clear all the temp statistics // clear all the temp statistics
@ -239,19 +241,19 @@ class ColMaker: public IUpdater<FMatrix> {
ThreadEntry &e = temp[nid]; ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init // test if first hit, this is fine, because we set 0 during init
if (e.stats.Empty()) { if (e.stats.Empty()) {
e.stats.Add(gpair[ridx]); e.stats.Add(gpair, info, ridx);
e.last_fvalue = fvalue; e.last_fvalue = fvalue;
} else { } else {
// try to find a split // try to find a split
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) { if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
TStats c = snode[nid].stats.Substract(e.stats); TStats c = snode[nid].stats.Substract(e.stats);
if (c.sum_hess >= param.min_child_weight) { if (c.sum_hess >= param.min_child_weight) {
double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain; double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search); e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
} }
} }
// update the statistics // update the statistics
e.stats.Add(gpair[ridx]); e.stats.Add(gpair, info, ridx);
e.last_fvalue = fvalue; e.last_fvalue = fvalue;
} }
} }
@ -261,7 +263,7 @@ class ColMaker: public IUpdater<FMatrix> {
ThreadEntry &e = temp[nid]; ThreadEntry &e = temp[nid];
TStats c = snode[nid].stats.Substract(e.stats); TStats c = snode[nid].stats.Substract(e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain; const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
const float delta = is_forward_search ? rt_eps : -rt_eps; const float delta = is_forward_search ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search); e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
} }
@ -269,7 +271,9 @@ class ColMaker: public IUpdater<FMatrix> {
} }
// find splits at current level, do split per level // find splits at current level, do split per level
inline void FindSplit(int depth, const std::vector<int> &qexpand, inline void FindSplit(int depth, const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair, const FMatrix &fmat, const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const BoosterInfo &info,
RegTree *p_tree) { RegTree *p_tree) {
std::vector<unsigned> feat_set = feat_index; std::vector<unsigned> feat_set = feat_index;
if (param.colsample_bylevel != 1.0f) { if (param.colsample_bylevel != 1.0f) {
@ -288,10 +292,10 @@ class ColMaker: public IUpdater<FMatrix> {
const unsigned fid = feat_set[i]; const unsigned fid = feat_set[i];
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
if (param.need_forward_search(fmat.GetColDensity(fid))) { if (param.need_forward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true); this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true);
} }
if (param.need_backward_search(fmat.GetColDensity(fid))) { if (param.need_backward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false); this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false);
} }
} }
// after this each thread's stemp will get the best candidates, aggregate results // after this each thread's stemp will get the best candidates, aggregate results

View File

@ -65,8 +65,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
RegTree::FVec &feats = fvec_temp[tid]; RegTree::FVec &feats = fvec_temp[tid];
feats.Fill(inst); feats.Fill(inst);
for (size_t j = 0; j < trees.size(); ++j) { for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair[ridx], AddStats(*trees[j], feats, gpair, info, ridx,
info.GetRoot(j),
&stemp[tid * trees.size() + j]); &stemp[tid * trees.size() + j]);
} }
feats.Drop(inst); feats.Drop(inst);
@ -95,31 +94,33 @@ class TreeRefresher: public IUpdater<FMatrix> {
private: private:
inline static void AddStats(const RegTree &tree, inline static void AddStats(const RegTree &tree,
const RegTree::FVec &feat, const RegTree::FVec &feat,
const bst_gpair &gpair, unsigned root_id, const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
const bst_uint ridx,
std::vector<GradStats> *p_gstats) { std::vector<GradStats> *p_gstats) {
std::vector<GradStats> &gstats = *p_gstats; std::vector<GradStats> &gstats = *p_gstats;
// start from groups that belongs to current data // start from groups that belongs to current data
int pid = static_cast<int>(root_id); int pid = static_cast<int>(info.GetRoot(ridx));
gstats[pid].Add(gpair); gstats[pid].Add(gpair, info, ridx);
// tranverse tree // tranverse tree
while (!tree[pid].is_leaf()) { while (!tree[pid].is_leaf()) {
unsigned split_index = tree[pid].split_index(); unsigned split_index = tree[pid].split_index();
pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index)); pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
gstats[pid].Add(gpair); gstats[pid].Add(gpair, info, ridx);
} }
} }
inline void Refresh(const std::vector<GradStats> &gstats, inline void Refresh(const std::vector<GradStats> &gstats,
int nid, RegTree *p_tree) { int nid, RegTree *p_tree) {
RegTree &tree = *p_tree; RegTree &tree = *p_tree;
tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]); tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess); tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate); tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else { } else {
tree.stat(nid).loss_chg = tree.stat(nid).loss_chg =
param.CalcGain(gstats[tree[nid].cleft()]) + gstats[tree[nid].cleft()].CalcGain(param) +
param.CalcGain(gstats[tree[nid].cright()]) - gstats[tree[nid].cright()].CalcGain(param) -
param.CalcGain(gstats[nid]); gstats[nid].CalcGain(param);
this->Refresh(gstats, tree[nid].cleft(), p_tree); this->Refresh(gstats, tree[nid].cleft(), p_tree);
this->Refresh(gstats, tree[nid].cright(), p_tree); this->Refresh(gstats, tree[nid].cright(), p_tree);
} }