Small cleanup to hist tree method. (#7735)

* Remove special optimization using number of bins.
* Remove 1-based index for column sampling.
* Remove data layout.
* Unify update prediction cache.
This commit is contained in:
Jiaming Yuan
2022-03-20 03:44:55 +08:00
committed by GitHub
parent 718472dbe2
commit 996cc705af
9 changed files with 140 additions and 205 deletions

View File

@@ -363,19 +363,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(TrainParam const &param, MetaInfo const &info, int32_t n_threads,
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task,
bool skip_0_index = false)
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
: param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
n_threads_{n_threads},
task_{task} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, skip_0_index);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
}
};
} // namespace tree
} // namespace xgboost
/**
* \brief CPU implementation of update prediction cache, which calculates the leaf value
* for the last tree and accumulates it to prediction vector.
*
* \param p_last_tree The last tree being updated by tree updater
*/
template <typename Partitioner, typename GradientSumT, typename ExpandEntry>
void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
HistEvaluator<GradientSumT, ExpandEntry> const &hist_evaluator,
TrainParam const &param, linalg::VectorView<float> out_preds) {
CHECK_GT(out_preds.Size(), 0U);
CHECK(p_last_tree);
auto const &tree = *p_last_tree;
auto const &snode = hist_evaluator.Stats();
auto evaluator = hist_evaluator.Evaluator();
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
size_t n_nodes = p_last_tree->GetNodes().size();
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx];
auto const &stats = snode[nidx];
auto leaf_value =
evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_