Define multi expand entry. (#8895)
This commit is contained in:
parent
bbee355b45
commit
5ba3509dd3
@ -1,29 +1,51 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
|
|
||||||
#include <utility>
|
#include <algorithm> // for all_of
|
||||||
#include "../param.h"
|
#include <ostream> // for ostream
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
namespace xgboost {
|
#include "../param.h" // for SplitEntry, SplitEntryContainer, TrainParam
|
||||||
namespace tree {
|
#include "xgboost/base.h" // for GradientPairPrecise, bst_node_t
|
||||||
|
|
||||||
struct CPUExpandEntry {
|
namespace xgboost::tree {
|
||||||
int nid;
|
/**
|
||||||
int depth;
|
* \brief Structure for storing tree split candidate.
|
||||||
SplitEntry split;
|
*/
|
||||||
CPUExpandEntry() = default;
|
template <typename Impl>
|
||||||
XGBOOST_DEVICE
|
struct ExpandEntryImpl {
|
||||||
CPUExpandEntry(int nid, int depth, SplitEntry split)
|
bst_node_t nid;
|
||||||
: nid(nid), depth(depth), split(std::move(split)) {}
|
bst_node_t depth;
|
||||||
CPUExpandEntry(int nid, int depth, float loss_chg)
|
|
||||||
: nid(nid), depth(depth) {
|
[[nodiscard]] float GetLossChange() const {
|
||||||
split.loss_chg = loss_chg;
|
return static_cast<Impl const*>(this)->split.loss_chg;
|
||||||
|
}
|
||||||
|
[[nodiscard]] bst_node_t GetNodeId() const { return nid; }
|
||||||
|
|
||||||
|
static bool ChildIsValid(TrainParam const& param, bst_node_t depth, bst_node_t num_leaves) {
|
||||||
|
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||||
|
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
[[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
|
return static_cast<Impl const*>(this)->IsValidImpl(param, num_leaves);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
|
||||||
|
SplitEntry split;
|
||||||
|
|
||||||
|
CPUExpandEntry() = default;
|
||||||
|
CPUExpandEntry(bst_node_t nidx, bst_node_t depth, SplitEntry split)
|
||||||
|
: ExpandEntryImpl{nidx, depth}, split(std::move(split)) {}
|
||||||
|
CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
||||||
|
|
||||||
|
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
if (split.loss_chg <= kRtEps) return false;
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||||
return false;
|
return false;
|
||||||
@ -40,16 +62,7 @@ struct CPUExpandEntry {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
float GetLossChange() const { return split.loss_chg; }
|
friend std::ostream& operator<<(std::ostream& os, CPUExpandEntry const& e) {
|
||||||
bst_node_t GetNodeId() const { return nid; }
|
|
||||||
|
|
||||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
|
||||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
|
||||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
|
|
||||||
os << "ExpandEntry:\n";
|
os << "ExpandEntry:\n";
|
||||||
os << "nidx: " << e.nid << "\n";
|
os << "nidx: " << e.nid << "\n";
|
||||||
os << "depth: " << e.depth << "\n";
|
os << "depth: " << e.depth << "\n";
|
||||||
@ -58,6 +71,54 @@ struct CPUExpandEntry {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
|
||||||
|
SplitEntryContainer<std::vector<GradientPairPrecise>> split;
|
||||||
|
|
||||||
|
MultiExpandEntry() = default;
|
||||||
|
MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
||||||
|
|
||||||
|
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
|
auto is_zero = [](auto const& sum) {
|
||||||
|
return std::all_of(sum.cbegin(), sum.cend(),
|
||||||
|
[&](auto const& g) { return g.GetHess() - .0 == .0; });
|
||||||
|
};
|
||||||
|
if (is_zero(split.left_sum) || is_zero(split.right_sum)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (split.loss_chg < param.min_split_loss) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) {
|
||||||
|
os << "ExpandEntry: \n";
|
||||||
|
os << "nidx: " << e.nid << "\n";
|
||||||
|
os << "depth: " << e.depth << "\n";
|
||||||
|
os << "loss: " << e.split.loss_chg << "\n";
|
||||||
|
os << "split cond:" << e.split.split_value << "\n";
|
||||||
|
os << "split ind:" << e.split.SplitIndex() << "\n";
|
||||||
|
os << "left_sum: [";
|
||||||
|
for (auto v : e.split.left_sum) {
|
||||||
|
os << v << ", ";
|
||||||
|
}
|
||||||
|
os << "]\n";
|
||||||
|
|
||||||
|
os << "right_sum: [";
|
||||||
|
for (auto v : e.split.right_sum) {
|
||||||
|
os << v << ", ";
|
||||||
|
}
|
||||||
|
os << "]\n";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost::tree
|
||||||
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
|
|||||||
@ -226,8 +226,8 @@ class GloablApproxBuilder {
|
|||||||
for (auto const &candidate : valid_candidates) {
|
for (auto const &candidate : valid_candidates) {
|
||||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||||
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}};
|
CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx)};
|
||||||
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}};
|
CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx)};
|
||||||
best_splits.push_back(l_best);
|
best_splits.push_back(l_best);
|
||||||
best_splits.push_back(r_best);
|
best_splits.push_back(r_best);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,7 +57,7 @@ bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
|
|||||||
|
|
||||||
CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
|
CPUExpandEntry QuantileHistMaker::Builder::InitRoot(
|
||||||
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
|
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h) {
|
||||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);
|
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
||||||
|
|
||||||
size_t page_id = 0;
|
size_t page_id = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, {node});
|
auto space = ConstructHistSpace(partitioner_, {node});
|
||||||
@ -197,8 +197,8 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
|
|||||||
for (auto const &candidate : valid_candidates) {
|
for (auto const &candidate : valid_candidates) {
|
||||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||||
CPUExpandEntry l_best{left_child_nidx, depth, 0.0};
|
CPUExpandEntry l_best{left_child_nidx, depth};
|
||||||
CPUExpandEntry r_best{right_child_nidx, depth, 0.0};
|
CPUExpandEntry r_best{right_child_nidx, depth};
|
||||||
best_splits.push_back(l_best);
|
best_splits.push_back(l_best);
|
||||||
best_splits.push_back(r_best);
|
best_splits.push_back(r_best);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -98,7 +98,8 @@ TEST(HistEvaluator, Apply) {
|
|||||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||||
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, ¶m, dmat->Info(), sampler};
|
||||||
|
|
||||||
CPUExpandEntry entry{0, 0, 10.0f};
|
CPUExpandEntry entry{0, 0};
|
||||||
|
entry.split.loss_chg = 10.0f;
|
||||||
entry.split.left_sum = GradStats{0.4, 0.6f};
|
entry.split.left_sum = GradStats{0.4, 0.6f};
|
||||||
entry.split.right_sum = GradStats{0.5, 0.5f};
|
entry.split.right_sum = GradStats{0.5, 0.5f};
|
||||||
|
|
||||||
|
|||||||
@ -41,10 +41,10 @@ void TestAddHistRows(bool is_distributed) {
|
|||||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3), 0.0f);
|
nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3));
|
||||||
nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4), 0.0f);
|
nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4));
|
||||||
nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5), 0.0f);
|
nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5));
|
||||||
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f);
|
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
|
||||||
|
|
||||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
||||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
||||||
@ -98,7 +98,7 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// level 0
|
// level 0
|
||||||
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0), 0.0f);
|
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
|
||||||
histogram.AddHistRows(&starting_index, &sync_count,
|
histogram.AddHistRows(&starting_index, &sync_count,
|
||||||
nodes_for_explicit_hist_build_,
|
nodes_for_explicit_hist_build_,
|
||||||
nodes_for_subtraction_trick_, &tree);
|
nodes_for_subtraction_trick_, &tree);
|
||||||
@ -108,10 +108,8 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick_.clear();
|
||||||
|
|
||||||
// level 1
|
// level 1
|
||||||
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(),
|
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
|
||||||
tree.GetDepth(1), 0.0f);
|
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));
|
||||||
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(),
|
|
||||||
tree.GetDepth(2), 0.0f);
|
|
||||||
|
|
||||||
histogram.AddHistRows(&starting_index, &sync_count,
|
histogram.AddHistRows(&starting_index, &sync_count,
|
||||||
nodes_for_explicit_hist_build_,
|
nodes_for_explicit_hist_build_,
|
||||||
@ -123,10 +121,10 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
nodes_for_explicit_hist_build_.clear();
|
nodes_for_explicit_hist_build_.clear();
|
||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick_.clear();
|
||||||
// level 2
|
// level 2
|
||||||
nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3), 0.0f);
|
nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3));
|
||||||
nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4), 0.0f);
|
nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4));
|
||||||
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5), 0.0f);
|
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
|
||||||
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f);
|
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
|
||||||
|
|
||||||
histogram.AddHistRows(&starting_index, &sync_count,
|
histogram.AddHistRows(&starting_index, &sync_count,
|
||||||
nodes_for_explicit_hist_build_,
|
nodes_for_explicit_hist_build_,
|
||||||
@ -256,7 +254,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection.Init();
|
row_set_collection.Init();
|
||||||
|
|
||||||
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
|
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
||||||
nodes_for_explicit_hist_build.push_back(node);
|
nodes_for_explicit_hist_build.push_back(node);
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>({kMaxBins, 0.5})) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>({kMaxBins, 0.5})) {
|
||||||
@ -330,7 +328,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
|
CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)};
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
||||||
nodes_for_explicit_hist_build.push_back(node);
|
nodes_for_explicit_hist_build.push_back(node);
|
||||||
|
|
||||||
@ -403,7 +401,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
|||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<CPUExpandEntry> nodes;
|
std::vector<CPUExpandEntry> nodes;
|
||||||
nodes.emplace_back(0, tree.GetDepth(0), 0.0f);
|
nodes.emplace_back(0, tree.GetDepth(0));
|
||||||
|
|
||||||
common::GHistRow multi_page;
|
common::GHistRow multi_page;
|
||||||
HistogramBuilder<CPUExpandEntry> multi_build;
|
HistogramBuilder<CPUExpandEntry> multi_build;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022, XGBoost contributors.
|
* Copyright 2021-2023 by XGBoost contributors.
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
@ -10,7 +10,6 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
std::vector<float> GenerateHess(size_t n_samples) {
|
std::vector<float> GenerateHess(size_t n_samples) {
|
||||||
auto grad = GenerateRandomGradients(n_samples);
|
auto grad = GenerateRandomGradients(n_samples);
|
||||||
@ -32,7 +31,8 @@ TEST(Approx, Partitioner) {
|
|||||||
|
|
||||||
auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
auto hess = GenerateHess(n_samples);
|
auto hess = GenerateHess(n_samples);
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||||
|
candidates.front().split.loss_chg = 0.4;
|
||||||
|
|
||||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({64, hess, true})) {
|
||||||
bst_feature_t const split_ind = 0;
|
bst_feature_t const split_ind = 0;
|
||||||
@ -79,7 +79,9 @@ void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared
|
|||||||
CommonRowPartitioner const& expected_mid_partitioner) {
|
CommonRowPartitioner const& expected_mid_partitioner) {
|
||||||
auto dmat =
|
auto dmat =
|
||||||
std::unique_ptr<DMatrix>{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
std::unique_ptr<DMatrix>{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||||
|
candidates.front().split.loss_chg = 0.4;
|
||||||
|
|
||||||
Context ctx;
|
Context ctx;
|
||||||
ctx.InitAllowUnknown(Args{});
|
ctx.InitAllowUnknown(Args{});
|
||||||
for (auto const& page : dmat->GetBatches<GHistIndexMatrix>({64, *hess, true})) {
|
for (auto const& page : dmat->GetBatches<GHistIndexMatrix>({64, *hess, true})) {
|
||||||
@ -124,7 +126,8 @@ TEST(Approx, PartitionerColSplit) {
|
|||||||
size_t n_samples = 1024, n_features = 16, base_rowid = 0;
|
size_t n_samples = 1024, n_features = 16, base_rowid = 0;
|
||||||
auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
auto const Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
auto hess = GenerateHess(n_samples);
|
auto hess = GenerateHess(n_samples);
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||||
|
candidates.front().split.loss_chg = 0.4;
|
||||||
|
|
||||||
float min_value, mid_value;
|
float min_value, mid_value;
|
||||||
Context ctx;
|
Context ctx;
|
||||||
@ -154,7 +157,8 @@ void TestLeafPartition(size_t n_samples) {
|
|||||||
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
|
||||||
|
|
||||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||||
|
candidates.front().split.loss_chg = 0.4;
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
std::vector<float> hess(n_samples, 0);
|
std::vector<float> hess(n_samples, 0);
|
||||||
// emulate sampling
|
// emulate sampling
|
||||||
|
|||||||
@ -29,7 +29,8 @@ TEST(QuantileHist, Partitioner) {
|
|||||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||||
|
|
||||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||||
|
candidates.front().split.loss_chg = 0.4;
|
||||||
|
|
||||||
auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
|
auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user