Multiclass prediction caching for CPU Hist (#6550)

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS
2021-01-12 23:42:07 +03:00
committed by GitHub
parent 03cd087da1
commit 7f4d3a91b9
5 changed files with 94 additions and 21 deletions

View File

@@ -196,11 +196,17 @@ void GBTree::DoBoost(DMatrix* p_fmat,
const int ngroup = model_.learner_model_param->num_output_group;
ConfigureWithKnownData(this->cfg_, p_fmat);
monitor_.Start("BoostNewTrees");
auto* out = &predt->predictions;
CHECK_NE(ngroup, 0);
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret));
if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 &&
updaters_.back()->UpdatePredictionCache(p_fmat, out)) {
predt->Update(1);
}
} else {
CHECK_EQ(in_gpair->Size() % ngroup, 0U)
<< "must have exactly ngroup * nrow gpairs";
@@ -210,6 +216,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
in_gpair->DeviceIdx());
const auto& gpair_h = in_gpair->ConstHostVector();
auto nsize = static_cast<bst_omp_uint>(tmp.Size());
bool update_predict = true;
for (int gid = 0; gid < ngroup; ++gid) {
std::vector<GradientPair>& tmp_h = tmp.HostVector();
#pragma omp parallel for schedule(static)
@@ -218,7 +225,16 @@ void GBTree::DoBoost(DMatrix* p_fmat,
}
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(&tmp, p_fmat, gid, &ret);
const size_t num_new_trees = ret.size();
new_trees.push_back(std::move(ret));
auto* out = &predt->predictions;
if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 &&
updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) {
update_predict = false;
}
}
if (update_predict) {
predt->Update(1);
}
}
monitor_.Stop("BoostNewTrees");
@@ -314,20 +330,9 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
DMatrix* m,
PredictionCacheEntry* predts) {
monitor_.Start("CommitModel");
int num_new_trees = 0;
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
num_new_trees += new_trees[gid].size();
model_.CommitModel(std::move(new_trees[gid]), gid);
}
auto* out = &predts->predictions;
if (model_.learner_model_param->num_output_group == 1 &&
updaters_.size() > 0 &&
num_new_trees == 1 &&
out->Size() > 0 &&
updaters_.back()->UpdatePredictionCache(m, out)) {
auto delta = num_new_trees / model_.learner_model_param->num_output_group;
predts->Update(delta);
}
monitor_.Stop("CommitModel");
}

View File

@@ -110,8 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
}
bool QuantileHistMaker::UpdatePredictionCache(
const DMatrix* data,
HostDeviceVector<bst_float>* out_preds) {
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
if (param_.subsample < 1.0f) {
return false;
} else {
@@ -125,6 +124,23 @@ bool QuantileHistMaker::UpdatePredictionCache(
}
}
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
const DMatrix* data,
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
if (param_.subsample < 1.0f) {
return false;
} else {
if (hist_maker_param_.single_precision_histogram && float_builder_) {
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
} else {
return false;
}
}
}
template <typename GradientSumT>
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
int,
@@ -620,7 +636,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
template<typename GradientSumT>
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
const DMatrix* data,
HostDeviceVector<bst_float>* p_out_preds) {
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
@@ -659,7 +675,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
leaf_value = (*p_last_tree_)[nid].LeafValue();
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds[*it] += leaf_value;
out_preds[*it * ngroup + gid] += leaf_value;
}
}
});

View File

@@ -121,6 +121,9 @@ class QuantileHistMaker: public TreeUpdater {
bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds) override;
bool UpdatePredictionCacheMulticlass(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds,
const int gid, const int ngroup) override;
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
@@ -243,7 +246,9 @@ class QuantileHistMaker: public TreeUpdater {
}
bool UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* p_out_preds);
HostDeviceVector<bst_float>* p_out_preds,
const int gid = 0, const int ngroup = 1);
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);