Implement hist evaluator for multi-target tree. (#8908)

This commit is contained in:
Jiaming Yuan 2023-03-15 01:42:51 +08:00 committed by GitHub
parent 95e2baf7c2
commit 8685556af2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 416 additions and 76 deletions

View File

@ -7,23 +7,22 @@
#ifndef XGBOOST_COMMON_HIST_UTIL_H_ #ifndef XGBOOST_COMMON_HIST_UTIL_H_
#define XGBOOST_COMMON_HIST_UTIL_H_ #define XGBOOST_COMMON_HIST_UTIL_H_
#include <xgboost/data.h>
#include <algorithm> #include <algorithm>
#include <cstdint> // for uint32_t
#include <limits> #include <limits>
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "algorithm.h" // SegmentId
#include "categorical.h" #include "categorical.h"
#include "common.h" #include "common.h"
#include "quantile.h" #include "quantile.h"
#include "row_set.h" #include "row_set.h"
#include "threading_utils.h" #include "threading_utils.h"
#include "timer.h" #include "timer.h"
#include "xgboost/base.h" // bst_feature_t, bst_bin_t #include "xgboost/base.h" // for bst_feature_t, bst_bin_t
#include "xgboost/data.h"
namespace xgboost { namespace xgboost {
class GHistIndexMatrix; class GHistIndexMatrix;
@ -392,15 +391,18 @@ class HistCollection {
} }
// have we computed a histogram for i-th node? // have we computed a histogram for i-th node?
bool RowExists(bst_uint nid) const { [[nodiscard]] bool RowExists(bst_uint nid) const {
const uint32_t k_max = std::numeric_limits<uint32_t>::max(); const uint32_t k_max = std::numeric_limits<uint32_t>::max();
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max); return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
} }
/**
// initialize histogram collection * \brief Initialize histogram collection.
void Init(uint32_t nbins) { *
if (nbins_ != nbins) { * \param n_total_bins Number of bins across all features.
nbins_ = nbins; */
void Init(std::uint32_t n_total_bins) {
if (nbins_ != n_total_bins) {
nbins_ = n_total_bins;
// quite expensive operation, so let's do this only once // quite expensive operation, so let's do this only once
data_.clear(); data_.clear();
} }

View File

@ -99,22 +99,25 @@ class CommonRowPartitioner {
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree, void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) { const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
for (size_t i = 0; i < nodes.size(); ++i) { auto const& ptrs = gmat.cut.Ptrs();
const int32_t nid = nodes[i].nid; auto const& vals = gmat.cut.Values();
const bst_uint fid = tree[nid].SplitIndex();
const bst_float split_pt = tree[nid].SplitCond(); for (std::size_t i = 0; i < nodes.size(); ++i) {
const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; bst_node_t const nid = nodes[i].nid;
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; bst_feature_t const fid = tree[nid].SplitIndex();
const float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = ptrs[fid];
const uint32_t upper_bound = ptrs[fid + 1];
bst_bin_t split_cond = -1; bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id // convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points // split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max())); CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (auto bound = lower_bound; bound < upper_bound; ++bound) { for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == gmat.cut.Values()[bound]) { if (split_pt == vals[bound]) {
split_cond = static_cast<int32_t>(bound); split_cond = static_cast<bst_bin_t>(bound);
} }
} }
(*split_conditions).at(i) = split_cond; (*split_conditions)[i] = split_cond;
} }
} }

View File

@ -4,22 +4,25 @@
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#include <algorithm> #include <algorithm> // for copy
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <limits> #include <limits> // for numeric_limits
#include <memory> #include <memory> // for shared_ptr
#include <numeric> #include <numeric> // for accumulate
#include <utility> #include <utility> // for move
#include <vector> #include <vector> // for vector
#include "../../common/categorical.h" #include "../../common/categorical.h" // for CatBitField
#include "../../common/hist_util.h" #include "../../common/hist_util.h" // for GHistRow, HistogramCuts
#include "../../common/random.h" #include "../../common/linalg_op.h" // for cbegin, cend, begin
#include "../../data/gradient_index.h" #include "../../common/random.h" // for ColumnSampler
#include "../constraints.h" #include "../constraints.h" // for FeatureInteractionConstraintHost
#include "../param.h" // for TrainParam #include "../param.h" // for TrainParam
#include "../split_evaluator.h" #include "../split_evaluator.h" // for TreeEvaluator
#include "xgboost/context.h" #include "expand_entry.h" // for MultiExpandEntry
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
#include "xgboost/context.h" // for COntext
#include "xgboost/linalg.h" // for Constants, Vector
namespace xgboost::tree { namespace xgboost::tree {
template <typename ExpandEntry> template <typename ExpandEntry>
@ -410,8 +413,6 @@ class HistEvaluator {
tree[candidate.nid].SplitIndex(), left_weight, tree[candidate.nid].SplitIndex(), left_weight,
right_weight); right_weight);
auto max_node = std::max(left_child, tree[candidate.nid].RightChild());
max_node = std::max(candidate.nid, max_node);
snode_.resize(tree.GetNodes().size()); snode_.resize(tree.GetNodes().size());
snode_.at(left_child).stats = candidate.split.left_sum; snode_.at(left_child).stats = candidate.split.left_sum;
snode_.at(left_child).root_gain = snode_.at(left_child).root_gain =
@ -456,6 +457,216 @@ class HistEvaluator {
} }
}; };
class HistMultiEvaluator {
std::vector<double> gain_;
linalg::Matrix<GradientPairPrecise> stats_;
TrainParam const *param_;
FeatureInteractionConstraintHost interaction_constraints_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;
private:
static double MultiCalcSplitGain(TrainParam const &param,
linalg::VectorView<GradientPairPrecise const> left_sum,
linalg::VectorView<GradientPairPrecise const> right_sum,
linalg::VectorView<float> left_weight,
linalg::VectorView<float> right_weight) {
CalcWeight(param, left_sum, left_weight);
CalcWeight(param, right_sum, right_weight);
auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
return left_gain + right_gain;
}
template <bst_bin_t d_step>
bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
common::Span<common::GHistRow const> hist,
linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &min_val = cut.MinValues();
auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
auto left_sum = sum.Slice(0, linalg::All());
auto right_sum = sum.Slice(1, linalg::All());
bst_bin_t ibegin, iend;
if (d_step > 0) {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
auto n_targets = hist.size();
auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
auto left_weight = weight.Slice(0, linalg::All());
auto right_weight = weight.Slice(1, linalg::All());
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
for (bst_target_t t = 0; t < n_targets; ++t) {
auto t_hist = hist[t];
auto t_p = parent_sum(t);
left_sum(t) += t_hist[i];
right_sum(t) = t_p - left_sum(t);
}
if (d_step > 0) {
auto split_pt = cut_val[i];
auto loss_chg =
MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
parent_gain;
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
float split_pt;
if (i == imin) {
split_pt = min_val[fidx];
} else {
split_pt = cut_val[i - 1];
}
auto loss_chg =
MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
parent_gain;
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
// return true if there's missing. Doesn't handle floating-point error well.
if (d_step == +1) {
return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
linalg::cbegin(parent_sum));
}
return false;
}
public:
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
auto &entries = *p_entries;
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
auto nidx = entries[nidx_in_set].nid;
features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
}
CHECK(!features.empty());
std::int32_t n_threads = ctx_->Threads();
std::size_t const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
common::BlockedSpace2d space(
entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
grain_size);
std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
for (std::size_t i = 0; i < entries.size(); ++i) {
for (std::int32_t j = 0; j < n_threads; ++j) {
tloc_candidates[i * n_threads + j] = entries[i];
}
}
common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto parent_sum = stats_.Slice(entry->nid, linalg::All());
std::vector<common::GHistRow> node_hist;
for (auto t_hist : hist) {
node_hist.push_back((*t_hist)[entry->nid]);
}
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
if (!interaction_constraints_.Query(entry->nid, fidx)) {
continue;
}
auto parent_gain = gain_[entry->nid];
bool missing =
this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
if (missing) {
this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
}
}
});
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}
linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
auto n_targets = root_sum.Size();
stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets);
gain_.resize(1);
linalg::Vector<float> weight({n_targets}, ctx_->gpu_id);
CalcWeight(*param_, root_sum, weight.HostView());
auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
gain_.front() = root_gain;
auto h_stats = stats_.HostView();
std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));
return weight;
}
void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
auto n_targets = p_tree->NumTargets();
auto parent_sum = stats_.Slice(candidate.nid, linalg::All());
auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
auto base_weight = weight.Slice(0, linalg::All());
CalcWeight(*param_, parent_sum, base_weight);
auto left_weight = weight.Slice(1, linalg::All());
auto left_sum =
linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);
auto right_weight = weight.Slice(2, linalg::All());
auto right_sum =
linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);
p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
CHECK(p_tree->IsMultiTarget());
auto left_child = p_tree->LeftChild(candidate.nid);
CHECK_GT(left_child, candidate.nid);
auto right_child = p_tree->RightChild(candidate.nid);
CHECK_GT(right_child, candidate.nid);
std::size_t n_nodes = p_tree->Size();
gain_.resize(n_nodes);
gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);
if (n_nodes >= stats_.Shape(0)) {
stats_.Reshape(n_nodes * 2, stats_.Shape(1));
}
CHECK_EQ(stats_.Shape(1), n_targets);
auto left_sum_stat = stats_.Slice(left_child, linalg::All());
std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
linalg::begin(left_sum_stat));
auto right_sum_stat = stats_.Slice(right_child, linalg::All());
std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
linalg::begin(right_sum_stat));
}
explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param}, column_sampler_{std::move(sampler)}, ctx_{ctx} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
param_->colsample_bytree);
}
};
/** /**
* \brief CPU implementation of update prediction cache, which calculates the leaf value * \brief CPU implementation of update prediction cache, which calculates the leaf value
* for the last tree and accumulates it to prediction vector. * for the last tree and accumulates it to prediction vector.

View File

@ -14,10 +14,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "xgboost/parameter.h"
#include "xgboost/data.h"
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/linalg_op.h"
#include "../common/math.h" #include "../common/math.h"
#include "xgboost/data.h"
#include "xgboost/linalg.h"
#include "xgboost/parameter.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -197,12 +199,11 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
} }
/*! \brief given the loss change, whether we need to invoke pruning */ /*! \brief given the loss change, whether we need to invoke pruning */
bool NeedPrune(double loss_chg, int depth) const { [[nodiscard]] bool NeedPrune(double loss_chg, int depth) const {
return loss_chg < this->min_split_loss || return loss_chg < this->min_split_loss || (this->max_depth != 0 && depth > this->max_depth);
(this->max_depth != 0 && depth > this->max_depth);
} }
bst_node_t MaxNodes() const { [[nodiscard]] bst_node_t MaxNodes() const {
if (this->max_depth == 0 && this->max_leaves == 0) { if (this->max_depth == 0 && this->max_leaves == 0) {
LOG(FATAL) << "Max leaves and max depth cannot both be unconstrained."; LOG(FATAL) << "Max leaves and max depth cannot both be unconstrained.";
} }
@ -292,6 +293,34 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad)
return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess()); return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess());
} }
/**
* \brief multi-target weight, calculated with learning rate.
*/
inline void CalcWeight(TrainParam const &p, linalg::VectorView<GradientPairPrecise const> grad_sum,
float eta, linalg::VectorView<float> out_w) {
for (bst_target_t i = 0; i < out_w.Size(); ++i) {
out_w(i) = CalcWeight(p, grad_sum(i).GetGrad(), grad_sum(i).GetHess()) * eta;
}
}
/**
* \brief multi-target weight
*/
inline void CalcWeight(TrainParam const &p, linalg::VectorView<GradientPairPrecise const> grad_sum,
linalg::VectorView<float> out_w) {
return CalcWeight(p, grad_sum, 1.0f, out_w);
}
inline double CalcGainGivenWeight(TrainParam const &p,
linalg::VectorView<GradientPairPrecise const> sum_grad,
linalg::VectorView<float const> weight) {
double gain{0};
for (bst_target_t i = 0; i < weight.Size(); ++i) {
gain += -weight(i) * ThresholdL1(sum_grad(i).GetGrad(), p.reg_alpha);
}
return gain;
}
/*! \brief core statistics used for tree construction */ /*! \brief core statistics used for tree construction */
struct XGBOOST_ALIGNAS(16) GradStats { struct XGBOOST_ALIGNAS(16) GradStats {
using GradType = double; using GradType = double;
@ -301,8 +330,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
GradType sum_hess { 0 }; GradType sum_hess { 0 };
public: public:
XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; } [[nodiscard]] XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; }
XGBOOST_DEVICE GradType GetHess() const { return sum_hess; } [[nodiscard]] XGBOOST_DEVICE GradType GetHess() const { return sum_hess; }
friend std::ostream& operator<<(std::ostream& os, GradStats s) { friend std::ostream& operator<<(std::ostream& os, GradStats s) {
os << s.GetGrad() << "/" << s.GetHess(); os << s.GetGrad() << "/" << s.GetHess();
@ -340,7 +369,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
sum_hess = a.sum_hess - b.sum_hess; sum_hess = a.sum_hess - b.sum_hess;
} }
/*! \return whether the statistics is not used yet */ /*! \return whether the statistics is not used yet */
inline bool Empty() const { return sum_hess == 0.0; } [[nodiscard]] bool Empty() const { return sum_hess == 0.0; }
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(GradType grad, GradType hess) { inline void Add(GradType grad, GradType hess) {
sum_grad += grad; sum_grad += grad;
@ -348,6 +377,19 @@ struct XGBOOST_ALIGNAS(16) GradStats {
} }
}; };
// Helper functions for copying gradient statistic, one for vector leaf, another for normal scalar.
template <typename T, typename U>
std::vector<T> &CopyStats(linalg::VectorView<U> const &src, std::vector<T> *dst) { // NOLINT
dst->resize(src.Size());
std::copy(linalg::cbegin(src), linalg::cend(src), dst->begin());
return *dst;
}
inline GradStats &CopyStats(GradStats const &src, GradStats *dst) { // NOLINT
*dst = src;
return *dst;
}
/*! /*!
* \brief statistics that is helpful to store * \brief statistics that is helpful to store
* and represent a split solution for the tree * and represent a split solution for the tree
@ -378,9 +420,9 @@ struct SplitEntryContainer {
return os; return os;
} }
/*!\return feature index to split on */ /*!\return feature index to split on */
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); } [[nodiscard]] bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
/*!\return whether missing value goes to left branch */ /*!\return whether missing value goes to left branch */
bool DefaultLeft() const { return (sindex >> 31) != 0; } [[nodiscard]] bool DefaultLeft() const { return (sindex >> 31) != 0; }
/*! /*!
* \brief decides whether we can replace current entry with the given statistics * \brief decides whether we can replace current entry with the given statistics
* *
@ -391,10 +433,10 @@ struct SplitEntryContainer {
* \param new_loss_chg the loss reduction get through the split * \param new_loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on * \param split_index the feature index where the split is on
*/ */
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { [[nodiscard]] bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
if (std::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, if (std::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
// for example when lambda = 0 & min_child_weight = 0 // for example when lambda = 0 & min_child_weight = 0
// skip value in this case // skip value in this case
return false; return false;
} else if (this->SplitIndex() <= split_index) { } else if (this->SplitIndex() <= split_index) {
return new_loss_chg > this->loss_chg; return new_loss_chg > this->loss_chg;
@ -429,9 +471,10 @@ struct SplitEntryContainer {
* \param default_left whether the missing value goes to left * \param default_left whether the missing value goes to left
* \return whether the proposed split is better and can replace current split * \return whether the proposed split is better and can replace current split
*/ */
bool Update(bst_float new_loss_chg, unsigned split_index, template <typename GradientSumT>
bst_float new_split_value, bool default_left, bool is_cat, bool Update(bst_float new_loss_chg, unsigned split_index, bst_float new_split_value,
const GradientT &left_sum, const GradientT &right_sum) { bool default_left, bool is_cat, GradientSumT const &left_sum,
GradientSumT const &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) { if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg; this->loss_chg = new_loss_chg;
if (default_left) { if (default_left) {
@ -440,8 +483,8 @@ struct SplitEntryContainer {
this->sindex = split_index; this->sindex = split_index;
this->split_value = new_split_value; this->split_value = new_split_value;
this->is_cat = is_cat; this->is_cat = is_cat;
this->left_sum = left_sum; CopyStats(left_sum, &this->left_sum);
this->right_sum = right_sum; CopyStats(right_sum, &this->right_sum);
return true; return true;
} else { } else {
return false; return false;

View File

@ -304,7 +304,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1}; thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
// Setup gradients so that second feature gets higher gain // Setup gradients so that second feature gets higher gain
auto feature_histogram = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}}); auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
dh::device_vector<FeatureType> feature_types(feature_set.size(), dh::device_vector<FeatureType> feature_types(feature_set.size(),
FeatureType::kCategorical); FeatureType::kCategorical);

View File

@ -1,18 +1,27 @@
/** /**
* Copyright 2021-2023 by XGBoost Contributors * Copyright 2021-2023 by XGBoost Contributors
*/ */
#include <gtest/gtest.h>
#include <xgboost/base.h>
#include "../../../../src/common/hist_util.h"
#include "../../../../src/tree/common_row_partitioner.h"
#include "../../../../src/tree/hist/evaluate_splits.h"
#include "../test_evaluate_splits.h" #include "../test_evaluate_splits.h"
#include "../../helpers.h"
#include "xgboost/context.h" // Context
namespace xgboost { #include <gtest/gtest.h>
namespace tree { #include <xgboost/base.h> // for GradientPairPrecise, Args, Gradie...
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for FeatureType, DMatrix, MetaInfo
#include <xgboost/logging.h> // for CHECK_EQ
#include <xgboost/tree_model.h> // for RegTree, RTreeNodeStat
#include <memory> // for make_shared, shared_ptr, addressof
#include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts
#include "../../../../src/common/random.h" // for ColumnSampler
#include "../../../../src/common/row_set.h" // for RowSetCollection
#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator
#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
#include "../../../../src/tree/param.h" // for GradStats, TrainParam
#include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo...
namespace xgboost::tree {
void TestEvaluateSplits(bool force_read_by_column) { void TestEvaluateSplits(bool force_read_by_column) {
Context ctx; Context ctx;
ctx.nthread = 4; ctx.nthread = 4;
@ -87,6 +96,68 @@ TEST(HistEvaluator, Evaluate) {
TestEvaluateSplits(true); TestEvaluateSplits(true);
} }
TEST(HistMultiEvaluator, Evaluate) {
Context ctx;
ctx.nthread = 1;
TrainParam param;
param.Init(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
auto sampler = std::make_shared<common::ColumnSampler>();
std::size_t n_samples = 3;
bst_feature_t n_features = 2;
bst_target_t n_targets = 2;
bst_bin_t n_bins = 2;
auto p_fmat =
RandomDataGenerator{n_samples, n_features, 0.5}.Targets(n_targets).GenerateDMatrix(true);
HistMultiEvaluator evaluator{&ctx, p_fmat->Info(), &param, sampler};
std::vector<common::HistCollection> histogram(n_targets);
linalg::Vector<GradientPairPrecise> root_sum({2}, Context::kCpuId);
for (bst_target_t t{0}; t < n_targets; ++t) {
auto &hist = histogram[t];
hist.Init(n_bins * n_features);
hist.AddHistRow(0);
hist.AllocateAllData();
auto node_hist = hist[0];
node_hist[0] = {-0.5, 0.5};
node_hist[1] = {2.0, 0.5};
node_hist[2] = {0.5, 0.5};
node_hist[3] = {1.0, 0.5};
root_sum(t) += node_hist[0];
root_sum(t) += node_hist[1];
}
RegTree tree{n_targets, n_features};
auto weight = evaluator.InitRoot(root_sum.HostView());
tree.SetLeaf(RegTree::kRoot, weight.HostView());
auto w = weight.HostView();
ASSERT_EQ(w.Size(), n_targets);
ASSERT_EQ(w(0), -1.5);
ASSERT_EQ(w(1), -1.5);
common::HistogramCuts cuts;
cuts.cut_ptrs_ = {0, 2, 4};
cuts.cut_values_ = {0.5, 1.0, 2.0, 3.0};
cuts.min_vals_ = {-0.2, 1.8};
std::vector<MultiExpandEntry> entries(1, {/*nidx=*/0, /*depth=*/0});
std::vector<common::HistCollection const *> ptrs;
std::transform(histogram.cbegin(), histogram.cend(), std::back_inserter(ptrs),
[](auto const &h) { return std::addressof(h); });
evaluator.EvaluateSplits(tree, ptrs, cuts, &entries);
ASSERT_EQ(entries.front().split.loss_chg, 12.5);
ASSERT_EQ(entries.front().split.split_value, 0.5);
ASSERT_EQ(entries.front().split.SplitIndex(), 0);
ASSERT_EQ(sampler->GetFeatureSet(0)->Size(), n_features);
}
TEST(HistEvaluator, Apply) { TEST(HistEvaluator, Apply) {
Context ctx; Context ctx;
ctx.nthread = 4; ctx.nthread = 4;
@ -211,12 +282,11 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
std::vector<CPUExpandEntry> entries(1); std::vector<CPUExpandEntry> entries(1);
RegTree tree; RegTree tree;
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries); evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
auto const& split = entries.front().split; auto const &split = entries.front().split;
this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat, this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat,
split.DefaultLeft(), split.DefaultLeft(),
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()}, GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()}); GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
} }
} // namespace tree } // namespace xgboost::tree
} // namespace xgboost

View File

@ -2,15 +2,26 @@
* Copyright 2022-2023 by XGBoost Contributors * Copyright 2022-2023 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/data.h> #include <xgboost/base.h> // for GradientPairInternal, GradientPairPrecise
#include <xgboost/data.h> // for MetaInfo
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/span.h> // for operator!=, Span, SpanIterator
#include <algorithm> // next_permutation #include <algorithm> // for max, max_element, next_permutation, copy
#include <numeric> // iota #include <cmath> // for isnan
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint64_t, uint32_t
#include <limits> // for numeric_limits
#include <numeric> // for iota
#include <tuple> // for make_tuple, tie, tuple
#include <utility> // for pair
#include <vector> // for vector
#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection #include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow
#include "../../../src/tree/param.h" // TrainParam #include "../../../src/tree/param.h" // for TrainParam, GradStats
#include "../../../src/tree/split_evaluator.h" #include "../../../src/tree/split_evaluator.h" // for TreeEvaluator
#include "../helpers.h" #include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution
#include "gtest/gtest_pred_impl.h" // for AssertionResult, ASSERT_EQ, ASSERT_TRUE
namespace xgboost::tree { namespace xgboost::tree {
/** /**