[gbtree] fix update process to work with multiclass and multitree; fixes #2315 (#2332)

This commit is contained in:
Vadim Khotilovich
2017-05-21 23:47:57 -05:00
committed by GitHub
parent b52db87d5c
commit da1629e848
2 changed files with 27 additions and 4 deletions

View File

@@ -246,12 +246,12 @@ class GBTree : public GradientBooster {
ObjFunction* obj) override {
const std::vector<bst_gpair>& gpair = *in_gpair;
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
if (mparam.num_output_group == 1) {
const int ngroup = mparam.num_output_group;
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(gpair, p_fmat, 0, &ret);
new_trees.push_back(std::move(ret));
} else {
const int ngroup = mparam.num_output_group;
CHECK_EQ(gpair.size() % ngroup, 0U)
<< "must have exactly ngroup*nrow gpairs";
std::vector<bst_gpair> tmp(gpair.size() / ngroup);
@@ -267,7 +267,7 @@ class GBTree : public GradientBooster {
}
}
double tstart = dmlc::GetTime();
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
for (int gid = 0; gid < ngroup; ++gid) {
this->CommitModel(std::move(new_trees[gid]), gid);
}
if (tparam.debug_verbose > 0) {
@@ -468,7 +468,8 @@ class GBTree : public GradientBooster {
} else if (tparam.process_type == kUpdate) {
CHECK_LT(trees.size(), trees_to_update.size());
// move an existing tree from trees_to_update
auto t = std::move(trees_to_update[trees.size()]);
auto t = std::move(trees_to_update[trees.size() +
bst_group * tparam.num_parallel_tree + i]);
new_trees.push_back(t.get());
ret->push_back(std::move(t));
}