Implement hist evaluator for multi-target tree. (#8908)
This commit is contained in:
parent
95e2baf7c2
commit
8685556af2
@ -7,23 +7,22 @@
|
||||
#ifndef XGBOOST_COMMON_HIST_UTIL_H_
|
||||
#define XGBOOST_COMMON_HIST_UTIL_H_
|
||||
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "algorithm.h" // SegmentId
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "quantile.h"
|
||||
#include "row_set.h"
|
||||
#include "threading_utils.h"
|
||||
#include "timer.h"
|
||||
#include "xgboost/base.h" // bst_feature_t, bst_bin_t
|
||||
#include "xgboost/base.h" // for bst_feature_t, bst_bin_t
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
class GHistIndexMatrix;
|
||||
@ -392,15 +391,18 @@ class HistCollection {
|
||||
}
|
||||
|
||||
// have we computed a histogram for i-th node?
|
||||
bool RowExists(bst_uint nid) const {
|
||||
[[nodiscard]] bool RowExists(bst_uint nid) const {
|
||||
const uint32_t k_max = std::numeric_limits<uint32_t>::max();
|
||||
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
|
||||
}
|
||||
|
||||
// initialize histogram collection
|
||||
void Init(uint32_t nbins) {
|
||||
if (nbins_ != nbins) {
|
||||
nbins_ = nbins;
|
||||
/**
|
||||
* \brief Initialize histogram collection.
|
||||
*
|
||||
* \param n_total_bins Number of bins across all features.
|
||||
*/
|
||||
void Init(std::uint32_t n_total_bins) {
|
||||
if (nbins_ != n_total_bins) {
|
||||
nbins_ = n_total_bins;
|
||||
// quite expensive operation, so let's do this only once
|
||||
data_.clear();
|
||||
}
|
||||
|
||||
@ -99,22 +99,25 @@ class CommonRowPartitioner {
|
||||
|
||||
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
const int32_t nid = nodes[i].nid;
|
||||
const bst_uint fid = tree[nid].SplitIndex();
|
||||
const bst_float split_pt = tree[nid].SplitCond();
|
||||
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
|
||||
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
|
||||
auto const& ptrs = gmat.cut.Ptrs();
|
||||
auto const& vals = gmat.cut.Values();
|
||||
|
||||
for (std::size_t i = 0; i < nodes.size(); ++i) {
|
||||
bst_node_t const nid = nodes[i].nid;
|
||||
bst_feature_t const fid = tree[nid].SplitIndex();
|
||||
const float split_pt = tree[nid].SplitCond();
|
||||
const uint32_t lower_bound = ptrs[fid];
|
||||
const uint32_t upper_bound = ptrs[fid + 1];
|
||||
bst_bin_t split_cond = -1;
|
||||
// convert floating-point split_pt into corresponding bin_id
|
||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
|
||||
if (split_pt == gmat.cut.Values()[bound]) {
|
||||
split_cond = static_cast<int32_t>(bound);
|
||||
if (split_pt == vals[bound]) {
|
||||
split_cond = static_cast<bst_bin_t>(bound);
|
||||
}
|
||||
}
|
||||
(*split_conditions).at(i) = split_cond;
|
||||
(*split_conditions)[i] = split_cond;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -4,22 +4,25 @@
|
||||
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <algorithm> // for copy
|
||||
#include <cstddef> // for size_t
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../common/categorical.h"
|
||||
#include "../../common/hist_util.h"
|
||||
#include "../../common/random.h"
|
||||
#include "../../data/gradient_index.h"
|
||||
#include "../constraints.h"
|
||||
#include "../../common/categorical.h" // for CatBitField
|
||||
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
|
||||
#include "../../common/linalg_op.h" // for cbegin, cend, begin
|
||||
#include "../../common/random.h" // for ColumnSampler
|
||||
#include "../constraints.h" // for FeatureInteractionConstraintHost
|
||||
#include "../param.h" // for TrainParam
|
||||
#include "../split_evaluator.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "../split_evaluator.h" // for TreeEvaluator
|
||||
#include "expand_entry.h" // for MultiExpandEntry
|
||||
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
|
||||
#include "xgboost/context.h" // for COntext
|
||||
#include "xgboost/linalg.h" // for Constants, Vector
|
||||
|
||||
namespace xgboost::tree {
|
||||
template <typename ExpandEntry>
|
||||
@ -410,8 +413,6 @@ class HistEvaluator {
|
||||
tree[candidate.nid].SplitIndex(), left_weight,
|
||||
right_weight);
|
||||
|
||||
auto max_node = std::max(left_child, tree[candidate.nid].RightChild());
|
||||
max_node = std::max(candidate.nid, max_node);
|
||||
snode_.resize(tree.GetNodes().size());
|
||||
snode_.at(left_child).stats = candidate.split.left_sum;
|
||||
snode_.at(left_child).root_gain =
|
||||
@ -456,6 +457,216 @@ class HistEvaluator {
|
||||
}
|
||||
};
|
||||
|
||||
class HistMultiEvaluator {
|
||||
std::vector<double> gain_;
|
||||
linalg::Matrix<GradientPairPrecise> stats_;
|
||||
TrainParam const *param_;
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
Context const *ctx_;
|
||||
|
||||
private:
|
||||
static double MultiCalcSplitGain(TrainParam const ¶m,
|
||||
linalg::VectorView<GradientPairPrecise const> left_sum,
|
||||
linalg::VectorView<GradientPairPrecise const> right_sum,
|
||||
linalg::VectorView<float> left_weight,
|
||||
linalg::VectorView<float> right_weight) {
|
||||
CalcWeight(param, left_sum, left_weight);
|
||||
CalcWeight(param, right_sum, right_weight);
|
||||
|
||||
auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
|
||||
auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
|
||||
return left_gain + right_gain;
|
||||
}
|
||||
|
||||
template <bst_bin_t d_step>
|
||||
bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
|
||||
common::Span<common::GHistRow const> hist,
|
||||
linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
|
||||
SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
auto const &cut_val = cut.Values();
|
||||
auto const &min_val = cut.MinValues();
|
||||
|
||||
auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
|
||||
auto left_sum = sum.Slice(0, linalg::All());
|
||||
auto right_sum = sum.Slice(1, linalg::All());
|
||||
|
||||
bst_bin_t ibegin, iend;
|
||||
if (d_step > 0) {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
|
||||
iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
|
||||
} else {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
|
||||
}
|
||||
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
|
||||
|
||||
auto n_targets = hist.size();
|
||||
auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
|
||||
auto left_weight = weight.Slice(0, linalg::All());
|
||||
auto right_weight = weight.Slice(1, linalg::All());
|
||||
|
||||
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
|
||||
for (bst_target_t t = 0; t < n_targets; ++t) {
|
||||
auto t_hist = hist[t];
|
||||
auto t_p = parent_sum(t);
|
||||
left_sum(t) += t_hist[i];
|
||||
right_sum(t) = t_p - left_sum(t);
|
||||
}
|
||||
|
||||
if (d_step > 0) {
|
||||
auto split_pt = cut_val[i];
|
||||
auto loss_chg =
|
||||
MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
|
||||
parent_gain;
|
||||
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
|
||||
} else {
|
||||
float split_pt;
|
||||
if (i == imin) {
|
||||
split_pt = min_val[fidx];
|
||||
} else {
|
||||
split_pt = cut_val[i - 1];
|
||||
}
|
||||
auto loss_chg =
|
||||
MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
|
||||
parent_gain;
|
||||
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
|
||||
}
|
||||
}
|
||||
// return true if there's missing. Doesn't handle floating-point error well.
|
||||
if (d_step == +1) {
|
||||
return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
|
||||
linalg::cbegin(parent_sum));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
|
||||
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
|
||||
auto &entries = *p_entries;
|
||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
|
||||
|
||||
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||
auto nidx = entries[nidx_in_set].nid;
|
||||
features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
|
||||
}
|
||||
CHECK(!features.empty());
|
||||
|
||||
std::int32_t n_threads = ctx_->Threads();
|
||||
std::size_t const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
|
||||
common::BlockedSpace2d space(
|
||||
entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
|
||||
grain_size);
|
||||
|
||||
std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
|
||||
for (std::size_t i = 0; i < entries.size(); ++i) {
|
||||
for (std::int32_t j = 0; j < n_threads; ++j) {
|
||||
tloc_candidates[i * n_threads + j] = entries[i];
|
||||
}
|
||||
}
|
||||
common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
|
||||
auto tidx = omp_get_thread_num();
|
||||
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
|
||||
auto best = &entry->split;
|
||||
auto parent_sum = stats_.Slice(entry->nid, linalg::All());
|
||||
std::vector<common::GHistRow> node_hist;
|
||||
for (auto t_hist : hist) {
|
||||
node_hist.push_back((*t_hist)[entry->nid]);
|
||||
}
|
||||
auto features_set = features[nidx_in_set]->ConstHostSpan();
|
||||
|
||||
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
|
||||
auto fidx = features_set[fidx_in_set];
|
||||
if (!interaction_constraints_.Query(entry->nid, fidx)) {
|
||||
continue;
|
||||
}
|
||||
auto parent_gain = gain_[entry->nid];
|
||||
bool missing =
|
||||
this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
|
||||
if (missing) {
|
||||
this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||
for (auto tidx = 0; tidx < n_threads; ++tidx) {
|
||||
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
|
||||
auto n_targets = root_sum.Size();
|
||||
stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets);
|
||||
gain_.resize(1);
|
||||
|
||||
linalg::Vector<float> weight({n_targets}, ctx_->gpu_id);
|
||||
CalcWeight(*param_, root_sum, weight.HostView());
|
||||
auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
|
||||
gain_.front() = root_gain;
|
||||
|
||||
auto h_stats = stats_.HostView();
|
||||
std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
|
||||
auto n_targets = p_tree->NumTargets();
|
||||
auto parent_sum = stats_.Slice(candidate.nid, linalg::All());
|
||||
|
||||
auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
|
||||
auto base_weight = weight.Slice(0, linalg::All());
|
||||
CalcWeight(*param_, parent_sum, base_weight);
|
||||
|
||||
auto left_weight = weight.Slice(1, linalg::All());
|
||||
auto left_sum =
|
||||
linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
|
||||
CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);
|
||||
|
||||
auto right_weight = weight.Slice(2, linalg::All());
|
||||
auto right_sum =
|
||||
linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
|
||||
CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);
|
||||
|
||||
p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
|
||||
candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
|
||||
CHECK(p_tree->IsMultiTarget());
|
||||
auto left_child = p_tree->LeftChild(candidate.nid);
|
||||
CHECK_GT(left_child, candidate.nid);
|
||||
auto right_child = p_tree->RightChild(candidate.nid);
|
||||
CHECK_GT(right_child, candidate.nid);
|
||||
|
||||
std::size_t n_nodes = p_tree->Size();
|
||||
gain_.resize(n_nodes);
|
||||
gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
|
||||
gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);
|
||||
|
||||
if (n_nodes >= stats_.Shape(0)) {
|
||||
stats_.Reshape(n_nodes * 2, stats_.Shape(1));
|
||||
}
|
||||
CHECK_EQ(stats_.Shape(1), n_targets);
|
||||
auto left_sum_stat = stats_.Slice(left_child, linalg::All());
|
||||
std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
|
||||
linalg::begin(left_sum_stat));
|
||||
auto right_sum_stat = stats_.Slice(right_child, linalg::All());
|
||||
std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
|
||||
linalg::begin(right_sum_stat));
|
||||
}
|
||||
|
||||
explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
|
||||
std::shared_ptr<common::ColumnSampler> sampler)
|
||||
: param_{param}, column_sampler_{std::move(sampler)}, ctx_{ctx} {
|
||||
interaction_constraints_.Configure(*param, info.num_col_);
|
||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||
param_->colsample_bynode, param_->colsample_bylevel,
|
||||
param_->colsample_bytree);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief CPU implementation of update prediction cache, which calculates the leaf value
|
||||
* for the last tree and accumulates it to prediction vector.
|
||||
|
||||
@ -14,10 +14,12 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/parameter.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -197,12 +199,11 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
}
|
||||
|
||||
/*! \brief given the loss change, whether we need to invoke pruning */
|
||||
bool NeedPrune(double loss_chg, int depth) const {
|
||||
return loss_chg < this->min_split_loss ||
|
||||
(this->max_depth != 0 && depth > this->max_depth);
|
||||
[[nodiscard]] bool NeedPrune(double loss_chg, int depth) const {
|
||||
return loss_chg < this->min_split_loss || (this->max_depth != 0 && depth > this->max_depth);
|
||||
}
|
||||
|
||||
bst_node_t MaxNodes() const {
|
||||
[[nodiscard]] bst_node_t MaxNodes() const {
|
||||
if (this->max_depth == 0 && this->max_leaves == 0) {
|
||||
LOG(FATAL) << "Max leaves and max depth cannot both be unconstrained.";
|
||||
}
|
||||
@ -292,6 +293,34 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad)
|
||||
return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief multi-target weight, calculated with learning rate.
|
||||
*/
|
||||
inline void CalcWeight(TrainParam const &p, linalg::VectorView<GradientPairPrecise const> grad_sum,
|
||||
float eta, linalg::VectorView<float> out_w) {
|
||||
for (bst_target_t i = 0; i < out_w.Size(); ++i) {
|
||||
out_w(i) = CalcWeight(p, grad_sum(i).GetGrad(), grad_sum(i).GetHess()) * eta;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief multi-target weight
|
||||
*/
|
||||
inline void CalcWeight(TrainParam const &p, linalg::VectorView<GradientPairPrecise const> grad_sum,
|
||||
linalg::VectorView<float> out_w) {
|
||||
return CalcWeight(p, grad_sum, 1.0f, out_w);
|
||||
}
|
||||
|
||||
inline double CalcGainGivenWeight(TrainParam const &p,
|
||||
linalg::VectorView<GradientPairPrecise const> sum_grad,
|
||||
linalg::VectorView<float const> weight) {
|
||||
double gain{0};
|
||||
for (bst_target_t i = 0; i < weight.Size(); ++i) {
|
||||
gain += -weight(i) * ThresholdL1(sum_grad(i).GetGrad(), p.reg_alpha);
|
||||
}
|
||||
return gain;
|
||||
}
|
||||
|
||||
/*! \brief core statistics used for tree construction */
|
||||
struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
using GradType = double;
|
||||
@ -301,8 +330,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
GradType sum_hess { 0 };
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; }
|
||||
XGBOOST_DEVICE GradType GetHess() const { return sum_hess; }
|
||||
[[nodiscard]] XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; }
|
||||
[[nodiscard]] XGBOOST_DEVICE GradType GetHess() const { return sum_hess; }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, GradStats s) {
|
||||
os << s.GetGrad() << "/" << s.GetHess();
|
||||
@ -340,7 +369,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
sum_hess = a.sum_hess - b.sum_hess;
|
||||
}
|
||||
/*! \return whether the statistics is not used yet */
|
||||
inline bool Empty() const { return sum_hess == 0.0; }
|
||||
[[nodiscard]] bool Empty() const { return sum_hess == 0.0; }
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(GradType grad, GradType hess) {
|
||||
sum_grad += grad;
|
||||
@ -348,6 +377,19 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
}
|
||||
};
|
||||
|
||||
// Helper functions for copying gradient statistic, one for vector leaf, another for normal scalar.
|
||||
template <typename T, typename U>
|
||||
std::vector<T> &CopyStats(linalg::VectorView<U> const &src, std::vector<T> *dst) { // NOLINT
|
||||
dst->resize(src.Size());
|
||||
std::copy(linalg::cbegin(src), linalg::cend(src), dst->begin());
|
||||
return *dst;
|
||||
}
|
||||
|
||||
inline GradStats &CopyStats(GradStats const &src, GradStats *dst) { // NOLINT
|
||||
*dst = src;
|
||||
return *dst;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief statistics that is helpful to store
|
||||
* and represent a split solution for the tree
|
||||
@ -378,9 +420,9 @@ struct SplitEntryContainer {
|
||||
return os;
|
||||
}
|
||||
/*!\return feature index to split on */
|
||||
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||
[[nodiscard]] bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||
/*!\return whether missing value goes to left branch */
|
||||
bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
||||
[[nodiscard]] bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
||||
/*!
|
||||
* \brief decides whether we can replace current entry with the given statistics
|
||||
*
|
||||
@ -391,7 +433,7 @@ struct SplitEntryContainer {
|
||||
* \param new_loss_chg the loss reduction get through the split
|
||||
* \param split_index the feature index where the split is on
|
||||
*/
|
||||
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||
[[nodiscard]] bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||
if (std::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
|
||||
// for example when lambda = 0 & min_child_weight = 0
|
||||
// skip value in this case
|
||||
@ -429,9 +471,10 @@ struct SplitEntryContainer {
|
||||
* \param default_left whether the missing value goes to left
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||
bst_float new_split_value, bool default_left, bool is_cat,
|
||||
const GradientT &left_sum, const GradientT &right_sum) {
|
||||
template <typename GradientSumT>
|
||||
bool Update(bst_float new_loss_chg, unsigned split_index, bst_float new_split_value,
|
||||
bool default_left, bool is_cat, GradientSumT const &left_sum,
|
||||
GradientSumT const &right_sum) {
|
||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||
this->loss_chg = new_loss_chg;
|
||||
if (default_left) {
|
||||
@ -440,8 +483,8 @@ struct SplitEntryContainer {
|
||||
this->sindex = split_index;
|
||||
this->split_value = new_split_value;
|
||||
this->is_cat = is_cat;
|
||||
this->left_sum = left_sum;
|
||||
this->right_sum = right_sum;
|
||||
CopyStats(left_sum, &this->left_sum);
|
||||
CopyStats(right_sum, &this->right_sum);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
|
||||
@ -304,7 +304,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
auto feature_histogram = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(),
|
||||
FeatureType::kCategorical);
|
||||
|
||||
@ -1,18 +1,27 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h>
|
||||
|
||||
#include "../../../../src/common/hist_util.h"
|
||||
#include "../../../../src/tree/common_row_partitioner.h"
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||
#include "../test_evaluate_splits.h"
|
||||
#include "../../helpers.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h> // for GradientPairPrecise, Args, Gradie...
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/data.h> // for FeatureType, DMatrix, MetaInfo
|
||||
#include <xgboost/logging.h> // for CHECK_EQ
|
||||
#include <xgboost/tree_model.h> // for RegTree, RTreeNodeStat
|
||||
|
||||
#include <memory> // for make_shared, shared_ptr, addressof
|
||||
|
||||
#include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts
|
||||
#include "../../../../src/common/random.h" // for ColumnSampler
|
||||
#include "../../../../src/common/row_set.h" // for RowSetCollection
|
||||
#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator
|
||||
#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
|
||||
#include "../../../../src/tree/param.h" // for GradStats, TrainParam
|
||||
#include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo...
|
||||
|
||||
namespace xgboost::tree {
|
||||
void TestEvaluateSplits(bool force_read_by_column) {
|
||||
Context ctx;
|
||||
ctx.nthread = 4;
|
||||
@ -87,6 +96,68 @@ TEST(HistEvaluator, Evaluate) {
|
||||
TestEvaluateSplits(true);
|
||||
}
|
||||
|
||||
TEST(HistMultiEvaluator, Evaluate) {
|
||||
Context ctx;
|
||||
ctx.nthread = 1;
|
||||
|
||||
TrainParam param;
|
||||
param.Init(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
|
||||
std::size_t n_samples = 3;
|
||||
bst_feature_t n_features = 2;
|
||||
bst_target_t n_targets = 2;
|
||||
bst_bin_t n_bins = 2;
|
||||
|
||||
auto p_fmat =
|
||||
RandomDataGenerator{n_samples, n_features, 0.5}.Targets(n_targets).GenerateDMatrix(true);
|
||||
|
||||
HistMultiEvaluator evaluator{&ctx, p_fmat->Info(), ¶m, sampler};
|
||||
std::vector<common::HistCollection> histogram(n_targets);
|
||||
linalg::Vector<GradientPairPrecise> root_sum({2}, Context::kCpuId);
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto &hist = histogram[t];
|
||||
hist.Init(n_bins * n_features);
|
||||
hist.AddHistRow(0);
|
||||
hist.AllocateAllData();
|
||||
auto node_hist = hist[0];
|
||||
node_hist[0] = {-0.5, 0.5};
|
||||
node_hist[1] = {2.0, 0.5};
|
||||
node_hist[2] = {0.5, 0.5};
|
||||
node_hist[3] = {1.0, 0.5};
|
||||
|
||||
root_sum(t) += node_hist[0];
|
||||
root_sum(t) += node_hist[1];
|
||||
}
|
||||
|
||||
RegTree tree{n_targets, n_features};
|
||||
auto weight = evaluator.InitRoot(root_sum.HostView());
|
||||
tree.SetLeaf(RegTree::kRoot, weight.HostView());
|
||||
auto w = weight.HostView();
|
||||
ASSERT_EQ(w.Size(), n_targets);
|
||||
ASSERT_EQ(w(0), -1.5);
|
||||
ASSERT_EQ(w(1), -1.5);
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_ptrs_ = {0, 2, 4};
|
||||
cuts.cut_values_ = {0.5, 1.0, 2.0, 3.0};
|
||||
cuts.min_vals_ = {-0.2, 1.8};
|
||||
|
||||
std::vector<MultiExpandEntry> entries(1, {/*nidx=*/0, /*depth=*/0});
|
||||
|
||||
std::vector<common::HistCollection const *> ptrs;
|
||||
std::transform(histogram.cbegin(), histogram.cend(), std::back_inserter(ptrs),
|
||||
[](auto const &h) { return std::addressof(h); });
|
||||
|
||||
evaluator.EvaluateSplits(tree, ptrs, cuts, &entries);
|
||||
|
||||
ASSERT_EQ(entries.front().split.loss_chg, 12.5);
|
||||
ASSERT_EQ(entries.front().split.split_value, 0.5);
|
||||
ASSERT_EQ(entries.front().split.SplitIndex(), 0);
|
||||
|
||||
ASSERT_EQ(sampler->GetFeatureSet(0)->Size(), n_features);
|
||||
}
|
||||
|
||||
TEST(HistEvaluator, Apply) {
|
||||
Context ctx;
|
||||
ctx.nthread = 4;
|
||||
@ -211,12 +282,11 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
RegTree tree;
|
||||
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
|
||||
auto const& split = entries.front().split;
|
||||
auto const &split = entries.front().split;
|
||||
|
||||
this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat,
|
||||
split.DefaultLeft(),
|
||||
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
|
||||
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -2,15 +2,26 @@
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/base.h> // for GradientPairInternal, GradientPairPrecise
|
||||
#include <xgboost/data.h> // for MetaInfo
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/span.h> // for operator!=, Span, SpanIterator
|
||||
|
||||
#include <algorithm> // next_permutation
|
||||
#include <numeric> // iota
|
||||
#include <algorithm> // for max, max_element, next_permutation, copy
|
||||
#include <cmath> // for isnan
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, uint64_t, uint32_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <numeric> // for iota
|
||||
#include <tuple> // for make_tuple, tie, tuple
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection
|
||||
#include "../../../src/tree/param.h" // TrainParam
|
||||
#include "../../../src/tree/split_evaluator.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow
|
||||
#include "../../../src/tree/param.h" // for TrainParam, GradStats
|
||||
#include "../../../src/tree/split_evaluator.h" // for TreeEvaluator
|
||||
#include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution
|
||||
#include "gtest/gtest_pred_impl.h" // for AssertionResult, ASSERT_EQ, ASSERT_TRUE
|
||||
|
||||
namespace xgboost::tree {
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user