fix num parallel tree
This commit is contained in:
parent
c4b21775fa
commit
3b02fb26b0
@ -27,10 +27,15 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
const FMatrix &fmat,
|
const FMatrix &fmat,
|
||||||
const std::vector<unsigned> &root_index,
|
const std::vector<unsigned> &root_index,
|
||||||
const std::vector<RegTree*> &trees) {
|
const std::vector<RegTree*> &trees) {
|
||||||
|
// rescale learning rate according to size of trees
|
||||||
|
float lr = param.learning_rate;
|
||||||
|
param.learning_rate = lr / trees.size();
|
||||||
|
// build tree
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
Builder builder(param);
|
Builder builder(param);
|
||||||
builder.Update(gpair, fmat, root_index, trees[i]);
|
builder.Update(gpair, fmat, root_index, trees[i]);
|
||||||
}
|
}
|
||||||
|
param.learning_rate = lr;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -26,9 +26,13 @@ class TreePruner: public IUpdater<FMatrix> {
|
|||||||
const FMatrix &fmat,
|
const FMatrix &fmat,
|
||||||
const std::vector<unsigned> &root_index,
|
const std::vector<unsigned> &root_index,
|
||||||
const std::vector<RegTree*> &trees) {
|
const std::vector<RegTree*> &trees) {
|
||||||
|
// rescale learning rate according to size of trees
|
||||||
|
float lr = param.learning_rate;
|
||||||
|
param.learning_rate = lr / trees.size();
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
this->DoPrune(*trees[i]);
|
this->DoPrune(*trees[i]);
|
||||||
}
|
}
|
||||||
|
param.learning_rate = lr;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user