ok, now work on update position

This commit is contained in:
tqchen 2014-10-16 11:56:55 -07:00
parent aefe58a207
commit 47145a7fac
3 changed files with 50 additions and 34 deletions

View File

@ -14,7 +14,7 @@ endif
BIN = BIN =
OBJ = updater.o gbm.o io.o main.o OBJ = updater.o gbm.o io.o main.o
MPIOBJ = sync.o MPIOBJ = sync.o
MPIBIN = test/test xgboost MPIBIN = xgboost
SLIB = #wrapper/libxgboostwrapper.so SLIB = #wrapper/libxgboostwrapper.so
.PHONY: clean all python Rpack .PHONY: clean all python Rpack

View File

@ -132,17 +132,17 @@ class ColMaker: public IUpdater {
// mark delete for the deleted datas // mark delete for the deleted datas
for (size_t i = 0; i < rowset.size(); ++i) { for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f) position[ridx] = -1; if (gpair[ridx].hess < 0.0f) position[ridx] = ~position[ridx];
} }
// mark subsample // mark subsample
if (param.subsample < 1.0f) { if (param.subsample < 1.0f) {
for (size_t i = 0; i < rowset.size(); ++i) { for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f) continue; if (gpair[ridx].hess < 0.0f) continue;
if (random::SampleBinary(param.subsample) == 0) position[ridx] = -1; if (random::SampleBinary(param.subsample) == 0) position[ridx] = ~position[ridx];
} }
} }
} }
{ {
// initialize feature index // initialize feature index
unsigned ncol = static_cast<unsigned>(fmat.NumCol()); unsigned ncol = static_cast<unsigned>(fmat.NumCol());
@ -473,6 +473,9 @@ class ColMaker: public IUpdater {
if (e.best.loss_chg > rt_eps) { if (e.best.loss_chg > rt_eps) {
p_tree->AddChilds(nid); p_tree->AddChilds(nid);
(*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left()); (*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
} else { } else {
(*p_tree)[nid].set_leaf(e.weight * param.learning_rate); (*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
} }
@ -480,28 +483,33 @@ class ColMaker: public IUpdater {
} }
// reset position of each data points after split is created in the tree // reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) { inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) {
// set the positions in the nondefault
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
// set rest of instances to default position
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset(); const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// step 1, set default direct nodes to default, and leaf nodes to -1 // set default direct nodes to default
// for leaf nodes that are not fresh, mark then to ~nid,
// so that they are ignored in future statistics collection
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size()); const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
int nid = position[ridx]; const int nid = this->DecodePosition(ridx);
if (nid < 0) nid = ~nid;
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
position[ridx] = ~nid; // mark finish when it is not a fresh leaf
if (tree[nid].cright() == -1) {
position[ridx] = ~nid;
}
} else { } else {
// push to default branch, correct latter // push to default branch
int pid = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright(); if (tree[nid].default_left()) {
if (position[ridx] < 0) { this->SetEncodePosition(ridx, tree[nid].cleft());
position[ridx] = ~pid;
} else { } else {
position[ridx] = pid; this->SetEncodePosition(ridx, tree[nid].cright());
} }
} }
} }
// set the positions in the nondefault places
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
} }
// customization part // customization part
// synchronize the best solution of each node // synchronize the best solution of each node
@ -516,7 +524,7 @@ class ColMaker: public IUpdater {
} }
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand, virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
IFMatrix *p_fmat, const RegTree &tree) { IFMatrix *p_fmat, const RegTree &tree) {
// step 2, classify the non-default data into right places // step 1, classify the non-default data into right places
std::vector<unsigned> fsplits; std::vector<unsigned> fsplits;
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
@ -538,22 +546,33 @@ class ColMaker: public IUpdater {
for (bst_omp_uint j = 0; j < ndata; ++j) { for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue; const float fvalue = col[j].fvalue;
int nid = position[ridx]; const int nid = this->DecodePosition(ridx);
if (nid < 0) nid = ~nid;
// go back to parent, correct those who are not default // go back to parent, correct those who are not default
nid = tree[nid].parent(); if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
if (tree[nid].split_index() == fid) { if(fvalue < tree[nid].split_cond()) {
if (fvalue < tree[nid].split_cond()) { this->SetEncodePosition(ridx, tree[nid].cleft());
position[ridx] = tree[nid].cleft();
} else { } else {
position[ridx] = tree[nid].cright(); this->SetEncodePosition(ridx, tree[nid].cright());
} }
} }
} }
} }
} }
} }
// utils to get/set position, with encoded format
// return decoded position
inline int DecodePosition(bst_uint ridx) const{
const int pid = position[ridx];
return pid < 0 ? ~pid : pid;
}
// encode the encoded position value for ridx
inline void SetEncodePosition(bst_uint ridx, int nid) {
if (position[ridx] < 0) {
position[ridx] = ~nid;
} else {
position[ridx] = nid;
}
}
//--data fields-- //--data fields--
const TrainParam &param; const TrainParam &param;
// number of omp thread used during training // number of omp thread used during training

View File

@ -100,11 +100,8 @@ class DistColMaker : public ColMaker<TStats> {
for (bst_omp_uint j = 0; j < ndata; ++j) { for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue; const float fvalue = col[j].fvalue;
int nid = this->position[ridx]; const int nid = this->DecodePosition(ridx);
if (nid < 0) continue; if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
// go back to parent, correct those who are not default
nid = tree[nid].parent();
if (tree[nid].split_index() == fid) {
if (fvalue < tree[nid].split_cond()) { if (fvalue < tree[nid].split_cond()) {
if (!tree[nid].default_left()) bitmap.SetTrue(ridx); if (!tree[nid].default_left()) bitmap.SetTrue(ridx);
} else { } else {
@ -122,13 +119,13 @@ class DistColMaker : public ColMaker<TStats> {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
int nid = this->position[ridx]; const int nid = this->DecodePosition(ridx);
if (nid >= 0 && bitmap.Get(ridx)) { if (bitmap.Get(ridx)) {
nid = tree[nid].parent(); utils::Assert(!tree[nid].is_leaf(), "inconsistent reduce information");
if (tree[nid].default_left()) { if (tree[nid].default_left()) {
this->position[ridx] = tree[nid].cright(); this->SetEncodePosition(ridx, tree[nid].cright());
} else { } else {
this->position[ridx] = tree[nid].cleft(); this->SetEncodePosition(ridx, tree[nid].cright());
} }
} }
} }