Small cleanup to hist tree method. (#7735)
* Remove special optimization using number of bins. * Remove 1-based index for column sampling. * Remove data layout. * Unify update prediction cache.
This commit is contained in:
@@ -363,19 +363,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
// The column sampler must be constructed by caller since we need to preserve the rng
|
||||
// for the entire training session.
|
||||
explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads,
|
||||
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task,
|
||||
bool skip_0_index = false)
|
||||
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
|
||||
: param_{param},
|
||||
column_sampler_{std::move(sampler)},
|
||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
|
||||
n_threads_{n_threads},
|
||||
task_{task} {
|
||||
interaction_constraints_.Configure(param, info.num_col_);
|
||||
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, skip_0_index);
|
||||
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
/**
|
||||
* \brief CPU implementation of update prediction cache, which calculates the leaf value
|
||||
* for the last tree and accumulates it to prediction vector.
|
||||
*
|
||||
* \param p_last_tree The last tree being updated by tree updater
|
||||
*/
|
||||
template <typename Partitioner, typename GradientSumT, typename ExpandEntry>
|
||||
void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_last_tree,
|
||||
std::vector<Partitioner> const &partitioner,
|
||||
HistEvaluator<GradientSumT, ExpandEntry> const &hist_evaluator,
|
||||
TrainParam const ¶m, linalg::VectorView<float> out_preds) {
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
CHECK(p_last_tree);
|
||||
auto const &tree = *p_last_tree;
|
||||
auto const &snode = hist_evaluator.Stats();
|
||||
auto evaluator = hist_evaluator.Evaluator();
|
||||
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
||||
size_t n_nodes = p_last_tree->GetNodes().size();
|
||||
for (auto &part : partitioner) {
|
||||
CHECK_EQ(part.Size(), n_nodes);
|
||||
common::BlockedSpace2d space(
|
||||
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
|
||||
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
|
||||
auto const &rowset = part[nidx];
|
||||
auto const &stats = snode[nidx];
|
||||
auto leaf_value =
|
||||
evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate;
|
||||
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds(*it) += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
|
||||
Reference in New Issue
Block a user