add tree refresher, need review

This commit is contained in:
tqchen 2014-08-18 21:32:48 -07:00
parent f757520c02
commit d08d8ed3ed

View File

@ -0,0 +1,138 @@
#ifndef XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
#define XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_
/*!
* \file updater_refresh-inl.hpp
* \brief refresh the statistics and leaf value on the tree on the dataset
* \author Tianqi Chen
*/
#include <vector>
#include <limits>
#include "./param.h"
#include "./updater.h"
namespace xgboost {
namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */
template<typename FMatrix>
class TreeRefresher: public IUpdater<FMatrix> {
public:
virtual ~TreeRefresher(void) {}
// set training parameter
virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val);
if (!strcmp(name, "silent")) silent = atoi(val);
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return;
// number of threads
int nthread;
// thread temporal space
std::vector< std::vector<GradStats> > stemp;
std::vector<RegTree::FVec> fvec_temp;
// setup temp space for each thread
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
fvec_temp.resize(nthread, RegTree::FVec());
stemp.resize(trees.size() * nthread, std::vector<GradStats>());
#pragma omp parallel
{
int tid = omp_get_thread_num();
for (size_t i = 0; i < trees.size(); ++i) {
std::vector<GradStats> &vec = stemp[tid * trees.size() + i];
vec.resize(trees[i]->param.num_nodes);
std::fill(vec.begin(), vec.end(), GradStats());
}
fvec_temp[tid].Init(trees[0]->param.num_feature);
}
// start accumulating statistics
utils::IIterator<SparseBatch> *iter = fmat.RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const SparseBatch &batch = iter->Value();
utils::Check(batch.size < std::numeric_limits<unsigned>::max(),
"too large batch size ");
const unsigned nbatch = static_cast<unsigned>(batch.size);
#pragma omp parallel for schedule(static)
for (unsigned i = 0; i < nbatch; ++i) {
SparseBatch::Inst inst = batch[i];
const int tid = omp_get_thread_num();
const size_t ridx = batch.base_rowid + i;
RegTree::FVec &feats = fvec_temp[tid];
feats.Fill(inst);
for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair[ridx],
root_index.size() == 0 ? 0 : root_index[ridx],
&stemp[tid * trees.size() + j]);
}
feats.Drop(inst);
}
}
// start update the trees using the statistics
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
for (size_t i = 0; i < trees.size(); ++i) {
// aggregate
#pragma omp parallel for schedule(static)
for (int nid = 0; nid < trees[i]->param.num_nodes; ++nid) {
for (int tid = 1; tid < nthread; ++tid) {
stemp[i][nid].Add(stemp[tid * trees.size() + i][nid]);
}
}
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
this->Refresh(stemp[i], rid, trees[i]);
}
}
// set learning rate back
param.learning_rate = lr;
}
private:
inline static void AddStats(const RegTree &tree,
const RegTree::FVec &feat,
const bst_gpair &gpair, unsigned root_id,
std::vector<GradStats> *p_gstats) {
std::vector<GradStats> &gstats = *p_gstats;
// start from groups that belongs to current data
int pid = static_cast<int>(root_id);
gstats[pid].Add(gpair);
// tranverse tree
while (!tree[pid].is_leaf()) {
unsigned split_index = tree[pid].split_index();
pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
gstats[pid].Add(gpair);
}
}
inline void Refresh(const std::vector<GradStats> &gstats,
int nid, RegTree *p_tree) {
RegTree &tree = *p_tree;
tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]);
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
if (tree[nid].is_leaf()) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
} else {
tree.stat(nid).loss_chg =
param.CalcGain(gstats[tree[nid].cleft()]) +
param.CalcGain(gstats[tree[nid].cright()]) -
param.CalcGain(gstats[nid]);
this->Refresh(gstats, tree[nid].cleft(), p_tree);
this->Refresh(gstats, tree[nid].cright(), p_tree);
}
}
// number of thread in the data
int nthread;
// shutup
int silent;
// training parameter
TrainParam param;
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_