GPU Plugin: Bug fix #2048 (#2155)

This commit is contained in:
Rory Mitchell 2017-03-30 06:10:57 +13:00 committed by Tianqi Chen
parent d45cf240a9
commit a33fa05bda
2 changed files with 6 additions and 0 deletions

View File

@ -19,6 +19,11 @@ class GPUBuilder {
void Init(const TrainParam &param); void Init(const TrainParam &param);
~GPUBuilder(); ~GPUBuilder();
void UpdateParam(const TrainParam &param)
{
this->param = param;
}
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat, void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
RegTree *p_tree); RegTree *p_tree);

View File

@ -27,6 +27,7 @@ template <typename TStats> class GPUMaker : public TreeUpdater {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;
param.learning_rate = lr / trees.size(); param.learning_rate = lr / trees.size();
builder.UpdateParam(param);
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
builder.Update(gpair, dmat, trees[i]); builder.Update(gpair, dmat, trees[i]);