This commit is contained in:
parent
b52db87d5c
commit
da1629e848
@ -73,3 +73,25 @@ test_that("updating the model works", {
|
||||
expect_gt(sum(abs(tr1[Feature != 'Leaf']$Quality - tr1ut[Feature != 'Leaf']$Quality)), 100)
|
||||
expect_lt(sum(tr1ut$Cover) / sum(tr1$Cover), 0.5)
|
||||
})
|
||||
|
||||
test_that("updating works for multiclass & multitree", {
|
||||
dtr <- xgb.DMatrix(as.matrix(iris[, -5]), label = as.numeric(iris$Species) - 1)
|
||||
watchlist <- list(train = dtr)
|
||||
p0 <- list(max_depth = 2, eta = 0.5, nthread = 2, subsample = 0.6,
|
||||
objective = "multi:softprob", num_class = 3, num_parallel_tree = 2,
|
||||
base_score = 0)
|
||||
set.seed(121)
|
||||
bst0 <- xgb.train(p0, dtr, 5, watchlist, verbose = 0)
|
||||
tr0 <- xgb.model.dt.tree(model = bst0)
|
||||
|
||||
# run update process for an original model with subsampling
|
||||
p0u <- modifyList(p0, list(process_type='update', updater='refresh', refresh_leaf=FALSE))
|
||||
bst0u <- xgb.train(p0u, dtr, nrounds = bst0$niter, watchlist, xgb_model = bst0, verbose = 0)
|
||||
tr0u <- xgb.model.dt.tree(model = bst0u)
|
||||
|
||||
# should be the same evaluation but different gains and larger cover
|
||||
expect_equal(bst0$evaluation_log, bst0u$evaluation_log)
|
||||
expect_equal(tr0[Feature == 'Leaf']$Quality, tr0u[Feature == 'Leaf']$Quality)
|
||||
expect_gt(sum(abs(tr0[Feature != 'Leaf']$Quality - tr0u[Feature != 'Leaf']$Quality)), 100)
|
||||
expect_gt(sum(tr0u$Cover) / sum(tr0$Cover), 1.5)
|
||||
})
|
||||
|
||||
@ -246,12 +246,12 @@ class GBTree : public GradientBooster {
|
||||
ObjFunction* obj) override {
|
||||
const std::vector<bst_gpair>& gpair = *in_gpair;
|
||||
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
|
||||
if (mparam.num_output_group == 1) {
|
||||
const int ngroup = mparam.num_output_group;
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(gpair, p_fmat, 0, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
} else {
|
||||
const int ngroup = mparam.num_output_group;
|
||||
CHECK_EQ(gpair.size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup*nrow gpairs";
|
||||
std::vector<bst_gpair> tmp(gpair.size() / ngroup);
|
||||
@ -267,7 +267,7 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
double tstart = dmlc::GetTime();
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
this->CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
if (tparam.debug_verbose > 0) {
|
||||
@ -468,7 +468,8 @@ class GBTree : public GradientBooster {
|
||||
} else if (tparam.process_type == kUpdate) {
|
||||
CHECK_LT(trees.size(), trees_to_update.size());
|
||||
// move an existing tree from trees_to_update
|
||||
auto t = std::move(trees_to_update[trees.size()]);
|
||||
auto t = std::move(trees_to_update[trees.size() +
|
||||
bst_group * tparam.num_parallel_tree + i]);
|
||||
new_trees.push_back(t.get());
|
||||
ret->push_back(std::move(t));
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user