templatize refresher
This commit is contained in:
parent
f71b732e7a
commit
ba9fbd380c
@ -60,7 +60,7 @@ namespace tree {
|
|||||||
template<typename FMatrix>
|
template<typename FMatrix>
|
||||||
inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
|
inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
|
||||||
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
|
if (!strcmp(name, "prune")) return new TreePruner<FMatrix>();
|
||||||
if (!strcmp(name, "refresh")) return new TreeRefresher<FMatrix>();
|
if (!strcmp(name, "refresh")) return new TreeRefresher<FMatrix, GradStats>();
|
||||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
|
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
|
||||||
utils::Error("unknown updater:%s", name);
|
utils::Error("unknown updater:%s", name);
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
/*! \brief pruner that prunes a tree after growing finishs */
|
/*! \brief pruner that prunes a tree after growing finishs */
|
||||||
template<typename FMatrix>
|
template<typename FMatrix, typename TStats>
|
||||||
class TreeRefresher: public IUpdater<FMatrix> {
|
class TreeRefresher: public IUpdater<FMatrix> {
|
||||||
public:
|
public:
|
||||||
virtual ~TreeRefresher(void) {}
|
virtual ~TreeRefresher(void) {}
|
||||||
@ -30,7 +30,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
// number of threads
|
// number of threads
|
||||||
int nthread;
|
int nthread;
|
||||||
// thread temporal space
|
// thread temporal space
|
||||||
std::vector< std::vector<GradStats> > stemp;
|
std::vector< std::vector<TStats> > stemp;
|
||||||
std::vector<RegTree::FVec> fvec_temp;
|
std::vector<RegTree::FVec> fvec_temp;
|
||||||
// setup temp space for each thread
|
// setup temp space for each thread
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
@ -38,14 +38,14 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
nthread = omp_get_num_threads();
|
nthread = omp_get_num_threads();
|
||||||
}
|
}
|
||||||
fvec_temp.resize(nthread, RegTree::FVec());
|
fvec_temp.resize(nthread, RegTree::FVec());
|
||||||
stemp.resize(trees.size() * nthread, std::vector<GradStats>());
|
stemp.resize(trees.size() * nthread, std::vector<TStats>());
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
std::vector<GradStats> &vec = stemp[tid * trees.size() + i];
|
std::vector<TStats> &vec = stemp[tid * trees.size() + i];
|
||||||
vec.resize(trees[i]->param.num_nodes);
|
vec.resize(trees[i]->param.num_nodes);
|
||||||
std::fill(vec.begin(), vec.end(), GradStats());
|
std::fill(vec.begin(), vec.end(), TStats());
|
||||||
}
|
}
|
||||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
||||||
}
|
}
|
||||||
@ -97,8 +97,8 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
const std::vector<bst_gpair> &gpair,
|
const std::vector<bst_gpair> &gpair,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
const bst_uint ridx,
|
const bst_uint ridx,
|
||||||
std::vector<GradStats> *p_gstats) {
|
std::vector<TStats> *p_gstats) {
|
||||||
std::vector<GradStats> &gstats = *p_gstats;
|
std::vector<TStats> &gstats = *p_gstats;
|
||||||
// start from groups that belongs to current data
|
// start from groups that belongs to current data
|
||||||
int pid = static_cast<int>(info.GetRoot(ridx));
|
int pid = static_cast<int>(info.GetRoot(ridx));
|
||||||
gstats[pid].Add(gpair, info, ridx);
|
gstats[pid].Add(gpair, info, ridx);
|
||||||
@ -109,7 +109,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
gstats[pid].Add(gpair, info, ridx);
|
gstats[pid].Add(gpair, info, ridx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline void Refresh(const std::vector<GradStats> &gstats,
|
inline void Refresh(const std::vector<TStats> &gstats,
|
||||||
int nid, RegTree *p_tree) {
|
int nid, RegTree *p_tree) {
|
||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user