diff --git a/plugin/updater_gpu/src/gpu_builder.cuh b/plugin/updater_gpu/src/gpu_builder.cuh index ba1521d35..bfdfa6d38 100644 --- a/plugin/updater_gpu/src/gpu_builder.cuh +++ b/plugin/updater_gpu/src/gpu_builder.cuh @@ -19,6 +19,11 @@ class GPUBuilder { void Init(const TrainParam ¶m); ~GPUBuilder(); + void UpdateParam(const TrainParam ¶m) + { + this->param = param; + } + void Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree); diff --git a/plugin/updater_gpu/src/updater_gpu.cc b/plugin/updater_gpu/src/updater_gpu.cc index 4083b8bd5..3c7badee2 100644 --- a/plugin/updater_gpu/src/updater_gpu.cc +++ b/plugin/updater_gpu/src/updater_gpu.cc @@ -27,6 +27,7 @@ template class GPUMaker : public TreeUpdater { // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); + builder.UpdateParam(param); // build tree for (size_t i = 0; i < trees.size(); ++i) { builder.Update(gpair, dmat, trees[i]);