ok, now work on update position
This commit is contained in:
parent
aefe58a207
commit
47145a7fac
2
Makefile
2
Makefile
@ -14,7 +14,7 @@ endif
|
||||
BIN =
|
||||
OBJ = updater.o gbm.o io.o main.o
|
||||
MPIOBJ = sync.o
|
||||
MPIBIN = test/test xgboost
|
||||
MPIBIN = xgboost
|
||||
SLIB = #wrapper/libxgboostwrapper.so
|
||||
|
||||
.PHONY: clean all python Rpack
|
||||
|
||||
@ -132,17 +132,17 @@ class ColMaker: public IUpdater {
|
||||
// mark delete for the deleted datas
|
||||
for (size_t i = 0; i < rowset.size(); ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
if (gpair[ridx].hess < 0.0f) position[ridx] = -1;
|
||||
if (gpair[ridx].hess < 0.0f) position[ridx] = ~position[ridx];
|
||||
}
|
||||
// mark subsample
|
||||
if (param.subsample < 1.0f) {
|
||||
for (size_t i = 0; i < rowset.size(); ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
if (gpair[ridx].hess < 0.0f) continue;
|
||||
if (random::SampleBinary(param.subsample) == 0) position[ridx] = -1;
|
||||
if (random::SampleBinary(param.subsample) == 0) position[ridx] = ~position[ridx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
// initialize feature index
|
||||
unsigned ncol = static_cast<unsigned>(fmat.NumCol());
|
||||
@ -473,6 +473,9 @@ class ColMaker: public IUpdater {
|
||||
if (e.best.loss_chg > rt_eps) {
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].set_split(e.best.split_index(), e.best.split_value, e.best.default_left());
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
|
||||
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
|
||||
} else {
|
||||
(*p_tree)[nid].set_leaf(e.weight * param.learning_rate);
|
||||
}
|
||||
@ -480,28 +483,33 @@ class ColMaker: public IUpdater {
|
||||
}
|
||||
// reset position of each data points after split is created in the tree
|
||||
inline void ResetPosition(const std::vector<int> &qexpand, IFMatrix *p_fmat, const RegTree &tree) {
|
||||
// set the positions in the nondefault
|
||||
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
|
||||
// set rest of instances to default position
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
// step 1, set default direct nodes to default, and leaf nodes to -1
|
||||
// set default direct nodes to default
|
||||
// for leaf nodes that are not fresh, mark then to ~nid,
|
||||
// so that they are ignored in future statistics collection
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
int nid = position[ridx];
|
||||
if (nid < 0) nid = ~nid;
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
if (tree[nid].is_leaf()) {
|
||||
position[ridx] = ~nid;
|
||||
// mark finish when it is not a fresh leaf
|
||||
if (tree[nid].cright() == -1) {
|
||||
position[ridx] = ~nid;
|
||||
}
|
||||
} else {
|
||||
// push to default branch, correct latter
|
||||
int pid = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
|
||||
if (position[ridx] < 0) {
|
||||
position[ridx] = ~pid;
|
||||
// push to default branch
|
||||
if (tree[nid].default_left()) {
|
||||
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||
} else {
|
||||
position[ridx] = pid;
|
||||
this->SetEncodePosition(ridx, tree[nid].cright());
|
||||
}
|
||||
}
|
||||
}
|
||||
// set the positions in the nondefault places
|
||||
this->SetNonDefaultPosition(qexpand, p_fmat, tree);
|
||||
}
|
||||
// customization part
|
||||
// synchronize the best solution of each node
|
||||
@ -516,7 +524,7 @@ class ColMaker: public IUpdater {
|
||||
}
|
||||
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
|
||||
IFMatrix *p_fmat, const RegTree &tree) {
|
||||
// step 2, classify the non-default data into right places
|
||||
// step 1, classify the non-default data into right places
|
||||
std::vector<unsigned> fsplits;
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
@ -538,22 +546,33 @@ class ColMaker: public IUpdater {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
const bst_uint ridx = col[j].index;
|
||||
const float fvalue = col[j].fvalue;
|
||||
int nid = position[ridx];
|
||||
if (nid < 0) nid = ~nid;
|
||||
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
// go back to parent, correct those who are not default
|
||||
nid = tree[nid].parent();
|
||||
if (tree[nid].split_index() == fid) {
|
||||
if (fvalue < tree[nid].split_cond()) {
|
||||
position[ridx] = tree[nid].cleft();
|
||||
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
|
||||
if(fvalue < tree[nid].split_cond()) {
|
||||
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||
} else {
|
||||
position[ridx] = tree[nid].cright();
|
||||
this->SetEncodePosition(ridx, tree[nid].cright());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// utils to get/set position, with encoded format
|
||||
// return decoded position
|
||||
inline int DecodePosition(bst_uint ridx) const{
|
||||
const int pid = position[ridx];
|
||||
return pid < 0 ? ~pid : pid;
|
||||
}
|
||||
// encode the encoded position value for ridx
|
||||
inline void SetEncodePosition(bst_uint ridx, int nid) {
|
||||
if (position[ridx] < 0) {
|
||||
position[ridx] = ~nid;
|
||||
} else {
|
||||
position[ridx] = nid;
|
||||
}
|
||||
}
|
||||
//--data fields--
|
||||
const TrainParam ¶m;
|
||||
// number of omp thread used during training
|
||||
|
||||
@ -100,11 +100,8 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
const bst_uint ridx = col[j].index;
|
||||
const float fvalue = col[j].fvalue;
|
||||
int nid = this->position[ridx];
|
||||
if (nid < 0) continue;
|
||||
// go back to parent, correct those who are not default
|
||||
nid = tree[nid].parent();
|
||||
if (tree[nid].split_index() == fid) {
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
|
||||
if (fvalue < tree[nid].split_cond()) {
|
||||
if (!tree[nid].default_left()) bitmap.SetTrue(ridx);
|
||||
} else {
|
||||
@ -122,13 +119,13 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
int nid = this->position[ridx];
|
||||
if (nid >= 0 && bitmap.Get(ridx)) {
|
||||
nid = tree[nid].parent();
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
if (bitmap.Get(ridx)) {
|
||||
utils::Assert(!tree[nid].is_leaf(), "inconsistent reduce information");
|
||||
if (tree[nid].default_left()) {
|
||||
this->position[ridx] = tree[nid].cright();
|
||||
this->SetEncodePosition(ridx, tree[nid].cright());
|
||||
} else {
|
||||
this->position[ridx] = tree[nid].cleft();
|
||||
this->SetEncodePosition(ridx, tree[nid].cright());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user