From a33fa05bdafc4d46fae29b690dc30cff8e954f83 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 30 Mar 2017 06:10:57 +1300 Subject: [PATCH] GPU Plugin: Bug fix #2048 (#2155) --- plugin/updater_gpu/src/gpu_builder.cuh | 5 +++++ plugin/updater_gpu/src/updater_gpu.cc | 1 + 2 files changed, 6 insertions(+) diff --git a/plugin/updater_gpu/src/gpu_builder.cuh b/plugin/updater_gpu/src/gpu_builder.cuh index ba1521d35..bfdfa6d38 100644 --- a/plugin/updater_gpu/src/gpu_builder.cuh +++ b/plugin/updater_gpu/src/gpu_builder.cuh @@ -19,6 +19,11 @@ class GPUBuilder { void Init(const TrainParam ¶m); ~GPUBuilder(); + void UpdateParam(const TrainParam ¶m) + { + this->param = param; + } + void Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree); diff --git a/plugin/updater_gpu/src/updater_gpu.cc b/plugin/updater_gpu/src/updater_gpu.cc index 4083b8bd5..3c7badee2 100644 --- a/plugin/updater_gpu/src/updater_gpu.cc +++ b/plugin/updater_gpu/src/updater_gpu.cc @@ -27,6 +27,7 @@ template class GPUMaker : public TreeUpdater { // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); + builder.UpdateParam(param); // build tree for (size_t i = 0; i < trees.size(); ++i) { builder.Update(gpair, dmat, trees[i]);