From e29512897315cbc6092e668c4f48bb6ef62f3fd7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 15 Oct 2014 14:30:09 -0700 Subject: [PATCH] add bitmap . --- src/tree/updater.cpp | 2 ++ src/tree/updater_colmaker-inl.hpp | 41 ++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/tree/updater.cpp b/src/tree/updater.cpp index 5879b2bbd..e2c530142 100644 --- a/src/tree/updater.cpp +++ b/src/tree/updater.cpp @@ -5,6 +5,7 @@ #include "./updater_prune-inl.hpp" #include "./updater_refresh-inl.hpp" #include "./updater_colmaker-inl.hpp" +#include "./updater_distcol-inl.hpp" namespace xgboost { namespace tree { @@ -12,6 +13,7 @@ IUpdater* CreateUpdater(const char *name) { using namespace std; if (!strcmp(name, "prune")) return new TreePruner(); if (!strcmp(name, "refresh")) return new TreeRefresher(); + if (!strcmp(name, "distcol")) return new DistColMaker(); if (!strcmp(name, "grow_colmaker")) return new ColMaker(); if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >(); if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >(); diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index 9c2740264..596c8c8f5 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -36,10 +36,11 @@ class ColMaker: public IUpdater { Builder builder(param); builder.Update(gpair, p_fmat, info, trees[i]); } + param.learning_rate = lr; } - private: + protected: // training parameter TrainParam param; // data structure @@ -108,7 +109,7 @@ class ColMaker: public IUpdater { } } - private: + protected: // initialize temp data structure inline void InitData(const std::vector &gpair, const IFMatrix &fmat, @@ -409,7 +410,7 @@ class ColMaker: public IUpdater { } } // update the solution candidate - virtual void UpdateSolution(const ColBatch &batch, + virtual void UpdateSolution(const ColBatch &batch, const std::vector &gpair, const IFMatrix &fmat, const BoosterInfo &info) { @@ -463,12 +464,11 @@ class ColMaker: public IUpdater { this->UpdateSolution(iter->Value(), gpair, *p_fmat, info); } // after this each thread's stemp will get the best candidates, aggregate results + this->SyncBestSolution(qexpand); + // get the best result, we can synchronize the solution for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; - NodeEntry &e = snode[nid]; - for (int tid = 0; tid < this->nthread; ++tid) { - e.best.Update(stemp[tid][nid].best); - } + NodeEntry &e = snode[nid]; // now we know the solution in snode[nid], set split if (e.best.loss_chg > rt_eps) { p_tree->AddChilds(nid); @@ -476,9 +476,8 @@ class ColMaker: public IUpdater { } else { (*p_tree)[nid].set_leaf(e.weight * param.learning_rate); } - } + } } - // reset position of each data points after split is created in the tree inline void ResetPosition(const std::vector &qexpand, IFMatrix *p_fmat, const RegTree &tree) { const std::vector &rowset = p_fmat->buffered_rowset(); @@ -490,18 +489,36 @@ class ColMaker: public IUpdater { const int nid = position[ridx]; if (nid >= 0) { if (tree[nid].is_leaf()) { - position[ridx] = -1; + position[ridx] = - nid - 1; } else { // push to default branch, correct latter position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright(); } } } + // set the positions in the nondefault places + this->SetNonDefaultPosition(qexpand, p_fmat, tree); + } + // customization part + // synchronize the best solution of each node + virtual void SyncBestSolution(const std::vector &qexpand) { + for (size_t i = 0; i < qexpand.size(); ++i) { + const int nid = qexpand[i]; + NodeEntry &e = snode[nid]; + for (int tid = 0; tid < this->nthread; ++tid) { + e.best.Update(stemp[tid][nid].best); + } + } + } + virtual void SetNonDefaultPosition(const std::vector &qexpand, + IFMatrix *p_fmat, const RegTree &tree) { // step 2, classify the non-default data into right places std::vector fsplits; for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; - if (!tree[nid].is_leaf()) fsplits.push_back(tree[nid].split_index()); + if (!tree[nid].is_leaf()) { + fsplits.push_back(tree[nid].split_index()); + } } std::sort(fsplits.begin(), fsplits.end()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); @@ -518,7 +535,7 @@ class ColMaker: public IUpdater { const bst_uint ridx = col[j].index; const float fvalue = col[j].fvalue; int nid = position[ridx]; - if (nid == -1) continue; + if (nid < 0) continue; // go back to parent, correct those who are not default nid = tree[nid].parent(); if (tree[nid].split_index() == fid) {