Define multi expand entry. (#8895)
This commit is contained in:
@@ -1,29 +1,51 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2021-2023 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||
|
||||
#include <utility>
|
||||
#include "../param.h"
|
||||
#include <algorithm> // for all_of
|
||||
#include <ostream> // for ostream
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
#include "../param.h" // for SplitEntry, SplitEntryContainer, TrainParam
|
||||
#include "xgboost/base.h" // for GradientPairPrecise, bst_node_t
|
||||
|
||||
struct CPUExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
SplitEntry split;
|
||||
CPUExpandEntry() = default;
|
||||
XGBOOST_DEVICE
|
||||
CPUExpandEntry(int nid, int depth, SplitEntry split)
|
||||
: nid(nid), depth(depth), split(std::move(split)) {}
|
||||
CPUExpandEntry(int nid, int depth, float loss_chg)
|
||||
: nid(nid), depth(depth) {
|
||||
split.loss_chg = loss_chg;
|
||||
namespace xgboost::tree {
|
||||
/**
|
||||
* \brief Structure for storing tree split candidate.
|
||||
*/
|
||||
template <typename Impl>
|
||||
struct ExpandEntryImpl {
|
||||
bst_node_t nid;
|
||||
bst_node_t depth;
|
||||
|
||||
[[nodiscard]] float GetLossChange() const {
|
||||
return static_cast<Impl const*>(this)->split.loss_chg;
|
||||
}
|
||||
[[nodiscard]] bst_node_t GetNodeId() const { return nid; }
|
||||
|
||||
static bool ChildIsValid(TrainParam const& param, bst_node_t depth, bst_node_t num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
[[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
|
||||
return static_cast<Impl const*>(this)->IsValidImpl(param, num_leaves);
|
||||
}
|
||||
};
|
||||
|
||||
struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
|
||||
SplitEntry split;
|
||||
|
||||
CPUExpandEntry() = default;
|
||||
CPUExpandEntry(bst_node_t nidx, bst_node_t depth, SplitEntry split)
|
||||
: ExpandEntryImpl{nidx, depth}, split(std::move(split)) {}
|
||||
CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
||||
|
||||
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
@@ -40,16 +62,7 @@ struct CPUExpandEntry {
|
||||
return true;
|
||||
}
|
||||
|
||||
float GetLossChange() const { return split.loss_chg; }
|
||||
bst_node_t GetNodeId() const { return nid; }
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
|
||||
friend std::ostream& operator<<(std::ostream& os, CPUExpandEntry const& e) {
|
||||
os << "ExpandEntry:\n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
@@ -58,6 +71,54 @@ struct CPUExpandEntry {
|
||||
return os;
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
|
||||
SplitEntryContainer<std::vector<GradientPairPrecise>> split;
|
||||
|
||||
MultiExpandEntry() = default;
|
||||
MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
||||
|
||||
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
auto is_zero = [](auto const& sum) {
|
||||
return std::all_of(sum.cbegin(), sum.cend(),
|
||||
[&](auto const& g) { return g.GetHess() - .0 == .0; });
|
||||
};
|
||||
if (is_zero(split.left_sum) || is_zero(split.right_sum)) {
|
||||
return false;
|
||||
}
|
||||
if (split.loss_chg < param.min_split_loss) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) {
|
||||
os << "ExpandEntry: \n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
os << "loss: " << e.split.loss_chg << "\n";
|
||||
os << "split cond:" << e.split.split_value << "\n";
|
||||
os << "split ind:" << e.split.SplitIndex() << "\n";
|
||||
os << "left_sum: [";
|
||||
for (auto v : e.split.left_sum) {
|
||||
os << v << ", ";
|
||||
}
|
||||
os << "]\n";
|
||||
|
||||
os << "right_sum: [";
|
||||
for (auto v : e.split.right_sum) {
|
||||
os << v << ", ";
|
||||
}
|
||||
os << "]\n";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||
|
||||
@@ -226,8 +226,8 @@ class GloablApproxBuilder {
|
||||
for (auto const &candidate : valid_candidates) {
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}};
|
||||
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}};
|
||||
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx)};
|
||||
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx)};
|
||||
best_splits.push_back(l_best);
|
||||
best_splits.push_back(r_best);
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
|
||||
|
||||
CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
|
||||
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
|
||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);
|
||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
||||
|
||||
size_t page_id = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, {node});
|
||||
@@ -197,8 +197,8 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
|
||||
for (auto const &candidate : valid_candidates) {
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
CPUExpandEntry l_best{left_child_nidx, depth, 0.0};
|
||||
CPUExpandEntry r_best{right_child_nidx, depth, 0.0};
|
||||
CPUExpandEntry l_best{left_child_nidx, depth};
|
||||
CPUExpandEntry r_best{right_child_nidx, depth};
|
||||
best_splits.push_back(l_best);
|
||||
best_splits.push_back(r_best);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user