refactor grad stats to be like visitor

This commit is contained in:
tqchen 2014-08-24 15:17:22 -07:00
parent d49c6e6e84
commit f71b732e7a
3 changed files with 83 additions and 69 deletions

View File

@ -11,45 +11,6 @@
namespace xgboost {
namespace tree {
/*! \brief core statistics used for tree construction */
struct GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief constructor */
GradStats(void) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const bst_gpair& b) {
this->Add(b.grad, b.hess);
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief substract the statistics by b */
inline GradStats Substract(const GradStats &b) const {
GradStats res;
res.sum_grad = this->sum_grad - b.sum_grad;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
}
/*! \return whether the statistics is not used yet */
inline bool Empty(void) const {
return sum_hess == 0.0;
}
};
/*! \brief training parameters for regression tree */
struct TrainParam{
// learning step size for a time
@ -165,13 +126,6 @@ struct TrainParam{
inline bool cannot_split(double sum_hess, int depth) const {
return sum_hess < this->min_child_weight * 2.0;
}
// code support for template data
inline double CalcWeight(const GradStats &d) const {
return this->CalcWeight(d.sum_grad, d.sum_hess);
}
inline double CalcGain(const GradStats &d) const {
return this->CalcGain(d.sum_grad, d.sum_hess);
}
protected:
// functions for L1 cost
@ -185,6 +139,61 @@ struct TrainParam{
}
};
/*! \brief core statistics used for tree construction */
struct GradStats {
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief constructor */
GradStats(void) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear(void) {
sum_grad = sum_hess = 0.0f;
}
/*!
* \brief accumulate statistics,
* \param gpair the vector storing the gradient statistics
* \param info the additional information
* \param ridx instance index of this instance
*/
inline void Add(const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
bst_uint ridx) {
const bst_gpair &b = gpair[ridx];
this->Add(b.grad, b.hess);
}
/*! \brief caculate leaf weight */
inline double CalcWeight(const TrainParam &param) const {
return param.CalcWeight(sum_grad, sum_hess);
}
/*!\brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const {
return param.CalcGain(sum_grad, sum_hess);
}
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
sum_grad += grad; sum_hess += hess;
}
/*! \brief add statistics to the data */
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief substract the statistics by b */
inline GradStats Substract(const GradStats &b) const {
GradStats res;
res.sum_grad = this->sum_grad - b.sum_grad;
res.sum_hess = this->sum_hess - b.sum_hess;
return res;
}
/*! \return whether the statistics is not used yet */
inline bool Empty(void) const {
return sum_hess == 0.0;
}
};
/*!
* \brief statistics that is helpful to store
* and represent a split solution for the tree

View File

@ -80,13 +80,13 @@ class ColMaker: public IUpdater<FMatrix> {
const BoosterInfo &info,
RegTree *p_tree) {
this->InitData(gpair, fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree);
this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
this->ResetPosition(this->qexpand, fmat, *p_tree);
this->UpdateQueueExpand(*p_tree, &this->qexpand);
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
// if nothing left to be expand, break
if (qexpand.size() == 0) break;
}
@ -175,6 +175,7 @@ class ColMaker: public IUpdater<FMatrix> {
inline void InitNewNode(const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const BoosterInfo &info,
const RegTree &tree) {
{// setup statistics space for each tree node
for (size_t i = 0; i < stemp.size(); ++i) {
@ -190,7 +191,7 @@ class ColMaker: public IUpdater<FMatrix> {
const bst_uint ridx = rowset[i];
const int tid = omp_get_thread_num();
if (position[ridx] < 0) continue;
stemp[tid][position[ridx]].stats.Add(gpair[ridx]);
stemp[tid][position[ridx]].stats.Add(gpair, info, ridx);
}
// sum the per thread statistics together
for (size_t j = 0; j < qexpand.size(); ++j) {
@ -201,8 +202,8 @@ class ColMaker: public IUpdater<FMatrix> {
}
// update node statistics
snode[nid].stats = stats;
snode[nid].root_gain = param.CalcGain(stats);
snode[nid].weight = param.CalcWeight(stats);
snode[nid].root_gain = stats.CalcGain(param);
snode[nid].weight = stats.CalcWeight(param);
}
}
/*! \brief update queue expand add in new leaves */
@ -223,6 +224,7 @@ class ColMaker: public IUpdater<FMatrix> {
template<typename Iter>
inline void EnumerateSplit(Iter it, unsigned fid,
const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
std::vector<ThreadEntry> &temp,
bool is_forward_search) {
// clear all the temp statistics
@ -239,19 +241,19 @@ class ColMaker: public IUpdater<FMatrix> {
ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init
if (e.stats.Empty()) {
e.stats.Add(gpair[ridx]);
e.stats.Add(gpair, info, ridx);
e.last_fvalue = fvalue;
} else {
// try to find a split
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
TStats c = snode[nid].stats.Substract(e.stats);
if (c.sum_hess >= param.min_child_weight) {
double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
}
}
// update the statistics
e.stats.Add(gpair[ridx]);
e.stats.Add(gpair, info, ridx);
e.last_fvalue = fvalue;
}
}
@ -261,7 +263,7 @@ class ColMaker: public IUpdater<FMatrix> {
ThreadEntry &e = temp[nid];
TStats c = snode[nid].stats.Substract(e.stats);
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
const float delta = is_forward_search ? rt_eps : -rt_eps;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
}
@ -269,7 +271,9 @@ class ColMaker: public IUpdater<FMatrix> {
}
// find splits at current level, do split per level
inline void FindSplit(int depth, const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair, const FMatrix &fmat,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const BoosterInfo &info,
RegTree *p_tree) {
std::vector<unsigned> feat_set = feat_index;
if (param.colsample_bylevel != 1.0f) {
@ -288,10 +292,10 @@ class ColMaker: public IUpdater<FMatrix> {
const unsigned fid = feat_set[i];
const int tid = omp_get_thread_num();
if (param.need_forward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true);
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true);
}
if (param.need_backward_search(fmat.GetColDensity(fid))) {
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false);
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false);
}
}
// after this each thread's stemp will get the best candidates, aggregate results

View File

@ -65,8 +65,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
RegTree::FVec &feats = fvec_temp[tid];
feats.Fill(inst);
for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair[ridx],
info.GetRoot(j),
AddStats(*trees[j], feats, gpair, info, ridx,
&stemp[tid * trees.size() + j]);
}
feats.Drop(inst);
@ -95,31 +94,33 @@ class TreeRefresher: public IUpdater<FMatrix> {
private:
inline static void AddStats(const RegTree &tree,
const RegTree::FVec &feat,
const bst_gpair &gpair, unsigned root_id,
const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
const bst_uint ridx,
std::vector<GradStats> *p_gstats) {
std::vector<GradStats> &gstats = *p_gstats;
// start from groups that belongs to current data
int pid = static_cast<int>(root_id);
gstats[pid].Add(gpair);
int pid = static_cast<int>(info.GetRoot(ridx));
gstats[pid].Add(gpair, info, ridx);
// tranverse tree
while (!tree[pid].is_leaf()) {
unsigned split_index = tree[pid].split_index();
pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
gstats[pid].Add(gpair);
gstats[pid].Add(gpair, info, ridx);
}
}
inline void Refresh(const std::vector<GradStats> &gstats,
int nid, RegTree *p_tree) {
RegTree &tree = *p_tree;
tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]);
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else {
tree.stat(nid).loss_chg =
param.CalcGain(gstats[tree[nid].cleft()]) +
param.CalcGain(gstats[tree[nid].cright()]) -
param.CalcGain(gstats[nid]);
gstats[tree[nid].cleft()].CalcGain(param) +
gstats[tree[nid].cright()].CalcGain(param) -
gstats[nid].CalcGain(param);
this->Refresh(gstats, tree[nid].cleft(), p_tree);
this->Refresh(gstats, tree[nid].cright(), p_tree);
}