xgboost/src/tree/hist/evaluate_splits.h

831 lines
34 KiB
C++

/**
* Copyright 2021-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#include <algorithm> // for copy
#include <cstddef> // for size_t
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <utility> // for move
#include <vector> // for vector
#include "../../common/categorical.h" // for CatBitField
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
#include "../../common/linalg_op.h" // for cbegin, cend, begin
#include "../../common/random.h" // for ColumnSampler
#include "../constraints.h" // for FeatureInteractionConstraintHost
#include "../param.h" // for TrainParam
#include "../split_evaluator.h" // for TreeEvaluator
#include "expand_entry.h" // for MultiExpandEntry
#include "hist_cache.h" // for BoundedHistCollection
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
#include "xgboost/context.h" // for COntext
#include "xgboost/linalg.h" // for Constants, Vector
namespace xgboost::tree {
class HistEvaluator {
private:
struct NodeEntry {
/*! \brief statics for node entry */
GradStats stats;
/*! \brief loss of this node, without split */
bst_float root_gain{0.0f};
};
private:
Context const* ctx_;
TrainParam const* param_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
bool is_col_split_{false};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;
// if sum of statistics for non-missing values in the node
// is equal to sum of statistics for all values:
// then - there are no missing values
// else - there are missing values
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
return false;
} else {
return true;
}
}
[[nodiscard]] bool IsValid(GradStats const &left, GradStats const &right) const {
return left.GetHess() >= param_->min_child_weight &&
right.GetHess() >= param_->min_child_weight;
}
/**
* \brief Use learned direction with one-hot split. Other implementations (LGB) create a
* pseudo-category for missing value but here we just do a complete scan to avoid
* making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, common::ConstGHistRow hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();
bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
bst_bin_t n_bins = iend - ibegin;
GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;
best.is_cat = false; // marker for whether it's updated or not.
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
GradStats missing;
auto const &parent = snode_[nidx];
missing.SetSubstract(parent.stats, feature_sum);
for (bst_bin_t i = ibegin; i != iend; i += 1) {
auto split_pt = cut_val[i];
// missing on left (treat missing as other categories)
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_left_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
}
// missing on right (treat missing as chosen category)
right_sum.Add(missing);
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
}
}
if (best.is_cat) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
best.cat_bits.resize(n, 0);
common::CatBitField cat_bits{best.cat_bits};
cat_bits.Set(best.split_value);
}
p_best->Update(best);
}
/**
* \brief Enumerate with partition-based splits.
*
* The implementation is different from LightGBM. Firstly we don't have a
* pseudo-cateogry for missing value, instead of we make 2 complete scans over the
* histogram. Secondly, both scan directions generate splits in the same
* order. Following table depicts the scan process, square bracket means the gradient in
* missing values is resided on that partition:
*
* | Forward | Backward |
* |----------+----------|
* | [BCDE] A | E [ABCD] |
* | [CDE] AB | DE [ABC] |
* | [DE] ABC | CDE [AB] |
* | [E] ABCD | BCDE [A] |
*/
template <int d_step>
void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &parent = snode_[nidx];
bst_bin_t f_begin = cut_ptr[fidx];
bst_bin_t f_end = cut_ptr[fidx + 1];
bst_bin_t n_bins_feature{f_end - f_begin};
auto n_bins = std::min(param_->max_cat_threshold, n_bins_feature);
// statistics on both sides of split
GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;
auto f_hist = hist.subspan(f_begin, n_bins_feature);
bst_bin_t it_begin, it_end;
if (d_step > 0) {
it_begin = f_begin;
it_end = it_begin + n_bins - 1;
} else {
it_begin = f_end - 1;
it_end = it_begin - n_bins + 1;
}
bst_bin_t best_thresh{-1};
for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
auto j = i - f_begin; // index local to current feature
if (d_step == 1) {
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
left_sum.SetSubstract(parent.stats, right_sum); // missing on left
} else {
left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
}
if (IsValid(left_sum, right_sum)) {
auto loss_chg = evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain;
// We don't have a numeric split point, nan here is a dummy split.
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
left_sum, right_sum)) {
best_thresh = i;
}
}
}
if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
auto cat = cut_val[c + f_begin];
cat_bits.Set(cat);
});
}
p_best->Update(best);
}
// Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
template <int d_step>
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::ConstGHistRow hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
// aliases
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();
auto const &parent = snode_[nidx];
// statistics on both sides of split
GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;
// bin boundaries
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
// imin: index (offset) of the minimum value for feature fid need this for backward
// enumeration
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
// ibegin, iend: smallest/largest cut points for feature fid use int to allow for
// value -1
bst_bin_t ibegin, iend;
if (d_step > 0) {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
iend = static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
// start working
// try to find a split
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
right_sum.SetSubstract(parent.stats, left_sum);
if (IsValid(left_sum, right_sum)) {
bst_float loss_chg;
bst_float split_pt;
if (d_step > 0) {
// forward enumeration: split at right bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
}
p_best->Update(best);
return left_sum;
}
/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
auto const world = collective::GetWorldSize();
auto const num_entries = entries.size();
// First, gather all the primitive fields.
std::vector<CPUExpandEntry> local_entries(num_entries);
std::vector<uint32_t> cat_bits;
std::vector<std::size_t> cat_bits_sizes;
for (std::size_t i = 0; i < num_entries; i++) {
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
}
auto all_entries = collective::Allgather(local_entries);
// Gather all the cat_bits.
auto gathered = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
// Copy the cat_bits back into all expand entries.
all_entries[i].split.cat_bits.resize(gathered.sizes[i]);
std::copy_n(gathered.result.cbegin() + gathered.offsets[i], gathered.sizes[i],
all_entries[i].split.cat_bits.begin());
});
return all_entries;
}
public:
void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut,
common::Span<FeatureType const> feature_types, const RegTree &tree,
std::vector<CPUExpandEntry> *p_entries) {
auto n_threads = ctx_->Threads();
auto& entries = *p_entries;
// All nodes are on the same level, so we can store the shared ptr.
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
entries.size());
for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
auto nidx = entries[nidx_in_set].nid;
features[nidx_in_set] =
column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
}
CHECK(!features.empty());
const size_t grain_size =
std::max<size_t>(1, features.front()->Size() / n_threads);
common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) {
return features[nidx_in_set]->Size();
}, grain_size);
std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
for (size_t i = 0; i < entries.size(); ++i) {
for (decltype(n_threads) j = 0; j < n_threads; ++j) {
tloc_candidates[i * n_threads + j] = entries[i];
}
}
auto evaluator = tree_evaluator_.GetEvaluator();
auto const& cut_ptrs = cut.Ptrs();
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
} else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
}
}
}
});
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(
tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto all_entries = Allgather(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(
all_entries[worker * entries.size() + nidx_in_set].split);
}
}
}
}
// Add splits to tree, handles all statistic
void ApplyTreeSplit(CPUExpandEntry const& candidate, RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator();
RegTree &tree = *p_tree;
GradStats parent_sum = candidate.split.left_sum;
parent_sum.Add(candidate.split.right_sum);
auto base_weight = evaluator.CalcWeight(candidate.nid, *param_, GradStats{parent_sum});
auto left_weight =
evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.left_sum});
auto right_weight =
evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.right_sum});
if (candidate.split.is_cat) {
tree.ExpandCategorical(
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
candidate.split.DefaultLeft(), base_weight, left_weight * param_->learning_rate,
right_weight * param_->learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight,
left_weight * param_->learning_rate, right_weight * param_->learning_rate,
candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
}
// Set up child constraints
auto left_child = tree[candidate.nid].LeftChild();
auto right_child = tree[candidate.nid].RightChild();
tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
tree[candidate.nid].SplitIndex(), left_weight,
right_weight);
evaluator = tree_evaluator_.GetEvaluator();
snode_.resize(tree.GetNodes().size());
snode_.at(left_child).stats = candidate.split.left_sum;
snode_.at(left_child).root_gain =
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});
snode_.at(right_child).stats = candidate.split.right_sum;
snode_.at(right_child).root_gain =
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});
interaction_constraints_.Split(candidate.nid,
tree[candidate.nid].SplitIndex(), left_child,
right_child);
}
[[nodiscard]] auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
[[nodiscard]] auto const &Stats() const { return snode_; }
float InitRoot(GradStats const &root_sum) {
snode_.resize(1);
auto root_evaluator = tree_evaluator_.GetEvaluator();
snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
snode_[0].root_gain =
root_evaluator.CalcGain(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
auto weight = root_evaluator.CalcWeight(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
return weight;
}
public:
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(Context const *ctx, TrainParam const *param, MetaInfo const &info,
std::shared_ptr<common::ColumnSampler> sampler)
: ctx_{ctx},
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
param_->colsample_bytree);
}
};
class HistMultiEvaluator {
std::vector<double> gain_;
linalg::Matrix<GradientPairPrecise> stats_;
TrainParam const *param_;
FeatureInteractionConstraintHost interaction_constraints_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;
bool is_col_split_{false};
private:
static double MultiCalcSplitGain(TrainParam const &param,
linalg::VectorView<GradientPairPrecise const> left_sum,
linalg::VectorView<GradientPairPrecise const> right_sum,
linalg::VectorView<float> left_weight,
linalg::VectorView<float> right_weight) {
CalcWeight(param, left_sum, left_weight);
CalcWeight(param, right_sum, right_weight);
auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
return left_gain + right_gain;
}
template <bst_bin_t d_step>
bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
common::Span<common::ConstGHistRow> hist,
linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &min_val = cut.MinValues();
auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
auto left_sum = sum.Slice(0, linalg::All());
auto right_sum = sum.Slice(1, linalg::All());
bst_bin_t ibegin, iend;
if (d_step > 0) {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
auto n_targets = hist.size();
auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
auto left_weight = weight.Slice(0, linalg::All());
auto right_weight = weight.Slice(1, linalg::All());
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
for (bst_target_t t = 0; t < n_targets; ++t) {
auto t_hist = hist[t];
auto t_p = parent_sum(t);
left_sum(t) += t_hist[i];
right_sum(t) = t_p - left_sum(t);
}
if (d_step > 0) {
auto split_pt = cut_val[i];
auto loss_chg =
MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
parent_gain;
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
float split_pt;
if (i == imin) {
split_pt = min_val[fidx];
} else {
split_pt = cut_val[i - 1];
}
auto loss_chg =
MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
parent_gain;
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
// return true if there's missing. Doesn't handle floating-point error well.
if (d_step == +1) {
return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
linalg::cbegin(parent_sum));
}
return false;
}
/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
auto const world = collective::GetWorldSize();
auto const num_entries = entries.size();
// First, gather all the primitive fields.
std::vector<MultiExpandEntry> local_entries(num_entries);
std::vector<uint32_t> cat_bits;
std::vector<std::size_t> cat_bits_sizes;
std::vector<GradientPairPrecise> gradients;
for (std::size_t i = 0; i < num_entries; i++) {
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes, &gradients);
}
auto all_entries = collective::Allgather(local_entries);
// Gather all the cat_bits.
auto gathered_cat_bits = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
// Gather all the gradients.
auto const num_gradients = gradients.size();
auto const all_gradients = collective::Allgather(gradients);
auto const total_entries = num_entries * world;
auto const gradients_per_entry = num_gradients / num_entries;
auto const gradients_per_side = gradients_per_entry / 2;
common::ParallelFor(total_entries, ctx_->Threads(), [&] (auto i) {
// Copy the cat_bits back into all expand entries.
all_entries[i].split.cat_bits.resize(gathered_cat_bits.sizes[i]);
std::copy_n(gathered_cat_bits.result.cbegin() + gathered_cat_bits.offsets[i],
gathered_cat_bits.sizes[i], all_entries[i].split.cat_bits.begin());
// Copy the gradients back into all expand entries.
all_entries[i].split.left_sum.resize(gradients_per_side);
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry, gradients_per_side,
all_entries[i].split.left_sum.begin());
all_entries[i].split.right_sum.resize(gradients_per_side);
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry + gradients_per_side,
gradients_per_side, all_entries[i].split.right_sum.begin());
});
return all_entries;
}
public:
void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist,
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
auto &entries = *p_entries;
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
auto nidx = entries[nidx_in_set].nid;
features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
}
CHECK(!features.empty());
std::int32_t n_threads = ctx_->Threads();
std::size_t const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
common::BlockedSpace2d space(
entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
grain_size);
std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
for (std::size_t i = 0; i < entries.size(); ++i) {
for (std::int32_t j = 0; j < n_threads; ++j) {
tloc_candidates[i * n_threads + j] = entries[i];
}
}
common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto parent_sum = stats_.Slice(entry->nid, linalg::All());
std::vector<common::ConstGHistRow> node_hist;
for (auto t_hist : hist) {
node_hist.emplace_back((*t_hist)[entry->nid]);
}
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
if (!interaction_constraints_.Query(entry->nid, fidx)) {
continue;
}
auto parent_gain = gain_[entry->nid];
bool missing =
this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
if (missing) {
this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
}
}
});
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto all_entries = Allgather(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(
all_entries[worker * entries.size() + nidx_in_set].split);
}
}
}
}
linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
auto n_targets = root_sum.Size();
stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets);
gain_.resize(1);
linalg::Vector<float> weight({n_targets}, ctx_->Device());
CalcWeight(*param_, root_sum, weight.HostView());
auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
gain_.front() = root_gain;
auto h_stats = stats_.HostView();
std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));
return weight;
}
void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
auto n_targets = p_tree->NumTargets();
auto parent_sum = stats_.Slice(candidate.nid, linalg::All());
auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
auto base_weight = weight.Slice(0, linalg::All());
CalcWeight(*param_, parent_sum, base_weight);
auto left_weight = weight.Slice(1, linalg::All());
auto left_sum =
linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);
auto right_weight = weight.Slice(2, linalg::All());
auto right_sum =
linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);
p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
CHECK(p_tree->IsMultiTarget());
auto left_child = p_tree->LeftChild(candidate.nid);
CHECK_GT(left_child, candidate.nid);
auto right_child = p_tree->RightChild(candidate.nid);
CHECK_GT(right_child, candidate.nid);
std::size_t n_nodes = p_tree->Size();
gain_.resize(n_nodes);
// Re-calculate weight without learning rate.
CalcWeight(*param_, left_sum, left_weight);
CalcWeight(*param_, right_sum, right_weight);
gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);
if (n_nodes >= stats_.Shape(0)) {
stats_.Reshape(n_nodes * 2, stats_.Shape(1));
}
CHECK_EQ(stats_.Shape(1), n_targets);
auto left_sum_stat = stats_.Slice(left_child, linalg::All());
std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
linalg::begin(left_sum_stat));
auto right_sum_stat = stats_.Slice(right_child, linalg::All());
std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
linalg::begin(right_sum_stat));
}
explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param},
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
param_->colsample_bytree);
}
};
/**
* \brief CPU implementation of update prediction cache, which calculates the leaf value
* for the last tree and accumulates it to prediction vector.
*
* \param p_last_tree The last tree being updated by tree updater
*/
template <typename Partitioner>
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
linalg::VectorView<float> out_preds) {
auto const &tree = *p_last_tree;
CHECK(out_preds.Device().IsCPU());
size_t n_nodes = p_last_tree->GetNodes().size();
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx];
auto leaf_value = tree[nidx].LeafValue();
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
}
template <typename Partitioner>
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
linalg::MatrixView<float> out_preds) {
CHECK_GT(out_preds.Size(), 0U);
CHECK(p_last_tree);
auto const &tree = *p_last_tree;
if (!tree.IsMultiTarget()) {
UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0));
return;
}
auto const *mttree = tree.GetMultiTargetTree();
auto n_nodes = mttree->Size();
auto n_targets = tree.NumTargets();
CHECK_EQ(out_preds.Shape(1), n_targets);
CHECK(out_preds.Device().IsCPU());
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
if (tree.IsLeaf(nidx)) {
auto const &rowset = part[nidx];
auto leaf_value = mttree->LeafValue(nidx);
for (std::size_t const *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
for (std::size_t i = 0; i < n_targets; ++i) {
out_preds(*it, i) += leaf_value(i);
}
}
}
});
}
}
} // namespace xgboost::tree
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_