Use realloc for histogram cache and expose the cache limit. (#9455)
This commit is contained in:
@@ -51,7 +51,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
row_set_collection.Init();
|
||||
|
||||
HistMakerTrainParam hist_param;
|
||||
hist.Reset(gmat.cut.Ptrs().back(), hist_param.internal_max_cached_hist_node);
|
||||
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
|
||||
hist.AllocateHistograms({0});
|
||||
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column);
|
||||
|
||||
@@ -118,7 +118,7 @@ TEST(HistMultiEvaluator, Evaluate) {
|
||||
linalg::Vector<GradientPairPrecise> root_sum({2}, Context::kCpuId);
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto &hist = histogram[t];
|
||||
hist.Reset(n_bins * n_features, hist_param.internal_max_cached_hist_node);
|
||||
hist.Reset(n_bins * n_features, hist_param.max_cached_hist_node);
|
||||
hist.AllocateHistograms({0});
|
||||
auto node_hist = hist[0];
|
||||
node_hist[0] = {-0.5, 0.5};
|
||||
@@ -235,7 +235,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
||||
entries.front().nid = 0;
|
||||
entries.front().depth = 0;
|
||||
|
||||
hist.Reset(gmat.cut.TotalBins(), hist_param.internal_max_cached_hist_node);
|
||||
hist.Reset(gmat.cut.TotalBins(), hist_param.max_cached_hist_node);
|
||||
hist.AllocateHistograms({0});
|
||||
auto node_hist = hist[0];
|
||||
|
||||
@@ -265,7 +265,7 @@ TEST(HistEvaluator, Categorical) {
|
||||
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
BoundedHistCollection hist;
|
||||
HistMakerTrainParam hist_param;
|
||||
hist.Reset(cuts_.TotalBins(), hist_param.internal_max_cached_hist_node);
|
||||
hist.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
|
||||
hist.AllocateHistograms({0});
|
||||
auto node_hist = hist[0];
|
||||
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
|
||||
|
||||
@@ -516,7 +516,7 @@ class OverflowTest : public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||
Context ctx;
|
||||
HistMakerTrainParam hist_param;
|
||||
if (limit) {
|
||||
hist_param.Init(Args{{"internal_max_cached_hist_node", "1"}});
|
||||
hist_param.Init(Args{{"max_cached_hist_node", "1"}});
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> Xy =
|
||||
|
||||
Reference in New Issue
Block a user