refresher test

This commit is contained in:
tqchen@graphlab.com
2014-08-19 11:41:35 -07:00
parent 762b360739
commit 91e70c76ff
4 changed files with 21 additions and 10 deletions

View File

@@ -110,22 +110,22 @@ class ColMaker: public IUpdater<FMatrix> {
const std::vector<unsigned> &root_index, const RegTree &tree) {
utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree");
{// setup position
position.resize(gpair.size());
position.resize(fmat.NumBufferedRow());
if (root_index.size() == 0) {
std::fill(position.begin(), position.end(), 0);
} else {
for (size_t i = 0; i < root_index.size(); ++i) {
for (size_t i = 0; i < position.size(); ++i) {
position[i] = root_index[i];
utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting");
}
}
// mark delete for the deleted datas
for (size_t i = 0; i < gpair.size(); ++i) {
for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) position[i] = -1;
}
// mark subsample
if (param.subsample < 1.0f) {
for (size_t i = 0; i < gpair.size(); ++i) {
for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) continue;
if (random::SampleBinary(param.subsample) == 0) position[i] = -1;
}
@@ -271,7 +271,9 @@ class ColMaker: public IUpdater<FMatrix> {
}
// start enumeration
const unsigned nsize = static_cast<unsigned>(feat_set.size());
#if defined(_OPENMP)
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
#endif
#pragma omp parallel for schedule(dynamic, batch_size)
for (unsigned i = 0; i < nsize; ++i) {
const unsigned fid = feat_set[i];

View File

@@ -20,7 +20,6 @@ class TreeRefresher: public IUpdater<FMatrix> {
// set training parameter
virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val);
if (!strcmp(name, "silent")) silent = atoi(val);
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
@@ -127,8 +126,6 @@ class TreeRefresher: public IUpdater<FMatrix> {
}
// number of thread in the data
int nthread;
// shutup
int silent;
// training parameter
TrainParam param;
};