Accept histogram cut instead gradient index in evaluation. (#7336)

This commit is contained in:
Jiaming Yuan
2021-10-20 18:04:46 +08:00
committed by GitHub
parent 15685996fc
commit 8d7c6366d7
4 changed files with 13 additions and 11 deletions

View File

@@ -56,15 +56,15 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
// a non-missing value for the particular feature fid.
template <int d_step>
GradStats EnumerateSplit(
const GHistIndexMatrix &gmat, const common::GHistRow<GradientSumT> &hist,
common::HistogramCuts const &cut, const common::GHistRow<GradientSumT> &hist,
const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx,
bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator) const {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
// aliases
const std::vector<uint32_t> &cut_ptr = gmat.cut.Ptrs();
const std::vector<bst_float> &cut_val = gmat.cut.Values();
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();
// statistics on both sides of split
GradStats c;
@@ -116,7 +116,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
snode.root_gain);
if (i == imin) {
// for leftmost bin, left bound is the smallest feature value
split_pt = gmat.cut.MinValues()[fidx];
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
@@ -132,7 +132,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
public:
void EvaluateSplits(const common::HistCollection<GradientSumT> &hist,
GHistIndexMatrix const &gidx, const RegTree &tree,
common::HistogramCuts const &cut, const RegTree &tree,
std::vector<ExpandEntry>* p_entries) {
auto& entries = *p_entries;
// All nodes are on the same level, so we can store the shared ptr.
@@ -168,10 +168,10 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
if (interaction_constraints_.Query(nidx, fidx)) {
auto grad_stats = EnumerateSplit<+1>(gidx, histogram, snode_[nidx],
auto grad_stats = EnumerateSplit<+1>(cut, histogram, snode_[nidx],
best, fidx, nidx, evaluator);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(gidx, histogram, snode_[nidx], best, fidx, nidx,
EnumerateSplit<-1>(cut, histogram, snode_[nidx], best, fidx, nidx,
evaluator);
}
}

View File

@@ -170,7 +170,8 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
builder_monitor_.Start("EvaluateSplits");
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(
BatchParam{GenericParameter::kCpuId, param_.max_bin})) {
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat, *p_tree, &entries);
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, *p_tree, &entries);
break;
}
builder_monitor_.Stop("EvaluateSplits");
node = entries.front();
@@ -271,7 +272,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
}
builder_monitor_.Start("EvaluateSplits");
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat,
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat.cut,
*p_tree, &nodes_to_evaluate);
builder_monitor_.Stop("EvaluateSplits");