fix num parallel tree

This commit is contained in:
tqchen@graphlab.com 2014-08-18 13:33:58 -07:00
parent c4b21775fa
commit 3b02fb26b0
2 changed files with 9 additions and 0 deletions

View File

@ -27,10 +27,15 @@ class ColMaker: public IUpdater<FMatrix> {
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, fmat, root_index, trees[i]);
}
param.learning_rate = lr;
}
private:

View File

@ -26,9 +26,13 @@ class TreePruner: public IUpdater<FMatrix> {
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
for (size_t i = 0; i < trees.size(); ++i) {
this->DoPrune(*trees[i]);
}
param.learning_rate = lr;
}
private: