From d08d8ed3edbc7c5c80ec0f1d72e3b806976d03ad Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 18 Aug 2014 21:32:48 -0700 Subject: [PATCH] add tree refresher, need review --- src/tree/updater_refresh-inl.hpp | 138 +++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 src/tree/updater_refresh-inl.hpp diff --git a/src/tree/updater_refresh-inl.hpp b/src/tree/updater_refresh-inl.hpp new file mode 100644 index 000000000..69f099e1d --- /dev/null +++ b/src/tree/updater_refresh-inl.hpp @@ -0,0 +1,138 @@ +#ifndef XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_ +#define XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_ +/*! + * \file updater_refresh-inl.hpp + * \brief refresh the statistics and leaf value on the tree on the dataset + * \author Tianqi Chen + */ +#include +#include +#include "./param.h" +#include "./updater.h" + +namespace xgboost { +namespace tree { +/*! \brief pruner that prunes a tree after growing finishs */ +template +class TreeRefresher: public IUpdater { + public: + virtual ~TreeRefresher(void) {} + // set training parameter + virtual void SetParam(const char *name, const char *val) { + param.SetParam(name, val); + if (!strcmp(name, "silent")) silent = atoi(val); + } + // update the tree, do pruning + virtual void Update(const std::vector &gpair, + const FMatrix &fmat, + const std::vector &root_index, + const std::vector &trees) { + if (trees.size() == 0) return; + // number of threads + int nthread; + // thread temporal space + std::vector< std::vector > stemp; + std::vector fvec_temp; + // setup temp space for each thread + #pragma omp parallel + { + nthread = omp_get_num_threads(); + } + fvec_temp.resize(nthread, RegTree::FVec()); + stemp.resize(trees.size() * nthread, std::vector()); + #pragma omp parallel + { + int tid = omp_get_thread_num(); + for (size_t i = 0; i < trees.size(); ++i) { + std::vector &vec = stemp[tid * trees.size() + i]; + vec.resize(trees[i]->param.num_nodes); + std::fill(vec.begin(), vec.end(), GradStats()); + } + fvec_temp[tid].Init(trees[0]->param.num_feature); + } + // start accumulating statistics + utils::IIterator *iter = fmat.RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const SparseBatch &batch = iter->Value(); + utils::Check(batch.size < std::numeric_limits::max(), + "too large batch size "); + const unsigned nbatch = static_cast(batch.size); + #pragma omp parallel for schedule(static) + for (unsigned i = 0; i < nbatch; ++i) { + SparseBatch::Inst inst = batch[i]; + const int tid = omp_get_thread_num(); + const size_t ridx = batch.base_rowid + i; + RegTree::FVec &feats = fvec_temp[tid]; + feats.Fill(inst); + for (size_t j = 0; j < trees.size(); ++j) { + AddStats(*trees[j], feats, gpair[ridx], + root_index.size() == 0 ? 0 : root_index[ridx], + &stemp[tid * trees.size() + j]); + } + feats.Drop(inst); + } + } + // start update the trees using the statistics + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); + for (size_t i = 0; i < trees.size(); ++i) { + // aggregate + #pragma omp parallel for schedule(static) + for (int nid = 0; nid < trees[i]->param.num_nodes; ++nid) { + for (int tid = 1; tid < nthread; ++tid) { + stemp[i][nid].Add(stemp[tid * trees.size() + i][nid]); + } + } + for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) { + this->Refresh(stemp[i], rid, trees[i]); + } + } + // set learning rate back + param.learning_rate = lr; + } + + private: + inline static void AddStats(const RegTree &tree, + const RegTree::FVec &feat, + const bst_gpair &gpair, unsigned root_id, + std::vector *p_gstats) { + std::vector &gstats = *p_gstats; + // start from groups that belongs to current data + int pid = static_cast(root_id); + gstats[pid].Add(gpair); + // tranverse tree + while (!tree[pid].is_leaf()) { + unsigned split_index = tree[pid].split_index(); + pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index)); + gstats[pid].Add(gpair); + } + } + inline void Refresh(const std::vector &gstats, + int nid, RegTree *p_tree) { + RegTree &tree = *p_tree; + tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]); + tree.stat(nid).sum_hess = static_cast(gstats[nid].sum_hess); + if (tree[nid].is_leaf()) { + tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate); + } else { + tree.stat(nid).loss_chg = + param.CalcGain(gstats[tree[nid].cleft()]) + + param.CalcGain(gstats[tree[nid].cright()]) - + param.CalcGain(gstats[nid]); + this->Refresh(gstats, tree[nid].cleft(), p_tree); + this->Refresh(gstats, tree[nid].cright(), p_tree); + } + } + // number of thread in the data + int nthread; + // shutup + int silent; + // training parameter + TrainParam param; +}; + +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_REFRESH_INL_HPP_