refactor grad stats to be like visitor
This commit is contained in:
parent
d49c6e6e84
commit
f71b732e7a
101
src/tree/param.h
101
src/tree/param.h
@ -11,45 +11,6 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
/*! \brief core statistics used for tree construction */
|
|
||||||
struct GradStats {
|
|
||||||
/*! \brief sum gradient statistics */
|
|
||||||
double sum_grad;
|
|
||||||
/*! \brief sum hessian statistics */
|
|
||||||
double sum_hess;
|
|
||||||
/*! \brief constructor */
|
|
||||||
GradStats(void) {
|
|
||||||
this->Clear();
|
|
||||||
}
|
|
||||||
/*! \brief clear the statistics */
|
|
||||||
inline void Clear(void) {
|
|
||||||
sum_grad = sum_hess = 0.0f;
|
|
||||||
}
|
|
||||||
/*! \brief add statistics to the data */
|
|
||||||
inline void Add(double grad, double hess) {
|
|
||||||
sum_grad += grad; sum_hess += hess;
|
|
||||||
}
|
|
||||||
/*! \brief add statistics to the data */
|
|
||||||
inline void Add(const bst_gpair& b) {
|
|
||||||
this->Add(b.grad, b.hess);
|
|
||||||
}
|
|
||||||
/*! \brief add statistics to the data */
|
|
||||||
inline void Add(const GradStats &b) {
|
|
||||||
this->Add(b.sum_grad, b.sum_hess);
|
|
||||||
}
|
|
||||||
/*! \brief substract the statistics by b */
|
|
||||||
inline GradStats Substract(const GradStats &b) const {
|
|
||||||
GradStats res;
|
|
||||||
res.sum_grad = this->sum_grad - b.sum_grad;
|
|
||||||
res.sum_hess = this->sum_hess - b.sum_hess;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
/*! \return whether the statistics is not used yet */
|
|
||||||
inline bool Empty(void) const {
|
|
||||||
return sum_hess == 0.0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief training parameters for regression tree */
|
/*! \brief training parameters for regression tree */
|
||||||
struct TrainParam{
|
struct TrainParam{
|
||||||
// learning step size for a time
|
// learning step size for a time
|
||||||
@ -165,13 +126,6 @@ struct TrainParam{
|
|||||||
inline bool cannot_split(double sum_hess, int depth) const {
|
inline bool cannot_split(double sum_hess, int depth) const {
|
||||||
return sum_hess < this->min_child_weight * 2.0;
|
return sum_hess < this->min_child_weight * 2.0;
|
||||||
}
|
}
|
||||||
// code support for template data
|
|
||||||
inline double CalcWeight(const GradStats &d) const {
|
|
||||||
return this->CalcWeight(d.sum_grad, d.sum_hess);
|
|
||||||
}
|
|
||||||
inline double CalcGain(const GradStats &d) const {
|
|
||||||
return this->CalcGain(d.sum_grad, d.sum_hess);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// functions for L1 cost
|
// functions for L1 cost
|
||||||
@ -185,6 +139,61 @@ struct TrainParam{
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*! \brief core statistics used for tree construction */
|
||||||
|
struct GradStats {
|
||||||
|
/*! \brief sum gradient statistics */
|
||||||
|
double sum_grad;
|
||||||
|
/*! \brief sum hessian statistics */
|
||||||
|
double sum_hess;
|
||||||
|
/*! \brief constructor */
|
||||||
|
GradStats(void) {
|
||||||
|
this->Clear();
|
||||||
|
}
|
||||||
|
/*! \brief clear the statistics */
|
||||||
|
inline void Clear(void) {
|
||||||
|
sum_grad = sum_hess = 0.0f;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief accumulate statistics,
|
||||||
|
* \param gpair the vector storing the gradient statistics
|
||||||
|
* \param info the additional information
|
||||||
|
* \param ridx instance index of this instance
|
||||||
|
*/
|
||||||
|
inline void Add(const std::vector<bst_gpair> &gpair,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
bst_uint ridx) {
|
||||||
|
const bst_gpair &b = gpair[ridx];
|
||||||
|
this->Add(b.grad, b.hess);
|
||||||
|
}
|
||||||
|
/*! \brief caculate leaf weight */
|
||||||
|
inline double CalcWeight(const TrainParam ¶m) const {
|
||||||
|
return param.CalcWeight(sum_grad, sum_hess);
|
||||||
|
}
|
||||||
|
/*!\brief calculate gain of the solution */
|
||||||
|
inline double CalcGain(const TrainParam ¶m) const {
|
||||||
|
return param.CalcGain(sum_grad, sum_hess);
|
||||||
|
}
|
||||||
|
/*! \brief add statistics to the data */
|
||||||
|
inline void Add(double grad, double hess) {
|
||||||
|
sum_grad += grad; sum_hess += hess;
|
||||||
|
}
|
||||||
|
/*! \brief add statistics to the data */
|
||||||
|
inline void Add(const GradStats &b) {
|
||||||
|
this->Add(b.sum_grad, b.sum_hess);
|
||||||
|
}
|
||||||
|
/*! \brief substract the statistics by b */
|
||||||
|
inline GradStats Substract(const GradStats &b) const {
|
||||||
|
GradStats res;
|
||||||
|
res.sum_grad = this->sum_grad - b.sum_grad;
|
||||||
|
res.sum_hess = this->sum_hess - b.sum_hess;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
/*! \return whether the statistics is not used yet */
|
||||||
|
inline bool Empty(void) const {
|
||||||
|
return sum_hess == 0.0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief statistics that is helpful to store
|
* \brief statistics that is helpful to store
|
||||||
* and represent a split solution for the tree
|
* and represent a split solution for the tree
|
||||||
|
|||||||
@ -80,13 +80,13 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
this->InitData(gpair, fmat, info.root_index, *p_tree);
|
this->InitData(gpair, fmat, info.root_index, *p_tree);
|
||||||
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
|
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
|
||||||
|
|
||||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||||
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree);
|
this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree);
|
||||||
this->ResetPosition(this->qexpand, fmat, *p_tree);
|
this->ResetPosition(this->qexpand, fmat, *p_tree);
|
||||||
this->UpdateQueueExpand(*p_tree, &this->qexpand);
|
this->UpdateQueueExpand(*p_tree, &this->qexpand);
|
||||||
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
|
this->InitNewNode(qexpand, gpair, fmat, info, *p_tree);
|
||||||
// if nothing left to be expand, break
|
// if nothing left to be expand, break
|
||||||
if (qexpand.size() == 0) break;
|
if (qexpand.size() == 0) break;
|
||||||
}
|
}
|
||||||
@ -175,6 +175,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
inline void InitNewNode(const std::vector<int> &qexpand,
|
inline void InitNewNode(const std::vector<int> &qexpand,
|
||||||
const std::vector<bst_gpair> &gpair,
|
const std::vector<bst_gpair> &gpair,
|
||||||
const FMatrix &fmat,
|
const FMatrix &fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
{// setup statistics space for each tree node
|
{// setup statistics space for each tree node
|
||||||
for (size_t i = 0; i < stemp.size(); ++i) {
|
for (size_t i = 0; i < stemp.size(); ++i) {
|
||||||
@ -190,7 +191,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
const bst_uint ridx = rowset[i];
|
const bst_uint ridx = rowset[i];
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
if (position[ridx] < 0) continue;
|
if (position[ridx] < 0) continue;
|
||||||
stemp[tid][position[ridx]].stats.Add(gpair[ridx]);
|
stemp[tid][position[ridx]].stats.Add(gpair, info, ridx);
|
||||||
}
|
}
|
||||||
// sum the per thread statistics together
|
// sum the per thread statistics together
|
||||||
for (size_t j = 0; j < qexpand.size(); ++j) {
|
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||||
@ -201,8 +202,8 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
}
|
}
|
||||||
// update node statistics
|
// update node statistics
|
||||||
snode[nid].stats = stats;
|
snode[nid].stats = stats;
|
||||||
snode[nid].root_gain = param.CalcGain(stats);
|
snode[nid].root_gain = stats.CalcGain(param);
|
||||||
snode[nid].weight = param.CalcWeight(stats);
|
snode[nid].weight = stats.CalcWeight(param);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief update queue expand add in new leaves */
|
/*! \brief update queue expand add in new leaves */
|
||||||
@ -223,6 +224,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
template<typename Iter>
|
template<typename Iter>
|
||||||
inline void EnumerateSplit(Iter it, unsigned fid,
|
inline void EnumerateSplit(Iter it, unsigned fid,
|
||||||
const std::vector<bst_gpair> &gpair,
|
const std::vector<bst_gpair> &gpair,
|
||||||
|
const BoosterInfo &info,
|
||||||
std::vector<ThreadEntry> &temp,
|
std::vector<ThreadEntry> &temp,
|
||||||
bool is_forward_search) {
|
bool is_forward_search) {
|
||||||
// clear all the temp statistics
|
// clear all the temp statistics
|
||||||
@ -239,19 +241,19 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
ThreadEntry &e = temp[nid];
|
ThreadEntry &e = temp[nid];
|
||||||
// test if first hit, this is fine, because we set 0 during init
|
// test if first hit, this is fine, because we set 0 during init
|
||||||
if (e.stats.Empty()) {
|
if (e.stats.Empty()) {
|
||||||
e.stats.Add(gpair[ridx]);
|
e.stats.Add(gpair, info, ridx);
|
||||||
e.last_fvalue = fvalue;
|
e.last_fvalue = fvalue;
|
||||||
} else {
|
} else {
|
||||||
// try to find a split
|
// try to find a split
|
||||||
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
||||||
TStats c = snode[nid].stats.Substract(e.stats);
|
TStats c = snode[nid].stats.Substract(e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight) {
|
if (c.sum_hess >= param.min_child_weight) {
|
||||||
double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
|
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// update the statistics
|
// update the statistics
|
||||||
e.stats.Add(gpair[ridx]);
|
e.stats.Add(gpair, info, ridx);
|
||||||
e.last_fvalue = fvalue;
|
e.last_fvalue = fvalue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -261,7 +263,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
ThreadEntry &e = temp[nid];
|
ThreadEntry &e = temp[nid];
|
||||||
TStats c = snode[nid].stats.Substract(e.stats);
|
TStats c = snode[nid].stats.Substract(e.stats);
|
||||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
||||||
const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain;
|
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||||
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
|
||||||
}
|
}
|
||||||
@ -269,7 +271,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
}
|
}
|
||||||
// find splits at current level, do split per level
|
// find splits at current level, do split per level
|
||||||
inline void FindSplit(int depth, const std::vector<int> &qexpand,
|
inline void FindSplit(int depth, const std::vector<int> &qexpand,
|
||||||
const std::vector<bst_gpair> &gpair, const FMatrix &fmat,
|
const std::vector<bst_gpair> &gpair,
|
||||||
|
const FMatrix &fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
std::vector<unsigned> feat_set = feat_index;
|
std::vector<unsigned> feat_set = feat_index;
|
||||||
if (param.colsample_bylevel != 1.0f) {
|
if (param.colsample_bylevel != 1.0f) {
|
||||||
@ -288,10 +292,10 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
const unsigned fid = feat_set[i];
|
const unsigned fid = feat_set[i];
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
||||||
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true);
|
this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true);
|
||||||
}
|
}
|
||||||
if (param.need_backward_search(fmat.GetColDensity(fid))) {
|
if (param.need_backward_search(fmat.GetColDensity(fid))) {
|
||||||
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false);
|
this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// after this each thread's stemp will get the best candidates, aggregate results
|
// after this each thread's stemp will get the best candidates, aggregate results
|
||||||
|
|||||||
@ -65,8 +65,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
RegTree::FVec &feats = fvec_temp[tid];
|
RegTree::FVec &feats = fvec_temp[tid];
|
||||||
feats.Fill(inst);
|
feats.Fill(inst);
|
||||||
for (size_t j = 0; j < trees.size(); ++j) {
|
for (size_t j = 0; j < trees.size(); ++j) {
|
||||||
AddStats(*trees[j], feats, gpair[ridx],
|
AddStats(*trees[j], feats, gpair, info, ridx,
|
||||||
info.GetRoot(j),
|
|
||||||
&stemp[tid * trees.size() + j]);
|
&stemp[tid * trees.size() + j]);
|
||||||
}
|
}
|
||||||
feats.Drop(inst);
|
feats.Drop(inst);
|
||||||
@ -95,31 +94,33 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
private:
|
private:
|
||||||
inline static void AddStats(const RegTree &tree,
|
inline static void AddStats(const RegTree &tree,
|
||||||
const RegTree::FVec &feat,
|
const RegTree::FVec &feat,
|
||||||
const bst_gpair &gpair, unsigned root_id,
|
const std::vector<bst_gpair> &gpair,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
const bst_uint ridx,
|
||||||
std::vector<GradStats> *p_gstats) {
|
std::vector<GradStats> *p_gstats) {
|
||||||
std::vector<GradStats> &gstats = *p_gstats;
|
std::vector<GradStats> &gstats = *p_gstats;
|
||||||
// start from groups that belongs to current data
|
// start from groups that belongs to current data
|
||||||
int pid = static_cast<int>(root_id);
|
int pid = static_cast<int>(info.GetRoot(ridx));
|
||||||
gstats[pid].Add(gpair);
|
gstats[pid].Add(gpair, info, ridx);
|
||||||
// tranverse tree
|
// tranverse tree
|
||||||
while (!tree[pid].is_leaf()) {
|
while (!tree[pid].is_leaf()) {
|
||||||
unsigned split_index = tree[pid].split_index();
|
unsigned split_index = tree[pid].split_index();
|
||||||
pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
|
pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
|
||||||
gstats[pid].Add(gpair);
|
gstats[pid].Add(gpair, info, ridx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline void Refresh(const std::vector<GradStats> &gstats,
|
inline void Refresh(const std::vector<GradStats> &gstats,
|
||||||
int nid, RegTree *p_tree) {
|
int nid, RegTree *p_tree) {
|
||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]);
|
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
||||||
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
||||||
if (tree[nid].is_leaf()) {
|
if (tree[nid].is_leaf()) {
|
||||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||||
} else {
|
} else {
|
||||||
tree.stat(nid).loss_chg =
|
tree.stat(nid).loss_chg =
|
||||||
param.CalcGain(gstats[tree[nid].cleft()]) +
|
gstats[tree[nid].cleft()].CalcGain(param) +
|
||||||
param.CalcGain(gstats[tree[nid].cright()]) -
|
gstats[tree[nid].cright()].CalcGain(param) -
|
||||||
param.CalcGain(gstats[nid]);
|
gstats[nid].CalcGain(param);
|
||||||
this->Refresh(gstats, tree[nid].cleft(), p_tree);
|
this->Refresh(gstats, tree[nid].cleft(), p_tree);
|
||||||
this->Refresh(gstats, tree[nid].cright(), p_tree);
|
this->Refresh(gstats, tree[nid].cright(), p_tree);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user