[core] fix slow predict-caching with many classes (#3109)

* fix prediction caching inefficiency for multiclass

* silence some warnings

* redundant if

* workaround for R v3.4.3 bug; fixes #3081
This commit is contained in:
Vadim Khotilovich
2018-02-15 18:31:42 -06:00
committed by GitHub
parent cf19caa46a
commit 9ffe8596f2
5 changed files with 31 additions and 32 deletions

View File

@@ -281,9 +281,7 @@ class GBTree : public GradientBooster {
}
monitor.Stop("BoostNewTrees");
monitor.Start("CommitModel");
for (int gid = 0; gid < ngroup; ++gid) {
this->CommitModel(std::move(new_trees[gid]), gid);
}
this->CommitModel(std::move(new_trees));
monitor.Stop("CommitModel");
}
@@ -338,11 +336,13 @@ class GBTree : public GradientBooster {
// commit new trees all at once
virtual void
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) {
model_.CommitModel(std::move(new_trees), bst_group);
predictor->UpdatePredictionCache(model_, &updaters, new_trees.size());
CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
int num_new_trees = 0;
for (int gid = 0; gid < model_.param.num_output_group; ++gid) {
num_new_trees += new_trees[gid].size();
model_.CommitModel(std::move(new_trees[gid]), gid);
}
predictor->UpdatePredictionCache(model_, &updaters, num_new_trees);
}
// --- data structure ---
@@ -514,20 +514,22 @@ class Dart : public GBTree {
}
}
}
// commit new trees all at once
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) override {
for (size_t i = 0; i < new_trees.size(); ++i) {
model_.trees.push_back(std::move(new_trees[i]));
model_.tree_info.push_back(bst_group);
void
CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) override {
int num_new_trees = 0;
for (int gid = 0; gid < model_.param.num_output_group; ++gid) {
num_new_trees += new_trees[gid].size();
model_.CommitModel(std::move(new_trees[gid]), gid);
}
model_.param.num_trees += static_cast<int>(new_trees.size());
size_t num_drop = NormalizeTrees(new_trees.size());
size_t num_drop = NormalizeTrees(num_new_trees);
if (dparam.silent != 1) {
LOG(INFO) << "drop " << num_drop << " trees, "
<< "weight = " << weight_drop.back();
}
}
// predict the leaf scores without dropped trees
inline bst_float PredValue(const RowBatch::Inst &inst,
int bst_group,
@@ -550,16 +552,17 @@ class Dart : public GBTree {
return psum;
}
// select dropped trees
// select which trees to drop
inline void DropTrees(unsigned ntree_limit_drop) {
idx_drop.clear();
if (ntree_limit_drop > 0) return;
std::uniform_real_distribution<> runif(0.0, 1.0);
auto& rnd = common::GlobalRandom();
// reset
idx_drop.clear();
// sample dropped trees
bool skip = false;
if (dparam.skip_drop > 0.0) skip = (runif(rnd) < dparam.skip_drop);
if (ntree_limit_drop == 0 && !skip) {
// sample some trees to drop
if (!skip) {
if (dparam.sample_type == 1) {
bst_float sum_weight = 0.0;
for (size_t i = 0; i < weight_drop.size(); ++i) {
@@ -594,6 +597,7 @@ class Dart : public GBTree {
}
}
}
// set normalization factors
inline size_t NormalizeTrees(size_t size_new_trees) {
float lr = 1.0 * dparam.learning_rate / size_new_trees;