refactor grad stats to be like visitor
This commit is contained in:
parent
d49c6e6e84
commit
f71b732e7a
101
src/tree/param.h
101
src/tree/param.h
@ -11,45 +11,6 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
/*! \brief core statistics used for tree construction */
|
||||
struct GradStats {
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief constructor */
|
||||
GradStats(void) {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \brief clear the statistics */
|
||||
inline void Clear(void) {
|
||||
sum_grad = sum_hess = 0.0f;
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(double grad, double hess) {
|
||||
sum_grad += grad; sum_hess += hess;
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const bst_gpair& b) {
|
||||
this->Add(b.grad, b.hess);
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const GradStats &b) {
|
||||
this->Add(b.sum_grad, b.sum_hess);
|
||||
}
|
||||
/*! \brief substract the statistics by b */
|
||||
inline GradStats Substract(const GradStats &b) const {
|
||||
GradStats res;
|
||||
res.sum_grad = this->sum_grad - b.sum_grad;
|
||||
res.sum_hess = this->sum_hess - b.sum_hess;
|
||||
return res;
|
||||
}
|
||||
/*! \return whether the statistics is not used yet */
|
||||
inline bool Empty(void) const {
|
||||
return sum_hess == 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief training parameters for regression tree */
|
||||
struct TrainParam{
|
||||
// learning step size for a time
|
||||
@ -165,13 +126,6 @@ struct TrainParam{
|
||||
inline bool cannot_split(double sum_hess, int depth) const {
|
||||
return sum_hess < this->min_child_weight * 2.0;
|
||||
}
|
||||
// code support for template data
|
||||
inline double CalcWeight(const GradStats &d) const {
|
||||
return this->CalcWeight(d.sum_grad, d.sum_hess);
|
||||
}
|
||||
inline double CalcGain(const GradStats &d) const {
|
||||
return this->CalcGain(d.sum_grad, d.sum_hess);
|
||||
}
|
||||
|
||||
protected:
|
||||
// functions for L1 cost
|
||||
@ -185,6 +139,61 @@ struct TrainParam{
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief core statistics used for tree construction */
|
||||
struct GradStats {
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief constructor */
|
||||
GradStats(void) {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \brief clear the statistics */
|
||||
inline void Clear(void) {
|
||||
sum_grad = sum_hess = 0.0f;
|
||||
}
|
||||
/*!
|
||||
* \brief accumulate statistics,
|
||||
* \param gpair the vector storing the gradient statistics
|
||||
* \param info the additional information
|
||||
* \param ridx instance index of this instance
|
||||
*/
|
||||
inline void Add(const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
bst_uint ridx) {
|
||||
const bst_gpair &b = gpair[ridx];
|
||||
this->Add(b.grad, b.hess);
|
||||
}
|
||||
/*! \brief caculate leaf weight */
|
||||
inline double CalcWeight(const TrainParam ¶m) const {
|
||||
return param.CalcWeight(sum_grad, sum_hess);
|
||||
}
|
||||
/*!\brief calculate gain of the solution */
|
||||
inline double CalcGain(const TrainParam ¶m) const {
|
||||
return param.CalcGain(sum_grad, sum_hess);
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(double grad, double hess) {
|
||||
sum_grad += grad; sum_hess += hess;
|
||||
}
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const GradStats &b) {
|
||||
this->Add(b.sum_grad, b.sum_hess);
|
||||
}
|
||||
/*! \brief substract the statistics by b */
|
||||
inline GradStats Substract(const GradStats &b) const {
|
||||
GradStats res;
|
||||
res.sum_grad = this->sum_grad - b.sum_grad;
|
||||
res.sum_hess = this->sum_hess - b.sum_hess;
|
||||
return res;
|
||||
}
|
||||
/*! \return whether the statistics is not used yet */
|
||||
inline bool Empty(void) const {
|
||||
return sum_hess == 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief statistics that is helpful to store
|
||||
* and represent a split solution for the tree
|
||||
|
||||
@ -80,13 +80,13 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
const BoosterInfo &info,
|
||||
RegTree *p_tree) {
|
||||
this->InitData(gpair, fmat, info.root_index, *p_tree);
|
||||
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
|
||||
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
|
||||
|
||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree);
|
||||
this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
|
||||
this->ResetPosition(this->qexpand, fmat, *p_tree);
|
||||
this->UpdateQueueExpand(*p_tree, &this->qexpand);
|
||||
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
|
||||
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
|
||||
// if nothing left to be expand, break
|
||||
if (qexpand.size() == 0) break;
|
||||
}
|
||||
@ -175,6 +175,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
inline void InitNewNode(const std::vector<int> &qexpand,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const FMatrix &fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) {
|
||||
{// setup statistics space for each tree node
|
||||
for (size_t i = 0; i < stemp.size(); ++i) {
|
||||
@ -190,7 +191,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
const bst_uint ridx = rowset[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
if (position[ridx] < 0) continue;
|
||||
stemp[tid][position[ridx]].stats.Add(gpair[ridx]);
|
||||
stemp[tid][position[ridx]].stats.Add(gpair, info, ridx);
|
||||
}
|
||||
// sum the per thread statistics together
|
||||
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||
@ -201,8 +202,8 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
}
|
||||
// update node statistics
|
||||
snode[nid].stats = stats;
|
||||
snode[nid].root_gain = param.CalcGain(stats);
|
||||
snode[nid].weight = param.CalcWeight(stats);
|
||||
snode[nid].root_gain = stats.CalcGain(param);
|
||||
snode[nid].weight = stats.CalcWeight(param);
|
||||
}
|
||||
}
|
||||
/*! \brief update queue expand add in new leaves */
|
||||
@ -223,6 +224,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
template<typename Iter>
|
||||
inline void EnumerateSplit(Iter it, unsigned fid,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
std::vector<ThreadEntry> &temp,
|
||||
bool is_forward_search) {
|
||||
// clear all the temp statistics
|
||||
@ -239,19 +241,19 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
ThreadEntry &e = temp[nid];
|
||||
// test if first hit, this is fine, because we set 0 during init
|
||||
if (e.stats.Empty()) {
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.stats.Add(gpair, info, ridx);
|
||||
e.last_fvalue = fvalue;
|
||||
} else {
|
||||
// try to find a split
|
||||
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
||||
TStats c = snode[nid].stats.Substract(e.stats);
|
||||
if (c.sum_hess >= param.min_child_weight) {
|
||||
double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
|
||||
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
||||
}
|
||||
}
|
||||
// update the statistics
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.stats.Add(gpair, info, ridx);
|
||||
e.last_fvalue = fvalue;
|
||||
}
|
||||
}
|
||||
@ -261,7 +263,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
ThreadEntry &e = temp[nid];
|
||||
TStats c = snode[nid].stats.Substract(e.stats);
|
||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
||||
const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
|
||||
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
|
||||
}
|
||||
@ -269,7 +271,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
}
|
||||
// find splits at current level, do split per level
|
||||
inline void FindSplit(int depth, const std::vector<int> &qexpand,
|
||||
const std::vector<bst_gpair> &gpair, const FMatrix &fmat,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const FMatrix &fmat,
|
||||
const BoosterInfo &info,
|
||||
RegTree *p_tree) {
|
||||
std::vector<unsigned> feat_set = feat_index;
|
||||
if (param.colsample_bylevel != 1.0f) {
|
||||
@ -288,10 +292,10 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
const unsigned fid = feat_set[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
||||
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true);
|
||||
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true);
|
||||
}
|
||||
if (param.need_backward_search(fmat.GetColDensity(fid))) {
|
||||
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false);
|
||||
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false);
|
||||
}
|
||||
}
|
||||
// after this each thread's stemp will get the best candidates, aggregate results
|
||||
|
||||
@ -65,8 +65,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
||||
RegTree::FVec &feats = fvec_temp[tid];
|
||||
feats.Fill(inst);
|
||||
for (size_t j = 0; j < trees.size(); ++j) {
|
||||
AddStats(*trees[j], feats, gpair[ridx],
|
||||
info.GetRoot(j),
|
||||
AddStats(*trees[j], feats, gpair, info, ridx,
|
||||
&stemp[tid * trees.size() + j]);
|
||||
}
|
||||
feats.Drop(inst);
|
||||
@ -95,31 +94,33 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
||||
private:
|
||||
inline static void AddStats(const RegTree &tree,
|
||||
const RegTree::FVec &feat,
|
||||
const bst_gpair &gpair, unsigned root_id,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
const BoosterInfo &info,
|
||||
const bst_uint ridx,
|
||||
std::vector<GradStats> *p_gstats) {
|
||||
std::vector<GradStats> &gstats = *p_gstats;
|
||||
// start from groups that belongs to current data
|
||||
int pid = static_cast<int>(root_id);
|
||||
gstats[pid].Add(gpair);
|
||||
int pid = static_cast<int>(info.GetRoot(ridx));
|
||||
gstats[pid].Add(gpair, info, ridx);
|
||||
// tranverse tree
|
||||
while (!tree[pid].is_leaf()) {
|
||||
unsigned split_index = tree[pid].split_index();
|
||||
pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
|
||||
gstats[pid].Add(gpair);
|
||||
gstats[pid].Add(gpair, info, ridx);
|
||||
}
|
||||
}
|
||||
inline void Refresh(const std::vector<GradStats> &gstats,
|
||||
int nid, RegTree *p_tree) {
|
||||
RegTree &tree = *p_tree;
|
||||
tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]);
|
||||
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
||||
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
||||
if (tree[nid].is_leaf()) {
|
||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||
} else {
|
||||
tree.stat(nid).loss_chg =
|
||||
param.CalcGain(gstats[tree[nid].cleft()]) +
|
||||
param.CalcGain(gstats[tree[nid].cright()]) -
|
||||
param.CalcGain(gstats[nid]);
|
||||
gstats[tree[nid].cleft()].CalcGain(param) +
|
||||
gstats[tree[nid].cright()].CalcGain(param) -
|
||||
gstats[nid].CalcGain(param);
|
||||
this->Refresh(gstats, tree[nid].cleft(), p_tree);
|
||||
this->Refresh(gstats, tree[nid].cright(), p_tree);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user