From ba9fbd380c9ad3024e73ae8b25cd0ac20297883a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 24 Aug 2014 15:22:11 -0700 Subject: [PATCH] templatize refresher --- src/tree/updater.h | 2 +- src/tree/updater_refresh-inl.hpp | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/tree/updater.h b/src/tree/updater.h index cdb625266..b33ee1833 100644 --- a/src/tree/updater.h +++ b/src/tree/updater.h @@ -60,7 +60,7 @@ namespace tree { template inline IUpdater* CreateUpdater(const char *name) { if (!strcmp(name, "prune")) return new TreePruner(); - if (!strcmp(name, "refresh")) return new TreeRefresher(); + if (!strcmp(name, "refresh")) return new TreeRefresher(); if (!strcmp(name, "grow_colmaker")) return new ColMaker(); utils::Error("unknown updater:%s", name); return NULL; diff --git a/src/tree/updater_refresh-inl.hpp b/src/tree/updater_refresh-inl.hpp index 3ccf217f6..e0e7ab520 100644 --- a/src/tree/updater_refresh-inl.hpp +++ b/src/tree/updater_refresh-inl.hpp @@ -13,7 +13,7 @@ namespace xgboost { namespace tree { /*! \brief pruner that prunes a tree after growing finishs */ -template +template class TreeRefresher: public IUpdater { public: virtual ~TreeRefresher(void) {} @@ -30,7 +30,7 @@ class TreeRefresher: public IUpdater { // number of threads int nthread; // thread temporal space - std::vector< std::vector > stemp; + std::vector< std::vector > stemp; std::vector fvec_temp; // setup temp space for each thread #pragma omp parallel @@ -38,14 +38,14 @@ class TreeRefresher: public IUpdater { nthread = omp_get_num_threads(); } fvec_temp.resize(nthread, RegTree::FVec()); - stemp.resize(trees.size() * nthread, std::vector()); + stemp.resize(trees.size() * nthread, std::vector()); #pragma omp parallel { int tid = omp_get_thread_num(); for (size_t i = 0; i < trees.size(); ++i) { - std::vector &vec = stemp[tid * trees.size() + i]; + std::vector &vec = stemp[tid * trees.size() + i]; vec.resize(trees[i]->param.num_nodes); - std::fill(vec.begin(), vec.end(), GradStats()); + std::fill(vec.begin(), vec.end(), TStats()); } fvec_temp[tid].Init(trees[0]->param.num_feature); } @@ -97,8 +97,8 @@ class TreeRefresher: public IUpdater { const std::vector &gpair, const BoosterInfo &info, const bst_uint ridx, - std::vector *p_gstats) { - std::vector &gstats = *p_gstats; + std::vector *p_gstats) { + std::vector &gstats = *p_gstats; // start from groups that belongs to current data int pid = static_cast(info.GetRoot(ridx)); gstats[pid].Add(gpair, info, ridx); @@ -109,7 +109,7 @@ class TreeRefresher: public IUpdater { gstats[pid].Add(gpair, info, ridx); } } - inline void Refresh(const std::vector &gstats, + inline void Refresh(const std::vector &gstats, int nid, RegTree *p_tree) { RegTree &tree = *p_tree; tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);