add bitmap .

This commit is contained in:
tqchen 2014-10-15 14:30:09 -07:00
parent d0daecb4d3
commit e295128973
2 changed files with 31 additions and 12 deletions

View File

@ -5,6 +5,7 @@
#include "./updater_prune-inl.hpp" #include "./updater_prune-inl.hpp"
#include "./updater_refresh-inl.hpp" #include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp" #include "./updater_colmaker-inl.hpp"
#include "./updater_distcol-inl.hpp"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -12,6 +13,7 @@ IUpdater* CreateUpdater(const char *name) {
using namespace std; using namespace std;
if (!strcmp(name, "prune")) return new TreePruner(); if (!strcmp(name, "prune")) return new TreePruner();
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>(); if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>(); if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >(); if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >(); if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >();

View File

@ -36,10 +36,11 @@ class ColMaker: public IUpdater {
Builder builder(param); Builder builder(param);
builder.Update(gpair, p_fmat, info, trees[i]); builder.Update(gpair, p_fmat, info, trees[i]);
} }
param.learning_rate = lr; param.learning_rate = lr;
} }
private: protected:
// training parameter // training parameter
TrainParam param; TrainParam param;
// data structure // data structure
@ -108,7 +109,7 @@ class ColMaker: public IUpdater {
} }
} }
private: protected:
// initialize temp data structure // initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair, inline void InitData(const std::vector<bst_gpair> &gpair,
const IFMatrix &fmat, const IFMatrix &fmat,
@ -463,12 +464,11 @@ class ColMaker: public IUpdater {
this->UpdateSolution(iter->Value(), gpair, *p_fmat, info); this->UpdateSolution(iter->Value(), gpair, *p_fmat, info);
} }
// after this each thread's stemp will get the best candidates, aggregate results // after this each thread's stemp will get the best candidates, aggregate results
this->SyncBestSolution(qexpand);
// get the best result, we can synchronize the solution
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
NodeEntry &e = snode[nid]; NodeEntry &e = snode[nid];
for (int tid = 0; tid < this->nthread; ++tid) {
e.best.Update(stemp[tid][nid].best);
}
// now we know the solution in snode[nid], set split // now we know the solution in snode[nid], set split
if (e.best.loss_chg > rt_eps) { if (e.best.loss_chg > rt_eps) {
p_tree->AddChilds(nid); p_tree->AddChilds(nid);
@ -478,7 +478,6 @@ class ColMaker: public IUpdater {
} }
} }
} }
// reset position of each data points after split is created in the tree // reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) { inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset(); const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
@ -490,18 +489,36 @@ class ColMaker: public IUpdater {
const int nid = position[ridx]; const int nid = position[ridx];
if (nid >= 0) { if (nid >= 0) {
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
position[ridx] = -1; position[ridx] = - nid - 1;
} else { } else {
// push to default branch, correct latter // push to default branch, correct latter
position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright(); position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
} }
} }
} }
// set the positions in the nondefault places
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
}
// customization part
// synchronize the best solution of each node
virtual void SyncBestSolution(const std::vector<int> &qexpand) {
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
NodeEntry &e = snode[nid];
for (int tid = 0; tid < this->nthread; ++tid) {
e.best.Update(stemp[tid][nid].best);
}
}
}
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
IFMatrix *p_fmat, const RegTree &tree) {
// step 2, classify the non-default data into right places // step 2, classify the non-default data into right places
std::vector<unsigned> fsplits; std::vector<unsigned> fsplits;
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
if (!tree[nid].is_leaf()) fsplits.push_back(tree[nid].split_index()); if (!tree[nid].is_leaf()) {
fsplits.push_back(tree[nid].split_index());
}
} }
std::sort(fsplits.begin(), fsplits.end()); std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
@ -518,7 +535,7 @@ class ColMaker: public IUpdater {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue; const float fvalue = col[j].fvalue;
int nid = position[ridx]; int nid = position[ridx];
if (nid == -1) continue; if (nid < 0) continue;
// go back to parent, correct those who are not default // go back to parent, correct those who are not default
nid = tree[nid].parent(); nid = tree[nid].parent();
if (tree[nid].split_index() == fid) { if (tree[nid].split_index() == fid) {