diff --git a/src/tree/hist/expand_entry.h b/src/tree/hist/expand_entry.h index 885a109bf..acd6edf2b 100644 --- a/src/tree/hist/expand_entry.h +++ b/src/tree/hist/expand_entry.h @@ -1,29 +1,51 @@ -/*! - * Copyright 2021 XGBoost contributors +/** + * Copyright 2021-2023 XGBoost contributors */ #ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ #define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ -#include -#include "../param.h" +#include // for all_of +#include // for ostream +#include // for move +#include // for vector -namespace xgboost { -namespace tree { +#include "../param.h" // for SplitEntry, SplitEntryContainer, TrainParam +#include "xgboost/base.h" // for GradientPairPrecise, bst_node_t -struct CPUExpandEntry { - int nid; - int depth; - SplitEntry split; - CPUExpandEntry() = default; - XGBOOST_DEVICE - CPUExpandEntry(int nid, int depth, SplitEntry split) - : nid(nid), depth(depth), split(std::move(split)) {} - CPUExpandEntry(int nid, int depth, float loss_chg) - : nid(nid), depth(depth) { - split.loss_chg = loss_chg; +namespace xgboost::tree { +/** + * \brief Structure for storing tree split candidate. + */ +template +struct ExpandEntryImpl { + bst_node_t nid; + bst_node_t depth; + + [[nodiscard]] float GetLossChange() const { + return static_cast(this)->split.loss_chg; + } + [[nodiscard]] bst_node_t GetNodeId() const { return nid; } + + static bool ChildIsValid(TrainParam const& param, bst_node_t depth, bst_node_t num_leaves) { + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; + return true; } - bool IsValid(const TrainParam& param, int num_leaves) const { + [[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const { + return static_cast(this)->IsValidImpl(param, num_leaves); + } +}; + +struct CPUExpandEntry : public ExpandEntryImpl { + SplitEntry split; + + CPUExpandEntry() = default; + CPUExpandEntry(bst_node_t nidx, bst_node_t depth, SplitEntry split) + : ExpandEntryImpl{nidx, depth}, split(std::move(split)) {} + CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {} + + [[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const { if (split.loss_chg <= kRtEps) return false; if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { return false; @@ -40,16 +62,7 @@ struct CPUExpandEntry { return true; } - float GetLossChange() const { return split.loss_chg; } - bst_node_t GetNodeId() const { return nid; } - - static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { - if (param.max_depth > 0 && depth >= param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; - return true; - } - - friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) { + friend std::ostream& operator<<(std::ostream& os, CPUExpandEntry const& e) { os << "ExpandEntry:\n"; os << "nidx: " << e.nid << "\n"; os << "depth: " << e.depth << "\n"; @@ -58,6 +71,54 @@ struct CPUExpandEntry { return os; } }; -} // namespace tree -} // namespace xgboost + +struct MultiExpandEntry : public ExpandEntryImpl { + SplitEntryContainer> split; + + MultiExpandEntry() = default; + MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {} + + [[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + auto is_zero = [](auto const& sum) { + return std::all_of(sum.cbegin(), sum.cend(), + [&](auto const& g) { return g.GetHess() - .0 == .0; }); + }; + if (is_zero(split.left_sum) || is_zero(split.right_sum)) { + return false; + } + if (split.loss_chg < param.min_split_loss) { + return false; + } + if (param.max_depth > 0 && depth == param.max_depth) { + return false; + } + if (param.max_leaves > 0 && num_leaves == param.max_leaves) { + return false; + } + return true; + } + + friend std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "split cond:" << e.split.split_value << "\n"; + os << "split ind:" << e.split.SplitIndex() << "\n"; + os << "left_sum: ["; + for (auto v : e.split.left_sum) { + os << v << ", "; + } + os << "]\n"; + + os << "right_sum: ["; + for (auto v : e.split.right_sum) { + os << v << ", "; + } + os << "]\n"; + return os; + } +}; +} // namespace xgboost::tree #endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 5af2721a6..fd636d3a3 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -226,8 +226,8 @@ class GloablApproxBuilder { for (auto const &candidate : valid_candidates) { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}}; - CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}}; + CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx)}; + CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx)}; best_splits.push_back(l_best); best_splits.push_back(r_best); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 76c402ff5..7d5f6efb3 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -57,7 +57,7 @@ bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data, CPUExpandEntry QuantileHistMaker::Builder::InitRoot( DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { - CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); + CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0)); size_t page_id = 0; auto space = ConstructHistSpace(partitioner_, {node}); @@ -197,8 +197,8 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree, for (auto const &candidate : valid_candidates) { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - CPUExpandEntry l_best{left_child_nidx, depth, 0.0}; - CPUExpandEntry r_best{right_child_nidx, depth, 0.0}; + CPUExpandEntry l_best{left_child_nidx, depth}; + CPUExpandEntry r_best{right_child_nidx, depth}; best_splits.push_back(l_best); best_splits.push_back(r_best); } diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index fc94f3130..cf9d78f52 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -98,7 +98,8 @@ TEST(HistEvaluator, Apply) { auto sampler = std::make_shared(); auto evaluator_ = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler}; - CPUExpandEntry entry{0, 0, 10.0f}; + CPUExpandEntry entry{0, 0}; + entry.split.loss_chg = 10.0f; entry.split.left_sum = GradStats{0.4, 0.6f}; entry.split.right_sum = GradStats{0.5, 0.5f}; diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 8462fa7d5..3b354bebb 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -41,10 +41,10 @@ void TestAddHistRows(bool is_distributed) { tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3), 0.0f); - nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4), 0.0f); - nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5), 0.0f); - nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3)); + nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4)); + nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5)); + nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); HistogramBuilder histogram_builder; histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, @@ -98,7 +98,7 @@ void TestSyncHist(bool is_distributed) { } // level 0 - nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0)); histogram.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -108,10 +108,8 @@ void TestSyncHist(bool is_distributed) { nodes_for_subtraction_trick_.clear(); // level 1 - nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), - tree.GetDepth(1), 0.0f); - nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), - tree.GetDepth(2), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1)); + nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2)); histogram.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, @@ -123,10 +121,10 @@ void TestSyncHist(bool is_distributed) { nodes_for_explicit_hist_build_.clear(); nodes_for_subtraction_trick_.clear(); // level 2 - nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3), 0.0f); - nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4), 0.0f); - nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5), 0.0f); - nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3)); + nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4)); + nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5)); + nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); histogram.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, @@ -256,7 +254,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_ std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); + CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)}; std::vector nodes_for_explicit_hist_build; nodes_for_explicit_hist_build.push_back(node); for (auto const &gidx : p_fmat->GetBatches({kMaxBins, 0.5})) { @@ -330,7 +328,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) { BatchParam batch_param{0, static_cast(kBins)}; RegTree tree; - CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); + CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)}; std::vector nodes_for_explicit_hist_build; nodes_for_explicit_hist_build.push_back(node); @@ -403,7 +401,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo RegTree tree; std::vector nodes; - nodes.emplace_back(0, tree.GetDepth(0), 0.0f); + nodes.emplace_back(0, tree.GetDepth(0)); common::GHistRow multi_page; HistogramBuilder multi_build; diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index cae76c373..308ae0823 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022, XGBoost contributors. +/** + * Copyright 2021-2023 by XGBoost contributors. */ #include @@ -10,7 +10,6 @@ namespace xgboost { namespace tree { - namespace { std::vector GenerateHess(size_t n_samples) { auto grad = GenerateRandomGradients(n_samples); @@ -32,7 +31,8 @@ TEST(Approx, Partitioner) { auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); auto hess = GenerateHess(n_samples); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; for (auto const& page : Xy->GetBatches({64, hess, true})) { bst_feature_t const split_ind = 0; @@ -79,7 +79,9 @@ void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared CommonRowPartitioner const& expected_mid_partitioner) { auto dmat = std::unique_ptr{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())}; - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; + Context ctx; ctx.InitAllowUnknown(Args{}); for (auto const& page : dmat->GetBatches({64, *hess, true})) { @@ -124,7 +126,8 @@ TEST(Approx, PartitionerColSplit) { size_t n_samples = 1024, n_features = 16, base_rowid = 0; auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); auto hess = GenerateHess(n_samples); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; float min_value, mid_value; Context ctx; @@ -154,7 +157,8 @@ void TestLeafPartition(size_t n_samples) { CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; RegTree tree; std::vector hess(n_samples, 0); // emulate sampling diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index ad98d1d6b..42edc2124 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -29,7 +29,8 @@ TEST(QuantileHist, Partitioner) { ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());