From 3b02fb26b088c44969def737b228c79160152cce Mon Sep 17 00:00:00 2001 From: "tqchen@graphlab.com" Date: Mon, 18 Aug 2014 13:33:58 -0700 Subject: [PATCH] fix num parallel tree --- src/tree/updater_colmaker-inl.hpp | 5 +++++ src/tree/updater_prune-inl.hpp | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index f0624bdeb..eb5ff85fc 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -27,10 +27,15 @@ class ColMaker: public IUpdater { const FMatrix &fmat, const std::vector &root_index, const std::vector &trees) { + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); + // build tree for (size_t i = 0; i < trees.size(); ++i) { Builder builder(param); builder.Update(gpair, fmat, root_index, trees[i]); } + param.learning_rate = lr; } private: diff --git a/src/tree/updater_prune-inl.hpp b/src/tree/updater_prune-inl.hpp index b5205080b..363d6eec1 100644 --- a/src/tree/updater_prune-inl.hpp +++ b/src/tree/updater_prune-inl.hpp @@ -26,9 +26,13 @@ class TreePruner: public IUpdater { const FMatrix &fmat, const std::vector &root_index, const std::vector &trees) { + // rescale learning rate according to size of trees + float lr = param.learning_rate; + param.learning_rate = lr / trees.size(); for (size_t i = 0; i < trees.size(); ++i) { this->DoPrune(*trees[i]); } + param.learning_rate = lr; } private: