add bitmap .

This commit is contained in:
tqchen 2014-10-15 14:30:09 -07:00
parent d0daecb4d3
commit e295128973
2 changed files with 31 additions and 12 deletions

View File

@ -5,6 +5,7 @@
#include "./updater_prune-inl.hpp"
#include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp"
#include "./updater_distcol-inl.hpp"
namespace xgboost {
namespace tree {
@ -12,6 +13,7 @@ IUpdater* CreateUpdater(const char *name) {
using namespace std;
if (!strcmp(name, "prune")) return new TreePruner();
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >();

View File

@ -36,10 +36,11 @@ class ColMaker: public IUpdater {
Builder builder(param);
builder.Update(gpair, p_fmat, info, trees[i]);
}
param.learning_rate = lr;
}
private:
protected:
// training parameter
TrainParam param;
// data structure
@ -108,7 +109,7 @@ class ColMaker: public IUpdater {
}
}
private:
protected:
// initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair,
const IFMatrix &fmat,
@ -409,7 +410,7 @@ class ColMaker: public IUpdater {
}
}
// update the solution candidate
virtual void UpdateSolution(const ColBatch &batch,
virtual void UpdateSolution(const ColBatch &batch,
const std::vector<bst_gpair> &gpair,
const IFMatrix &fmat,
const BoosterInfo &info) {
@ -463,12 +464,11 @@ class ColMaker: public IUpdater {
this->UpdateSolution(iter->Value(), gpair, *p_fmat, info);
}
// after this each thread's stemp will get the best candidates, aggregate results
this->SyncBestSolution(qexpand);
// get the best result, we can synchronize the solution
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
NodeEntry &e = snode[nid];
for (int tid = 0; tid < this->nthread; ++tid) {
e.best.Update(stemp[tid][nid].best);
}
NodeEntry &e = snode[nid];
// now we know the solution in snode[nid], set split
if (e.best.loss_chg > rt_eps) {
p_tree->AddChilds(nid);
@ -476,9 +476,8 @@ class ColMaker: public IUpdater {
} else {
(*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
}
}
}
}
// reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
@ -490,18 +489,36 @@ class ColMaker: public IUpdater {
const int nid = position[ridx];
if (nid >= 0) {
if (tree[nid].is_leaf()) {
position[ridx] = -1;
position[ridx] = - nid - 1;
} else {
// push to default branch, correct latter
position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
}
}
}
// set the positions in the nondefault places
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
}
// customization part
// synchronize the best solution of each node
virtual void SyncBestSolution(const std::vector<int> &qexpand) {
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
NodeEntry &e = snode[nid];
for (int tid = 0; tid < this->nthread; ++tid) {
e.best.Update(stemp[tid][nid].best);
}
}
}
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
IFMatrix *p_fmat, const RegTree &tree) {
// step 2, classify the non-default data into right places
std::vector<unsigned> fsplits;
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
if (!tree[nid].is_leaf()) fsplits.push_back(tree[nid].split_index());
if (!tree[nid].is_leaf()) {
fsplits.push_back(tree[nid].split_index());
}
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
@ -518,7 +535,7 @@ class ColMaker: public IUpdater {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
int nid = position[ridx];
if (nid == -1) continue;
if (nid < 0) continue;
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {