change row subsample to prob

This commit is contained in:
tqchen@graphlab.com
2014-08-19 12:07:52 -07:00
parent 91e70c76ff
commit 9caccd3b36
3 changed files with 69 additions and 55 deletions

View File

@@ -80,13 +80,13 @@ class ColMaker: public IUpdater<FMatrix> {
const std::vector<unsigned> &root_index,
RegTree *p_tree) {
this->InitData(gpair, fmat, root_index, *p_tree);
this->InitNewNode(qexpand, gpair, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree);
this->ResetPosition(this->qexpand, fmat, *p_tree);
this->UpdateQueueExpand(*p_tree, &this->qexpand);
this->InitNewNode(qexpand, gpair, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
// if nothing left to be expand, break
if (qexpand.size() == 0) break;
}
@@ -109,25 +109,31 @@ class ColMaker: public IUpdater<FMatrix> {
const FMatrix &fmat,
const std::vector<unsigned> &root_index, const RegTree &tree) {
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
{// setup position
position.resize(fmat.NumBufferedRow());
position.resize(gpair.size());
if (root_index.size() == 0) {
std::fill(position.begin(), position.end(), 0);
for (size_t i = 0; i < rowset.size(); ++i) {
position[rowset[i]] = 0;
}
} else {
for (size_t i = 0; i < position.size(); ++i) {
position[i] = root_index[i];
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting");
for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i];
position[ridx] = root_index[ridx];
utils::Assert(root_index[ridx] < (unsigned)tree.param.num_roots, "root index exceed setting");
}
}
// mark delete for the deleted datas
for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) position[i] = -1;
for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f) position[ridx] = -1;
}
// mark subsample
if (param.subsample < 1.0f) {
for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) continue;
if (random::SampleBinary(param.subsample) == 0) position[i] = -1;
for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f) continue;
if (random::SampleBinary(param.subsample) == 0) position[ridx] = -1;
}
}
}
@@ -168,6 +174,7 @@ class ColMaker: public IUpdater<FMatrix> {
/*! \brief initialize the base_weight, root_gain, and NodeEntry for all the new nodes in qexpand */
inline void InitNewNode(const std::vector<int> &qexpand,
const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const RegTree &tree) {
{// setup statistics space for each tree node
for (size_t i = 0; i < stemp.size(); ++i) {
@@ -175,13 +182,15 @@ class ColMaker: public IUpdater<FMatrix> {
}
snode.resize(tree.param.num_nodes, NodeEntry());
}
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
// setup position
const unsigned ndata = static_cast<unsigned>(position.size());
const unsigned ndata = static_cast<unsigned>(rowset.size());
#pragma omp parallel for schedule(static)
for (unsigned i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
const int tid = omp_get_thread_num();
if (position[i] < 0) continue;
stemp[tid][position[i]].stats.Add(gpair[i]);
if (position[ridx] < 0) continue;
stemp[tid][position[ridx]].stats.Add(gpair[ridx]);
}
// sum the per thread statistics together
for (size_t j = 0; j < qexpand.size(); ++j) {
@@ -303,17 +312,19 @@ class ColMaker: public IUpdater<FMatrix> {
}
// reset position of each data points after split is created in the tree
inline void ResetPosition(const std::vector<int> &qexpand, const FMatrix &fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
// step 1, set default direct nodes to default, and leaf nodes to -1
const unsigned ndata = static_cast<unsigned>(position.size());
const unsigned ndata = static_cast<unsigned>(rowset.size());
#pragma omp parallel for schedule(static)
for (unsigned i = 0; i < ndata; ++i) {
const int nid = position[i];
for (unsigned i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
const int nid = position[ridx];
if (nid >= 0) {
if (tree[nid].is_leaf()) {
position[i] = -1;
position[ridx] = -1;
} else {
// push to default branch, correct latter
position[i] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
}
}
}