Multiclass prediction caching for CPU Hist (#6550)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -196,11 +196,17 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
const int ngroup = model_.learner_model_param->num_output_group;
|
||||
ConfigureWithKnownData(this->cfg_, p_fmat);
|
||||
monitor_.Start("BoostNewTrees");
|
||||
auto* out = &predt->predictions;
|
||||
CHECK_NE(ngroup, 0);
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
||||
const size_t num_new_trees = ret.size();
|
||||
new_trees.push_back(std::move(ret));
|
||||
if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 &&
|
||||
updaters_.back()->UpdatePredictionCache(p_fmat, out)) {
|
||||
predt->Update(1);
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->Size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup * nrow gpairs";
|
||||
@@ -210,6 +216,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
in_gpair->DeviceIdx());
|
||||
const auto& gpair_h = in_gpair->ConstHostVector();
|
||||
auto nsize = static_cast<bst_omp_uint>(tmp.Size());
|
||||
bool update_predict = true;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
std::vector<GradientPair>& tmp_h = tmp.HostVector();
|
||||
#pragma omp parallel for schedule(static)
|
||||
@@ -218,7 +225,16 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
}
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||
const size_t num_new_trees = ret.size();
|
||||
new_trees.push_back(std::move(ret));
|
||||
auto* out = &predt->predictions;
|
||||
if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 &&
|
||||
updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) {
|
||||
update_predict = false;
|
||||
}
|
||||
}
|
||||
if (update_predict) {
|
||||
predt->Update(1);
|
||||
}
|
||||
}
|
||||
monitor_.Stop("BoostNewTrees");
|
||||
@@ -314,20 +330,9 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
||||
DMatrix* m,
|
||||
PredictionCacheEntry* predts) {
|
||||
monitor_.Start("CommitModel");
|
||||
int num_new_trees = 0;
|
||||
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
|
||||
num_new_trees += new_trees[gid].size();
|
||||
model_.CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
auto* out = &predts->predictions;
|
||||
if (model_.learner_model_param->num_output_group == 1 &&
|
||||
updaters_.size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
out->Size() > 0 &&
|
||||
updaters_.back()->UpdatePredictionCache(m, out)) {
|
||||
auto delta = num_new_trees / model_.learner_model_param->num_output_group;
|
||||
predts->Update(delta);
|
||||
}
|
||||
monitor_.Stop("CommitModel");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user