From da1629e84842972560e1ed67cb2a2569ff5a8476 Mon Sep 17 00:00:00 2001 From: Vadim Khotilovich Date: Sun, 21 May 2017 23:47:57 -0500 Subject: [PATCH] [gbtree] fix update process to work with multiclass and multitree; fixes #2315 (#2332) --- R-package/tests/testthat/test_update.R | 22 ++++++++++++++++++++++ src/gbm/gbtree.cc | 9 +++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/R-package/tests/testthat/test_update.R b/R-package/tests/testthat/test_update.R index 8518711fb..e1803029b 100644 --- a/R-package/tests/testthat/test_update.R +++ b/R-package/tests/testthat/test_update.R @@ -73,3 +73,25 @@ test_that("updating the model works", { expect_gt(sum(abs(tr1[Feature != 'Leaf']$Quality - tr1ut[Feature != 'Leaf']$Quality)), 100) expect_lt(sum(tr1ut$Cover) / sum(tr1$Cover), 0.5) }) + +test_that("updating works for multiclass & multitree", { + dtr <- xgb.DMatrix(as.matrix(iris[, -5]), label = as.numeric(iris$Species) - 1) + watchlist <- list(train = dtr) + p0 <- list(max_depth = 2, eta = 0.5, nthread = 2, subsample = 0.6, + objective = "multi:softprob", num_class = 3, num_parallel_tree = 2, + base_score = 0) + set.seed(121) + bst0 <- xgb.train(p0, dtr, 5, watchlist, verbose = 0) + tr0 <- xgb.model.dt.tree(model = bst0) + + # run update process for an original model with subsampling + p0u <- modifyList(p0, list(process_type='update', updater='refresh', refresh_leaf=FALSE)) + bst0u <- xgb.train(p0u, dtr, nrounds = bst0$niter, watchlist, xgb_model = bst0, verbose = 0) + tr0u <- xgb.model.dt.tree(model = bst0u) + + # should be the same evaluation but different gains and larger cover + expect_equal(bst0$evaluation_log, bst0u$evaluation_log) + expect_equal(tr0[Feature == 'Leaf']$Quality, tr0u[Feature == 'Leaf']$Quality) + expect_gt(sum(abs(tr0[Feature != 'Leaf']$Quality - tr0u[Feature != 'Leaf']$Quality)), 100) + expect_gt(sum(tr0u$Cover) / sum(tr0$Cover), 1.5) +}) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index bca210953..ed1593333 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -246,12 +246,12 @@ class GBTree : public GradientBooster { ObjFunction* obj) override { const std::vector& gpair = *in_gpair; std::vector > > new_trees; - if (mparam.num_output_group == 1) { + const int ngroup = mparam.num_output_group; + if (ngroup == 1) { std::vector > ret; BoostNewTrees(gpair, p_fmat, 0, &ret); new_trees.push_back(std::move(ret)); } else { - const int ngroup = mparam.num_output_group; CHECK_EQ(gpair.size() % ngroup, 0U) << "must have exactly ngroup*nrow gpairs"; std::vector tmp(gpair.size() / ngroup); @@ -267,7 +267,7 @@ class GBTree : public GradientBooster { } } double tstart = dmlc::GetTime(); - for (int gid = 0; gid < mparam.num_output_group; ++gid) { + for (int gid = 0; gid < ngroup; ++gid) { this->CommitModel(std::move(new_trees[gid]), gid); } if (tparam.debug_verbose > 0) { @@ -468,7 +468,8 @@ class GBTree : public GradientBooster { } else if (tparam.process_type == kUpdate) { CHECK_LT(trees.size(), trees_to_update.size()); // move an existing tree from trees_to_update - auto t = std::move(trees_to_update[trees.size()]); + auto t = std::move(trees_to_update[trees.size() + + bst_group * tparam.num_parallel_tree + i]); new_trees.push_back(t.get()); ret->push_back(std::move(t)); }