From 54029a59af1632bd953e058855e0277292d8bee0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 8 Aug 2023 03:21:26 +0800 Subject: [PATCH] Bound the size of the histogram cache. (#9440) - A new histogram collection with a limit in size. - Unify histogram building logic between hist, multi-hist, and approx. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + include/xgboost/base.h | 8 +- python-package/xgboost/testing/data_iter.py | 34 ++ python-package/xgboost/testing/params.py | 4 + src/common/hist_util.cc | 11 - src/common/hist_util.h | 28 +- src/common/threading_utils.h | 37 +- src/data/adapter.h | 11 +- src/tree/hist/evaluate_splits.h | 19 +- src/tree/hist/expand_entry.h | 4 +- src/tree/hist/hist_cache.h | 109 +++++ src/tree/hist/histogram.cc | 63 +++ src/tree/hist/histogram.h | 399 ++++++++++++------ src/tree/hist/param.h | 15 +- src/tree/param.h | 2 +- src/tree/updater_approx.cc | 104 ++--- src/tree/updater_quantile_hist.cc | 155 ++----- tests/cpp/common/test_hist_util.cc | 6 +- tests/cpp/test_learner.cc | 2 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 48 ++- tests/cpp/tree/hist/test_histogram.cc | 377 +++++++++++------ tests/cpp/tree/test_evaluate_splits.h | 36 +- .../test_device_quantile_dmatrix.py | 4 + tests/python/test_quantile_dmatrix.py | 4 + tests/python/test_updaters.py | 45 +- .../test_with_dask/test_with_dask.py | 32 +- 27 files changed, 994 insertions(+), 565 deletions(-) create mode 100644 python-package/xgboost/testing/data_iter.py create mode 100644 src/tree/hist/hist_cache.h create mode 100644 src/tree/hist/histogram.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index f03bbc73f..a93f773f9 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -69,6 +69,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/updater_refresh.o \ $(PKGROOT)/src/tree/updater_sync.o \ $(PKGROOT)/src/tree/hist/param.o \ + $(PKGROOT)/src/tree/hist/histogram.o \ $(PKGROOT)/src/linear/linear_updater.o \ $(PKGROOT)/src/linear/updater_coordinate.o \ $(PKGROOT)/src/linear/updater_shotgun.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 9f4d0d5f3..d2f47b2aa 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -69,6 +69,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/updater_refresh.o \ $(PKGROOT)/src/tree/updater_sync.o \ $(PKGROOT)/src/tree/hist/param.o \ + $(PKGROOT)/src/tree/hist/histogram.o \ $(PKGROOT)/src/linear/linear_updater.o \ $(PKGROOT)/src/linear/updater_coordinate.o \ $(PKGROOT)/src/linear/updater_shotgun.o \ diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 9a61151f4..a5edadb6c 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -91,8 +91,6 @@ namespace xgboost { /*! \brief unsigned integer type used for feature index. */ using bst_uint = uint32_t; // NOLINT -/*! \brief integer type. */ -using bst_int = int32_t; // NOLINT /*! \brief unsigned long integers */ using bst_ulong = uint64_t; // NOLINT /*! \brief float type, used for storing statistics */ @@ -138,9 +136,9 @@ namespace detail { template class GradientPairInternal { /*! \brief gradient statistics */ - T grad_; + T grad_{0}; /*! \brief second order gradient statistics */ - T hess_; + T hess_{0}; XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; } XGBOOST_DEVICE void SetHess(T h) { hess_ = h; } @@ -157,7 +155,7 @@ class GradientPairInternal { a += b; } - XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {} + GradientPairInternal() = default; XGBOOST_DEVICE GradientPairInternal(T grad, T hess) { SetGrad(grad); diff --git a/python-package/xgboost/testing/data_iter.py b/python-package/xgboost/testing/data_iter.py new file mode 100644 index 000000000..18f8eb378 --- /dev/null +++ b/python-package/xgboost/testing/data_iter.py @@ -0,0 +1,34 @@ +"""Tests related to the `DataIter` interface.""" +import numpy as np + +import xgboost +from xgboost import testing as tm + + +def run_mixed_sparsity(device: str) -> None: + """Check QDM with mixed batches.""" + X_0, y_0, _ = tm.make_regression(128, 16, False) + if device.startswith("cuda"): + X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True) + else: + X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, False) + X_2, y_2 = tm.make_sparse_regression(512, 16, 0.9, True) + X = [X_0, X_1, X_2] + y = [y_0, y_1, y_2] + + if device.startswith("cuda"): + import cupy as cp # pylint: disable=import-error + + X = [cp.array(batch) for batch in X] + + it = tm.IteratorForTest(X, y, None, None) + Xy_0 = xgboost.QuantileDMatrix(it) + + X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True) + X = [X_0, X_1, X_2] + y = [y_0, y_1, y_2] + X_arr = np.concatenate(X, axis=0) + y_arr = np.concatenate(y, axis=0) + Xy_1 = xgboost.QuantileDMatrix(X_arr, y_arr) + + assert tm.predictor_equal(Xy_0, Xy_1) diff --git a/python-package/xgboost/testing/params.py b/python-package/xgboost/testing/params.py index 8dc91b601..4ed8f4c4e 100644 --- a/python-package/xgboost/testing/params.py +++ b/python-package/xgboost/testing/params.py @@ -41,6 +41,10 @@ hist_parameter_strategy = strategies.fixed_dictionaries( and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide") ) +hist_cache_strategy = strategies.fixed_dictionaries( + {"internal_max_cached_hist_node": strategies.sampled_from([1, 4, 1024, 2**31])} +) + hist_multi_parameter_strategy = strategies.fixed_dictionaries( { "max_depth": strategies.integers(1, 11), diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index e52ce1f66..65ab18630 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -67,17 +67,6 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins return out; } -/*! - * \brief fill a histogram by zeros in range [begin, end) - */ -void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { -#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 - std::fill(hist.begin() + begin, hist.begin() + end, xgboost::GradientPairPrecise()); -#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 - memset(hist.data() + begin, '\0', (end - begin) * sizeof(xgboost::GradientPairPrecise)); -#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 -} - /*! * \brief Increment hist as dst += add in range [begin, end) */ diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 9bc44409e..fbbd15b49 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -364,11 +364,6 @@ bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t using GHistRow = Span; using ConstGHistRow = Span; -/*! - * \brief fill a histogram by zeros - */ -void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); - /*! * \brief Increment hist as dst += add in range [begin, end) */ @@ -395,12 +390,7 @@ class HistCollection { constexpr uint32_t kMax = std::numeric_limits::max(); const size_t id = row_ptr_.at(nid); CHECK_NE(id, kMax); - GradientPairPrecise* ptr = nullptr; - if (contiguous_allocation_) { - ptr = const_cast(data_[0].data() + nbins_*id); - } else { - ptr = const_cast(data_[id].data()); - } + GradientPairPrecise* ptr = const_cast(data_[id].data()); return {ptr, nbins_}; } @@ -445,24 +435,12 @@ class HistCollection { data_[row_ptr_[nid]].resize(nbins_, {0, 0}); } } - // allocate common buffer contiguously for all nodes, need for single Allreduce call - void AllocateAllData() { - const size_t new_size = nbins_*data_.size(); - contiguous_allocation_ = true; - if (data_[0].size() != new_size) { - data_[0].resize(new_size); - } - } - [[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; } private: /*! \brief number of all bins over all features */ uint32_t nbins_ = 0; /*! \brief amount of active nodes in hist collection */ uint32_t n_nodes_added_ = 0; - /*! \brief flag to identify contiguous memory allocation */ - bool contiguous_allocation_ = false; - std::vector> data_; /*! \brief row_ptr_[nid] locates bin for histogram of node nid */ @@ -518,7 +496,7 @@ class ParallelGHistBuilder { GHistRow hist = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx]; if (!hist_was_used_[tid * nodes_ + nid]) { - InitilizeHistByZeroes(hist, 0, hist.size()); + std::fill_n(hist.data(), hist.size(), GradientPairPrecise{}); hist_was_used_[tid * nodes_ + nid] = static_cast(true); } @@ -548,7 +526,7 @@ class ParallelGHistBuilder { if (!is_updated) { // In distributed mode - some tree nodes can be empty on local machines, // So we need just set local hist by zeros in this case - InitilizeHistByZeroes(dst, begin, end); + std::fill(dst.data() + begin, dst.data() + end, GradientPairPrecise{}); } } diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 9c7483847..3c1636906 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -7,13 +7,14 @@ #include #include -#include -#include // for int32_t -#include // for malloc, free -#include +#include // for min +#include // for size_t +#include // for int32_t +#include // for malloc, free +#include // for function #include // for bad_alloc -#include // for is_signed -#include +#include // for is_signed, conditional_t +#include // for vector #include "xgboost/logging.h" @@ -25,6 +26,8 @@ inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT // MSVC doesn't implement the thread limit. #if defined(_OPENMP) && defined(_MSC_VER) +#include + extern "C" { inline int32_t omp_get_thread_limit() { return std::numeric_limits::max(); } // NOLINT } @@ -84,8 +87,8 @@ class BlockedSpace2d { // dim1 - size of the first dimension in the space // getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index // grain_size - max size of produced blocks - template - BlockedSpace2d(std::size_t dim1, Func getter_size_dim2, std::size_t grain_size) { + BlockedSpace2d(std::size_t dim1, std::function getter_size_dim2, + std::size_t grain_size) { for (std::size_t i = 0; i < dim1; ++i) { std::size_t size = getter_size_dim2(i); // Each row (second dim) is divided into n_blocks @@ -104,13 +107,13 @@ class BlockedSpace2d { } // get index of the first dimension of i-th block(task) - [[nodiscard]] std::size_t GetFirstDimension(size_t i) const { + [[nodiscard]] std::size_t GetFirstDimension(std::size_t i) const { CHECK_LT(i, first_dimension_.size()); return first_dimension_[i]; } // get a range of indexes for the second dimension of i-th block(task) - [[nodiscard]] Range1d GetRange(size_t i) const { + [[nodiscard]] Range1d GetRange(std::size_t i) const { CHECK_LT(i, ranges_.size()); return ranges_[i]; } @@ -129,22 +132,22 @@ class BlockedSpace2d { } std::vector ranges_; - std::vector first_dimension_; + std::vector first_dimension_; }; // Wrapper to implement nested parallelism with simple omp parallel for -template -void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { +inline void ParallelFor2d(BlockedSpace2d const& space, std::int32_t n_threads, + std::function func) { std::size_t n_blocks_in_space = space.Size(); - CHECK_GE(nthreads, 1); + CHECK_GE(n_threads, 1); dmlc::OMPException exc; -#pragma omp parallel num_threads(nthreads) +#pragma omp parallel num_threads(n_threads) { exc.Run([&]() { - size_t tid = omp_get_thread_num(); - size_t chunck_size = n_blocks_in_space / nthreads + !!(n_blocks_in_space % nthreads); + std::size_t tid = omp_get_thread_num(); + std::size_t chunck_size = n_blocks_in_space / n_threads + !!(n_blocks_in_space % n_threads); std::size_t begin = chunck_size * tid; std::size_t end = std::min(begin + chunck_size, n_blocks_in_space); diff --git a/src/data/adapter.h b/src/data/adapter.h index 7776177ab..1463a13a7 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -477,7 +477,6 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo { ArrayInterface<1> indptr_; ArrayInterface<1> indices_; ArrayInterface<1> values_; - bst_row_t n_rows_; class Line { std::size_t column_idx_; @@ -503,11 +502,8 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo { static constexpr bool kIsRowMajor = false; CSCArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices, - ArrayInterface<1> values, bst_row_t n_rows) - : indptr_{std::move(indptr)}, - indices_{std::move(indices)}, - values_{std::move(values)}, - n_rows_{n_rows} {} + ArrayInterface<1> values) + : indptr_{std::move(indptr)}, indices_{std::move(indices)}, values_{std::move(values)} {} std::size_t Size() const { return indptr_.n - 1; } Line GetLine(std::size_t idx) const { @@ -542,8 +538,7 @@ class CSCArrayAdapter : public detail::SingleBatchDataIter indices_{indices}, values_{values}, num_rows_{num_rows}, - batch_{ - CSCArrayAdapterBatch{indptr_, indices_, values_, static_cast(num_rows_)}} {} + batch_{CSCArrayAdapterBatch{indptr_, indices_, values_}} {} // JVM package sends 0 as unknown size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index f4e44fa52..82dc99b12 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -4,13 +4,13 @@ #ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ -#include // for copy -#include // for size_t -#include // for numeric_limits -#include // for shared_ptr -#include // for accumulate -#include // for move -#include // for vector +#include // for copy +#include // for size_t +#include // for numeric_limits +#include // for shared_ptr +#include // for accumulate +#include // for move +#include // for vector #include "../../common/categorical.h" // for CatBitField #include "../../common/hist_util.h" // for GHistRow, HistogramCuts @@ -20,6 +20,7 @@ #include "../param.h" // for TrainParam #include "../split_evaluator.h" // for TreeEvaluator #include "expand_entry.h" // for MultiExpandEntry +#include "hist_cache.h" // for BoundedHistCollection #include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t #include "xgboost/context.h" // for COntext #include "xgboost/linalg.h" // for Constants, Vector @@ -317,7 +318,7 @@ class HistEvaluator { } public: - void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut, + void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut, common::Span feature_types, const RegTree &tree, std::vector *p_entries) { auto n_threads = ctx_->Threads(); @@ -623,7 +624,7 @@ class HistMultiEvaluator { } public: - void EvaluateSplits(RegTree const &tree, common::Span hist, + void EvaluateSplits(RegTree const &tree, common::Span hist, common::HistogramCuts const &cut, std::vector *p_entries) { auto &entries = *p_entries; std::vector>> features(entries.size()); diff --git a/src/tree/hist/expand_entry.h b/src/tree/hist/expand_entry.h index e7e19be06..0225a5110 100644 --- a/src/tree/hist/expand_entry.h +++ b/src/tree/hist/expand_entry.h @@ -18,8 +18,8 @@ namespace xgboost::tree { */ template struct ExpandEntryImpl { - bst_node_t nid; - bst_node_t depth; + bst_node_t nid{0}; + bst_node_t depth{0}; [[nodiscard]] float GetLossChange() const { return static_cast(this)->split.loss_chg; diff --git a/src/tree/hist/hist_cache.h b/src/tree/hist/hist_cache.h new file mode 100644 index 000000000..79e5d9bad --- /dev/null +++ b/src/tree/hist/hist_cache.h @@ -0,0 +1,109 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TREE_HIST_HIST_CACHE_H_ +#define XGBOOST_TREE_HIST_HIST_CACHE_H_ +#include // for size_t +#include // for map +#include // for vector + +#include "../../common/hist_util.h" // for GHistRow, ConstGHistRow +#include "xgboost/base.h" // for bst_node_t, bst_bin_t +#include "xgboost/logging.h" // for CHECK_GT +#include "xgboost/span.h" // for Span + +namespace xgboost::tree { +/** + * @brief A persistent cache for CPU histogram. + * + * The size of the cache is first bounded by the `Driver` class then by this cache + * implementaiton. The former limits the number of nodes that can be built for each node + * batch, while this cache limits the number of all nodes up to the size of + * max(|node_batch|, n_cached_node). + * + * The caller is responsible for clearing up the cache as it needs to rearrange the + * nodes before making overflowed allocations. The strcut only reports whether the size + * limit has benn reached. + */ +class BoundedHistCollection { + // maps node index to offset in `data_`. + std::map node_map_; + // currently allocated bins, used for tracking consistentcy. + std::size_t current_size_{0}; + + // stores the histograms in a contiguous buffer + std::vector data_; + + // number of histogram bins across all features + bst_bin_t n_total_bins_{0}; + // limits the number of nodes that can be in the cache for each tree + std::size_t n_cached_nodes_{0}; + // whether the tree has grown beyond the cache limit + bool has_exceeded_{false}; + + public: + common::GHistRow operator[](std::size_t idx) { + auto offset = node_map_.at(idx); + return common::Span{data_.data(), data_.size()}.subspan(offset, n_total_bins_); + } + common::ConstGHistRow operator[](std::size_t idx) const { + auto offset = node_map_.at(idx); + return common::Span{data_.data(), data_.size()}.subspan(offset, n_total_bins_); + } + void Reset(bst_bin_t n_total_bins, std::size_t n_cached_nodes) { + n_total_bins_ = n_total_bins; + n_cached_nodes_ = n_cached_nodes; + this->Clear(false); + } + /** + * @brief Clear the cache, mark whether the cache is exceeded the limit. + */ + void Clear(bool exceeded) { + node_map_.clear(); + current_size_ = 0; + has_exceeded_ = exceeded; + } + + [[nodiscard]] bool CanHost(common::Span nodes_to_build, + common::Span nodes_to_sub) const { + auto n_new_nodes = nodes_to_build.size() + nodes_to_sub.size(); + return n_new_nodes + node_map_.size() <= n_cached_nodes_; + } + + /** + * @brief Allocate histogram buffers for all nodes. + * + * The resulting histogram buffer is contiguous for all nodes in the order of + * allocation. + */ + void AllocateHistograms(common::Span nodes_to_build, + common::Span nodes_to_sub) { + auto n_new_nodes = nodes_to_build.size() + nodes_to_sub.size(); + auto alloc_size = n_new_nodes * n_total_bins_; + auto new_size = alloc_size + current_size_; + if (new_size > data_.size()) { + data_.resize(new_size); + } + for (auto nidx : nodes_to_build) { + node_map_[nidx] = current_size_; + current_size_ += n_total_bins_; + } + for (auto nidx : nodes_to_sub) { + node_map_[nidx] = current_size_; + current_size_ += n_total_bins_; + } + CHECK_EQ(current_size_, new_size); + } + void AllocateHistograms(std::vector const& nodes) { + this->AllocateHistograms(common::Span{nodes}, + common::Span{}); + } + + [[nodiscard]] bool HasExceeded() const { return has_exceeded_; } + [[nodiscard]] bool HistogramExists(bst_node_t nidx) const { + return node_map_.find(nidx) != node_map_.cend(); + } + [[nodiscard]] std::size_t Size() const { return current_size_; } +}; +} // namespace xgboost::tree +#endif // XGBOOST_TREE_HIST_HIST_CACHE_H_ diff --git a/src/tree/hist/histogram.cc b/src/tree/hist/histogram.cc new file mode 100644 index 000000000..96abc039c --- /dev/null +++ b/src/tree/hist/histogram.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include "histogram.h" + +#include // for size_t +#include // for accumulate +#include // for swap +#include // for vector + +#include "../../common/transform_iterator.h" // for MakeIndexTransformIter +#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry +#include "xgboost/logging.h" // for CHECK_NE +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree + +namespace xgboost::tree { +void AssignNodes(RegTree const *p_tree, std::vector const &valid_candidates, + common::Span nodes_to_build, common::Span nodes_to_sub) { + CHECK_EQ(nodes_to_build.size(), valid_candidates.size()); + + std::size_t n_idx = 0; + for (auto const &c : valid_candidates) { + auto left_nidx = p_tree->LeftChild(c.nid); + auto right_nidx = p_tree->RightChild(c.nid); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + auto lit = + common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); }); + auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0); + auto rit = + common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); }); + auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0); + auto fewer_right = right_sum < left_sum; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build[n_idx] = build_nidx; + nodes_to_sub[n_idx] = subtract_nidx; + ++n_idx; + } +} + +void AssignNodes(RegTree const *p_tree, std::vector const &candidates, + common::Span nodes_to_build, common::Span nodes_to_sub) { + std::size_t n_idx = 0; + for (auto const &c : candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build[n_idx] = build_nidx; + nodes_to_sub[n_idx] = subtract_nidx; + ++n_idx; + } +} +} // namespace xgboost::tree diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index aef7f6df1..54c716887 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -4,80 +4,85 @@ #ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_ #define XGBOOST_TREE_HIST_HISTOGRAM_H_ -#include -#include -#include +#include // for max +#include // for size_t +#include // for int32_t +#include // for function +#include // for move +#include // for vector -#include "../../collective/communicator-inl.h" -#include "../../common/hist_util.h" -#include "../../data/gradient_index.h" -#include "expand_entry.h" -#include "xgboost/tree_model.h" // for RegTree +#include "../../collective/communicator-inl.h" // for Allreduce +#include "../../collective/communicator.h" // for Operation +#include "../../common/hist_util.h" // for GHistRow, ParallelGHi... +#include "../../common/row_set.h" // for RowSetCollection +#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d +#include "../../data/gradient_index.h" // for GHistIndexMatrix +#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry +#include "hist_cache.h" // for BoundedHistCollection +#include "param.h" // for HistMakerTrainParam +#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for BatchIterator, BatchSet +#include "xgboost/linalg.h" // for MatrixView, All, Vect... +#include "xgboost/logging.h" // for CHECK_GE +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { -template +/** + * @brief Decide which node as the build node for multi-target trees. + */ +void AssignNodes(RegTree const *p_tree, std::vector const &valid_candidates, + common::Span nodes_to_build, common::Span nodes_to_sub); + +/** + * @brief Decide which node as the build node. + */ +void AssignNodes(RegTree const *p_tree, std::vector const &candidates, + common::Span nodes_to_build, common::Span nodes_to_sub); + class HistogramBuilder { /*! \brief culmulative histogram of gradients. */ - common::HistCollection hist_; + BoundedHistCollection hist_; common::ParallelGHistBuilder buffer_; BatchParam param_; int32_t n_threads_{-1}; - size_t n_batches_{0}; // Whether XGBoost is running in distributed environment. bool is_distributed_{false}; bool is_col_split_{false}; public: /** - * \param total_bins Total number of bins across all features - * \param max_bin_per_feat Maximum number of bins per feature, same as the `max_bin` - * training parameter. - * \param n_threads Number of threads. - * \param is_distributed Mostly used for testing to allow injecting parameters instead + * @brief Reset the builder, should be called before growing a new tree. + * + * @param total_bins Total number of bins across all features + * @param is_distributed Mostly used for testing to allow injecting parameters instead * of using global rabit variable. */ - void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches, - bool is_distributed, bool is_col_split) { - CHECK_GE(n_threads, 1); - n_threads_ = n_threads; - n_batches_ = n_batches; + void Reset(Context const *ctx, bst_bin_t total_bins, BatchParam const &p, bool is_distributed, + bool is_col_split, HistMakerTrainParam const *param) { + n_threads_ = ctx->Threads(); param_ = p; - hist_.Init(total_bins); + hist_.Reset(total_bins, param->internal_max_cached_hist_node); buffer_.Init(total_bins); is_distributed_ = is_distributed; is_col_split_ = is_col_split; } template - void BuildLocalHistograms(size_t page_idx, common::BlockedSpace2d space, - GHistIndexMatrix const &gidx, - std::vector const &nodes_for_explicit_hist_build, + void BuildLocalHistograms(common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx, + std::vector const &nodes_to_build, common::RowSetCollection const &row_set_collection, common::Span gpair_h, bool force_read_by_column) { - const size_t n_nodes = nodes_for_explicit_hist_build.size(); - CHECK_GT(n_nodes, 0); - - std::vector target_hists(n_nodes); - for (size_t i = 0; i < n_nodes; ++i) { - auto const nidx = nodes_for_explicit_hist_build[i].nid; - target_hists[i] = hist_[nidx]; - } - if (page_idx == 0) { - // FIXME(jiamingy): Handle different size of space. Right now we use the maximum - // partition size for the buffer, which might not be efficient if partition sizes - // has significant variance. - buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); - } - // Parallel processing by nodes and data in each node common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { const auto tid = static_cast(omp_get_thread_num()); - const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; - auto elem = row_set_collection[nid]; + bst_node_t const nidx = nodes_to_build[nid_in_set]; + auto elem = row_set_collection[nidx]; auto start_of_row_set = std::min(r.begin(), elem.Size()); auto end_of_row_set = std::min(r.end(), elem.Size()); auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set, - elem.begin + end_of_row_set, nid); + elem.begin + end_of_row_set, nidx); auto hist = buffer_.GetInitializedHist(tid, nid_in_set); if (rid_set.Size() != 0) { common::BuildHist(gpair_h, rid_set, gidx, hist, force_read_by_column); @@ -85,117 +90,143 @@ class HistogramBuilder { }); } - void AddHistRows(int *starting_index, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick) { - for (auto const &entry : nodes_for_explicit_hist_build) { - int nid = entry.nid; - this->hist_.AddHistRow(nid); - (*starting_index) = std::min(nid, (*starting_index)); + /** + * @brief Allocate histogram, rearrange the nodes if `rearrange` is true and the tree + * has reached the cache size limit. + */ + void AddHistRows(RegTree const *p_tree, std::vector *p_nodes_to_build, + std::vector *p_nodes_to_sub, bool rearrange) { + CHECK(p_nodes_to_build); + auto &nodes_to_build = *p_nodes_to_build; + CHECK(p_nodes_to_sub); + auto &nodes_to_sub = *p_nodes_to_sub; + + // We first check whether the cache size is already exceeded or about to be exceeded. + // If not, then we can allocate histograms without clearing the cache and without + // worrying about missing parent histogram. + // + // Otherwise, we need to rearrange the nodes before the allocation to make sure the + // resulting buffer is contiguous. This is to facilitate efficient allreduce. + + bool can_host = this->hist_.CanHost(nodes_to_build, nodes_to_sub); + // True if the tree is still within the size of cache limit. Allocate histogram as + // usual. + auto cache_is_valid = can_host && !this->hist_.HasExceeded(); + + if (!can_host) { + this->hist_.Clear(true); } - for (auto const &node : nodes_for_subtraction_trick) { - this->hist_.AddHistRow(node.nid); - } - this->hist_.AllocateAllData(); - } - - /** Main entry point of this class, build histogram for tree nodes. */ - void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx, - RegTree const *p_tree, common::RowSetCollection const &row_set_collection, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick, - common::Span gpair, bool force_read_by_column = false) { - int starting_index = std::numeric_limits::max(); - if (page_id == 0) { - this->AddHistRows(&starting_index, nodes_for_explicit_hist_build, - nodes_for_subtraction_trick); - } - if (gidx.IsDense()) { - this->BuildLocalHistograms(page_id, space, gidx, nodes_for_explicit_hist_build, - row_set_collection, gpair, force_read_by_column); - } else { - this->BuildLocalHistograms(page_id, space, gidx, nodes_for_explicit_hist_build, - row_set_collection, gpair, force_read_by_column); - } - - CHECK_GE(n_batches_, 1); - if (page_id != n_batches_ - 1) { + if (!rearrange || cache_is_valid) { + // If not rearrange, we allocate the histogram as usual, assuming the nodes have + // been properly arranged by other builders. + this->hist_.AllocateHistograms(nodes_to_build, nodes_to_sub); + if (rearrange) { + CHECK(!this->hist_.HasExceeded()); + } return; } - this->SyncHistogram(p_tree, nodes_for_explicit_hist_build, - nodes_for_subtraction_trick, starting_index); - } - /** same as the other build hist but handles only single batch data (in-core) */ - void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree, - common::RowSetCollection const &row_set_collection, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick, - common::Span gpair, bool force_read_by_column = false) { - const size_t n_nodes = nodes_for_explicit_hist_build.size(); - // create space of size (# rows in each node) - common::BlockedSpace2d space( - n_nodes, - [&](size_t nidx_in_set) { - const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid; - return row_set_collection[nidx].Size(); - }, - 256); - this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, nodes_for_explicit_hist_build, - nodes_for_subtraction_trick, gpair, force_read_by_column); - } - - void SyncHistogram(RegTree const *p_tree, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick, - int starting_index) { - auto n_bins = buffer_.TotalBins(); - common::BlockedSpace2d space( - nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024); - CHECK(hist_.IsContiguous()); - common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) { - const auto &entry = nodes_for_explicit_hist_build[node]; - auto this_hist = this->hist_[entry.nid]; - // Merging histograms from each thread into once - this->buffer_.ReduceHist(node, r.begin(), r.end()); - }); - - if (is_distributed_ && !is_col_split_) { - collective::Allreduce( - reinterpret_cast(this->hist_[starting_index].data()), - n_bins * nodes_for_explicit_hist_build.size() * 2); + // The cache is full, parent histogram might be removed in previous iterations to + // saved memory. + std::vector can_subtract; + for (auto const &v : nodes_to_sub) { + if (this->hist_.HistogramExists(p_tree->Parent(v))) { + // We can still use the subtraction trick for this node + can_subtract.push_back(v); + } else { + // This node requires a full build + nodes_to_build.push_back(v); + } } - common::ParallelFor2d(space, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) { - const auto &entry = nodes_for_explicit_hist_build[nidx_in_set]; - auto this_hist = this->hist_[entry.nid]; - if (!p_tree->IsRoot(entry.nid)) { - auto const parent_id = p_tree->Parent(entry.nid); - auto const subtraction_node_id = nodes_for_subtraction_trick[nidx_in_set].nid; - auto parent_hist = this->hist_[parent_id]; - auto sibling_hist = this->hist_[subtraction_node_id]; - common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end()); + nodes_to_sub = std::move(can_subtract); + this->hist_.AllocateHistograms(nodes_to_build, nodes_to_sub); + } + + /** Main entry point of this class, build histogram for tree nodes. */ + void BuildHist(std::size_t page_idx, common::BlockedSpace2d const &space, + GHistIndexMatrix const &gidx, common::RowSetCollection const &row_set_collection, + std::vector const &nodes_to_build, + linalg::VectorView gpair, bool force_read_by_column = false) { + CHECK(gpair.Contiguous()); + + if (page_idx == 0) { + // Add the local histogram cache to the parallel buffer before processing the first page. + auto n_nodes = nodes_to_build.size(); + std::vector target_hists(n_nodes); + for (size_t i = 0; i < n_nodes; ++i) { + auto const nidx = nodes_to_build[i]; + target_hists[i] = hist_[nidx]; } + buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + } + + if (gidx.IsDense()) { + this->BuildLocalHistograms(space, gidx, nodes_to_build, row_set_collection, + gpair.Values(), force_read_by_column); + } else { + this->BuildLocalHistograms(space, gidx, nodes_to_build, row_set_collection, + gpair.Values(), force_read_by_column); + } + } + + void SyncHistogram(RegTree const *p_tree, std::vector const &nodes_to_build, + std::vector const &nodes_to_trick) { + auto n_total_bins = buffer_.TotalBins(); + common::BlockedSpace2d space( + nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024); + common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) { + // Merging histograms from each thread. + this->buffer_.ReduceHist(node, r.begin(), r.end()); }); + if (is_distributed_ && !is_col_split_) { + // The cache is contiguous, we can perform allreduce for all nodes in one go. + CHECK(!nodes_to_build.empty()); + auto first_nidx = nodes_to_build.front(); + std::size_t n = n_total_bins * nodes_to_build.size() * 2; + collective::Allreduce( + reinterpret_cast(this->hist_[first_nidx].data()), n); + } + + common::BlockedSpace2d const &subspace = + nodes_to_trick.size() == nodes_to_build.size() + ? space + : common::BlockedSpace2d{nodes_to_trick.size(), + [&](std::size_t) { return n_total_bins; }, 1024}; + common::ParallelFor2d( + subspace, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) { + auto subtraction_nidx = nodes_to_trick[nidx_in_set]; + auto parent_id = p_tree->Parent(subtraction_nidx); + auto sibling_nidx = p_tree->IsLeftChild(subtraction_nidx) ? p_tree->RightChild(parent_id) + : p_tree->LeftChild(parent_id); + auto sibling_hist = this->hist_[sibling_nidx]; + auto parent_hist = this->hist_[parent_id]; + auto subtract_hist = this->hist_[subtraction_nidx]; + common::SubtractionHist(subtract_hist, parent_hist, sibling_hist, r.begin(), r.end()); + }); } public: /* Getters for tests. */ - common::HistCollection const &Histogram() { return hist_; } + [[nodiscard]] BoundedHistCollection const &Histogram() const { return hist_; } + [[nodiscard]] BoundedHistCollection &Histogram() { return hist_; } auto &Buffer() { return buffer_; } }; // Construct a work space for building histogram. Eventually we should move this // function into histogram builder once hist tree method supports external memory. -template +template common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners, - std::vector const &nodes_to_build) { - std::vector partition_size(nodes_to_build.size(), 0); + std::vector const &nodes_to_build) { + // FIXME(jiamingy): Handle different size of space. Right now we use the maximum + // partition size for the buffer, which might not be efficient if partition sizes + // has significant variance. + std::vector partition_size(nodes_to_build.size(), 0); for (auto const &partition : partitioners) { size_t k = 0; - for (auto node : nodes_to_build) { - auto n_rows_in_node = partition.Partitions()[node.nid].Size(); + for (auto nidx : nodes_to_build) { + auto n_rows_in_node = partition.Partitions()[nidx].Size(); partition_size[k] = std::max(partition_size[k], n_rows_in_node); k++; } @@ -204,5 +235,107 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners, nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256}; return space; } + +/** + * @brief Histogram builder that can handle multiple targets. + */ +class MultiHistogramBuilder { + std::vector target_builders_; + Context const *ctx_; + + public: + /** + * @brief Build the histogram for root node. + */ + template + void BuildRootHist(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &partitioners, + linalg::MatrixView gpair, ExpandEntry const &best, + BatchParam const ¶m, bool force_read_by_column = false) { + auto n_targets = p_tree->NumTargets(); + CHECK_EQ(gpair.Shape(1), n_targets); + CHECK_EQ(p_fmat->Info().num_row_, gpair.Shape(0)); + CHECK_EQ(target_builders_.size(), n_targets); + std::vector nodes{best.nid}; + std::vector dummy_sub; + + auto space = ConstructHistSpace(partitioners, nodes); + for (bst_target_t t{0}; t < n_targets; ++t) { + this->target_builders_[t].AddHistRows(p_tree, &nodes, &dummy_sub, false); + } + CHECK(dummy_sub.empty()); + + std::size_t page_idx{0}; + for (auto const &gidx : p_fmat->GetBatches(ctx_, param)) { + for (bst_target_t t{0}; t < n_targets; ++t) { + auto t_gpair = gpair.Slice(linalg::All(), t); + this->target_builders_[t].BuildHist(page_idx, space, gidx, + partitioners[page_idx].Partitions(), nodes, t_gpair, + force_read_by_column); + } + ++page_idx; + } + + for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) { + this->target_builders_[t].SyncHistogram(p_tree, nodes, dummy_sub); + } + } + /** + * @brief Build histogram for left and right child of valid candidates + */ + template + void BuildHistLeftRight(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &partitioners, + std::vector const &valid_candidates, + linalg::MatrixView gpair, BatchParam const ¶m, + bool force_read_by_column = false) { + std::vector nodes_to_build(valid_candidates.size()); + std::vector nodes_to_sub(valid_candidates.size()); + AssignNodes(p_tree, valid_candidates, nodes_to_build, nodes_to_sub); + + // use the first builder for getting number of valid nodes. + target_builders_.front().AddHistRows(p_tree, &nodes_to_build, &nodes_to_sub, true); + CHECK_GE(nodes_to_build.size(), nodes_to_sub.size()); + CHECK_EQ(nodes_to_sub.size() + nodes_to_build.size(), valid_candidates.size() * 2); + + // allocate storage for the rest of the builders + for (bst_target_t t = 1; t < target_builders_.size(); ++t) { + target_builders_[t].AddHistRows(p_tree, &nodes_to_build, &nodes_to_sub, false); + } + + auto space = ConstructHistSpace(partitioners, nodes_to_build); + std::size_t page_idx{0}; + for (auto const &page : p_fmat->GetBatches(ctx_, param)) { + CHECK_EQ(gpair.Shape(1), p_tree->NumTargets()); + for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) { + auto t_gpair = gpair.Slice(linalg::All(), t); + CHECK_EQ(t_gpair.Shape(0), p_fmat->Info().num_row_); + this->target_builders_[t].BuildHist(page_idx, space, page, + partitioners[page_idx].Partitions(), nodes_to_build, + t_gpair, force_read_by_column); + } + page_idx++; + } + + for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) { + this->target_builders_[t].SyncHistogram(p_tree, nodes_to_build, nodes_to_sub); + } + } + + [[nodiscard]] auto const &Histogram(bst_target_t t) const { + return target_builders_[t].Histogram(); + } + [[nodiscard]] auto &Histogram(bst_target_t t) { return target_builders_[t].Histogram(); } + + void Reset(Context const *ctx, bst_bin_t total_bins, bst_target_t n_targets, BatchParam const &p, + bool is_distributed, bool is_col_split, HistMakerTrainParam const *param) { + ctx_ = ctx; + target_builders_.resize(n_targets); + CHECK_GE(n_targets, 1); + for (auto &v : target_builders_) { + v.Reset(ctx, total_bins, p, is_distributed, is_col_split, param); + } + } +}; } // namespace xgboost::tree #endif // XGBOOST_TREE_HIST_HISTOGRAM_H_ diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h index 3dfbf68e1..0f2f4ac00 100644 --- a/src/tree/hist/param.h +++ b/src/tree/hist/param.h @@ -2,12 +2,19 @@ * Copyright 2021-2023, XGBoost Contributors */ #pragma once -#include "xgboost/parameter.h" + +#include // for size_t + +#include "xgboost/parameter.h" // for XGBoostParameter #include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { struct HistMakerTrainParam : public XGBoostParameter { - bool debug_synchronize; + constexpr static std::size_t DefaultNodes() { return static_cast(1) << 16; } + + bool debug_synchronize{false}; + std::size_t internal_max_cached_hist_node{DefaultNodes()}; + void CheckTreesSynchronized(RegTree const* local_tree) const; // declare parameters @@ -15,6 +22,10 @@ struct HistMakerTrainParam : public XGBoostParameter { DMLC_DECLARE_FIELD(debug_synchronize) .set_default(false) .describe("Check if all distributed tree are identical after tree construction."); + DMLC_DECLARE_FIELD(internal_max_cached_hist_node) + .set_default(DefaultNodes()) + .set_lower_bound(1) + .describe("Maximum number of nodes in CPU histogram cache. Only for internal usage."); } }; } // namespace xgboost::tree diff --git a/src/tree/param.h b/src/tree/param.h index e182fe539..5e2a36dfe 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -526,7 +526,7 @@ struct SplitEntryContainer { * \return whether the proposed split is better and can replace current split */ template - bool Update(bst_float new_loss_chg, unsigned split_index, bst_float new_split_value, + bool Update(bst_float new_loss_chg, bst_feature_t split_index, float new_split_value, bool default_left, bool is_cat, GradientSumT const &left_sum, GradientSumT const &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 9f496d052..2110cd6e6 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -3,27 +3,39 @@ * * \brief Implementation for the approx tree method. */ -#include -#include -#include +#include // for max, transform, fill_n +#include // for size_t +#include // for map +#include // for allocator, unique_ptr, make_shared, make_unique +#include // for move +#include // for vector -#include "../collective/aggregator.h" -#include "../common/random.h" -#include "../data/gradient_index.h" -#include "common_row_partitioner.h" -#include "driver.h" -#include "hist/evaluate_splits.h" -#include "hist/histogram.h" -#include "hist/param.h" -#include "hist/sampler.h" // for SampleGradient -#include "param.h" // for HistMakerTrainParam -#include "xgboost/base.h" -#include "xgboost/data.h" -#include "xgboost/json.h" -#include "xgboost/linalg.h" -#include "xgboost/task.h" // for ObjInfo -#include "xgboost/tree_model.h" -#include "xgboost/tree_updater.h" // for TreeUpdater +#include "../collective/aggregator.h" // for GlobalSum +#include "../collective/communicator-inl.h" // for IsDistributed +#include "../common/hist_util.h" // for HistogramCuts +#include "../common/random.h" // for ColumnSampler +#include "../common/timer.h" // for Monitor +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "common_row_partitioner.h" // for CommonRowPartitioner +#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG +#include "driver.h" // for Driver +#include "hist/evaluate_splits.h" // for HistEvaluator, UpdatePredictionCacheImpl +#include "hist/expand_entry.h" // for CPUExpandEntry +#include "hist/histogram.h" // for MultiHistogramBuilder +#include "hist/param.h" // for HistMakerTrainParam +#include "hist/sampler.h" // for SampleGradient +#include "param.h" // for GradStats, TrainParam +#include "xgboost/base.h" // for Args, GradientPair, bst_node_t, bst_bin_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, BatchSet, BatchIterator, MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Object, Json, FromJson, ToJson, get +#include "xgboost/linalg.h" // for Matrix, MakeTensorView, Empty, MatrixView +#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK +#include "xgboost/span.h" // for Span +#include "xgboost/task.h" // for ObjInfo +#include "xgboost/tree_model.h" // for RegTree, RTreeNodeStat +#include "xgboost/tree_updater.h" // for TreeUpdater, TreeUpdaterReg, XGBOOST_REGISTE... namespace xgboost::tree { @@ -46,7 +58,7 @@ class GloablApproxBuilder { HistMakerTrainParam const *hist_param_{nullptr}; std::shared_ptr col_sampler_; HistEvaluator evaluator_; - HistogramBuilder histogram_builder_; + MultiHistogramBuilder histogram_builder_; Context const *ctx_; ObjInfo const *const task_; @@ -59,7 +71,7 @@ class GloablApproxBuilder { common::HistogramCuts feature_values_; public: - void InitData(DMatrix *p_fmat, common::Span hess) { + void InitData(DMatrix *p_fmat, RegTree const *p_tree, common::Span hess) { monitor_->Start(__func__); n_batches_ = 0; @@ -79,8 +91,9 @@ class GloablApproxBuilder { n_batches_++; } - histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_, - collective::IsDistributed(), p_fmat->Info().IsColumnSplit()); + histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess), + collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), + hist_param_); monitor_->Stop(__func__); } @@ -96,20 +109,16 @@ class GloablApproxBuilder { } collective::GlobalSum(p_fmat->Info(), reinterpret_cast(&root_sum), 2); std::vector nodes{best}; - size_t i = 0; - auto space = ConstructHistSpace(partitioner_, nodes); - for (auto const &page : p_fmat->GetBatches(ctx_, BatchSpec(*param_, hess))) { - histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, - {}, gpair); - i++; - } + this->histogram_builder_.BuildRootHist(p_fmat, p_tree, partitioner_, + linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1), + best, BatchSpec(*param_, hess)); auto weight = evaluator_.InitRoot(root_sum); p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess(); p_tree->Stat(RegTree::kRoot).base_weight = weight; (*p_tree)[RegTree::kRoot].SetLeaf(param_->learning_rate * weight); - auto const &histograms = histogram_builder_.Histogram(); + auto const &histograms = histogram_builder_.Histogram(0); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &nodes); monitor_->Stop(__func__); @@ -130,30 +139,9 @@ class GloablApproxBuilder { std::vector const &valid_candidates, std::vector const &gpair, common::Span hess) { monitor_->Start(__func__); - std::vector nodes_to_build; - std::vector nodes_to_sub; - - for (auto const &c : valid_candidates) { - auto left_nidx = (*p_tree)[c.nid].LeftChild(); - auto right_nidx = (*p_tree)[c.nid].RightChild(); - auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); - - auto build_nidx = left_nidx; - auto subtract_nidx = right_nidx; - if (fewer_right) { - std::swap(build_nidx, subtract_nidx); - } - nodes_to_build.push_back(CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}); - nodes_to_sub.push_back(CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}); - } - - size_t i = 0; - auto space = ConstructHistSpace(partitioner_, nodes_to_build); - for (auto const &page : p_fmat->GetBatches(ctx_, BatchSpec(*param_, hess))) { - histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), - nodes_to_build, nodes_to_sub, gpair); - i++; - } + this->histogram_builder_.BuildHistLeftRight( + p_fmat, p_tree, partitioner_, valid_candidates, + linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1), BatchSpec(*param_, hess)); monitor_->Stop(__func__); } @@ -185,7 +173,7 @@ class GloablApproxBuilder { void UpdateTree(DMatrix *p_fmat, std::vector const &gpair, common::Span hess, RegTree *p_tree, HostDeviceVector *p_out_position) { p_last_tree_ = p_tree; - this->InitData(p_fmat, hess); + this->InitData(p_fmat, p_tree, hess); Driver driver(*param_); auto &tree = *p_tree; @@ -235,7 +223,7 @@ class GloablApproxBuilder { best_splits.push_back(l_best); best_splits.push_back(r_best); } - auto const &histograms = histogram_builder_.Histogram(); + auto const &histograms = histogram_builder_.Histogram(0); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); monitor_->Start("EvaluateSplits"); evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 63aaf27f6..883c18f36 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -7,35 +7,37 @@ #include // for max, copy, transform #include // for size_t #include // for uint32_t, int32_t -#include // for unique_ptr, allocator, make_unique, shared_ptr -#include // for accumulate -#include // for basic_ostream, char_traits, operator<< -#include // for move, swap +#include // for exception +#include // for allocator, unique_ptr, make_unique, shared_ptr +#include // for operator<<, basic_ostream, char_traits +#include // for move #include // for vector #include "../collective/aggregator.h" // for GlobalSum -#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed -#include "../common/hist_util.h" // for HistogramCuts, HistCollection +#include "../collective/communicator-inl.h" // for IsDistributed +#include "../common/hist_util.h" // for HistogramCuts, GHistRow #include "../common/linalg_op.h" // for begin, cbegin, cend #include "../common/random.h" // for ColumnSampler #include "../common/threading_utils.h" // for ParallelFor #include "../common/timer.h" // for Monitor -#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter +#include "../common/transform_iterator.h" // for IndexTransformIter #include "../data/gradient_index.h" // for GHistIndexMatrix #include "common_row_partitioner.h" // for CommonRowPartitioner #include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG #include "driver.h" // for Driver #include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre... #include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry -#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace +#include "hist/hist_cache.h" // for BoundedHistCollection +#include "hist/histogram.h" // for MultiHistogramBuilder #include "hist/param.h" // for HistMakerTrainParam #include "hist/sampler.h" // for SampleGradient -#include "param.h" // for TrainParam, SplitEntryContainer, GradStats -#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ... +#include "param.h" // for TrainParam, GradStats +#include "xgboost/base.h" // for Args, GradientPairPrecise, GradientPair, Gra... #include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo +#include "xgboost/data.h" // for BatchSet, DMatrix, BatchIterator, MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/linalg.h" // for All, MatrixView, TensorView, Matrix, Empty +#include "xgboost/json.h" // for Object, Json, FromJson, ToJson, get +#include "xgboost/linalg.h" // for MatrixView, TensorView, All, Matrix, Empty #include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE #include "xgboost/span.h" // for Span, operator!=, SpanIterator #include "xgboost/string_view.h" // for operator<< @@ -120,7 +122,7 @@ class MultiTargetHistBuilder { std::shared_ptr col_sampler_; std::unique_ptr evaluator_; // Histogram builder for each target. - std::vector> histogram_builder_; + std::unique_ptr histogram_builder_; Context const *ctx_{nullptr}; // Partitioner for each data batch. std::vector partitioner_; @@ -150,7 +152,6 @@ class MultiTargetHistBuilder { monitor_->Start(__func__); p_last_fmat_ = p_fmat; - std::size_t page_id = 0; bst_bin_t n_total_bins = 0; partitioner_.clear(); for (auto const &page : p_fmat->GetBatches(ctx_, HistBatch(param_))) { @@ -160,16 +161,13 @@ class MultiTargetHistBuilder { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit()); - page_id++; } bst_target_t n_targets = p_tree->NumTargets(); - histogram_builder_.clear(); - for (std::size_t i = 0; i < n_targets; ++i) { - histogram_builder_.emplace_back(); - histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, - collective::IsDistributed(), p_fmat->Info().IsColumnSplit()); - } + histogram_builder_ = std::make_unique(); + histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_), + collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), + hist_param_); evaluator_ = std::make_unique(ctx_, p_fmat->Info(), param_, col_sampler_); p_last_tree_ = p_tree; @@ -204,17 +202,7 @@ class MultiTargetHistBuilder { collective::GlobalSum(p_fmat->Info(), reinterpret_cast(root_sum.Values().data()), root_sum.Size() * 2); - std::vector nodes{best}; - std::size_t i = 0; - auto space = ConstructHistSpace(partitioner_, nodes); - for (auto const &page : p_fmat->GetBatches(ctx_, HistBatch(param_))) { - for (bst_target_t t{0}; t < n_targets; ++t) { - auto t_gpair = gpair.Slice(linalg::All(), t); - histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), - nodes, {}, t_gpair.Values()); - } - i++; - } + histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, best, HistBatch(param_)); auto weight = evaluator_->InitRoot(root_sum); auto weight_t = weight.HostView(); @@ -222,9 +210,10 @@ class MultiTargetHistBuilder { [&](float w) { return w * param_->learning_rate; }); p_tree->SetLeaf(RegTree::kRoot, weight_t); - std::vector hists; + std::vector hists; + std::vector nodes{{RegTree::kRoot, 0}}; for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) { - hists.push_back(&histogram_builder_[t].Histogram()); + hists.push_back(&(*histogram_builder_).Histogram(t)); } for (auto const &gmat : p_fmat->GetBatches(ctx_, HistBatch(param_))) { evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes); @@ -239,50 +228,17 @@ class MultiTargetHistBuilder { std::vector const &valid_candidates, linalg::MatrixView gpair) { monitor_->Start(__func__); - std::vector nodes_to_build; - std::vector nodes_to_sub; - - for (auto const &c : valid_candidates) { - auto left_nidx = p_tree->LeftChild(c.nid); - auto right_nidx = p_tree->RightChild(c.nid); - - auto build_nidx = left_nidx; - auto subtract_nidx = right_nidx; - auto lit = - common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); }); - auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0); - auto rit = - common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); }); - auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0); - auto fewer_right = right_sum < left_sum; - if (fewer_right) { - std::swap(build_nidx, subtract_nidx); - } - nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx)); - nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx)); - } - - std::size_t i = 0; - auto space = ConstructHistSpace(partitioner_, nodes_to_build); - for (auto const &page : p_fmat->GetBatches(ctx_, HistBatch(param_))) { - for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) { - auto t_gpair = gpair.Slice(linalg::All(), t); - // Make sure the gradient matrix is f-order. - CHECK(t_gpair.Contiguous()); - histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), - nodes_to_build, nodes_to_sub, t_gpair.Values()); - } - i++; - } + histogram_builder_->BuildHistLeftRight(p_fmat, p_tree, partitioner_, valid_candidates, gpair, + HistBatch(param_)); monitor_->Stop(__func__); } void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree, std::vector *best_splits) { monitor_->Start(__func__); - std::vector hists; + std::vector hists; for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) { - hists.push_back(&histogram_builder_[t].Histogram()); + hists.push_back(&(*histogram_builder_).Histogram(t)); } for (auto const &gmat : p_fmat->GetBatches(ctx_, HistBatch(param_))) { evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits); @@ -349,7 +305,7 @@ class HistUpdater { const RegTree *p_last_tree_{nullptr}; DMatrix const *const p_last_fmat_{nullptr}; - std::unique_ptr> histogram_builder_; + std::unique_ptr histogram_builder_; ObjInfo const *task_{nullptr}; // Context for number of threads Context const *ctx_{nullptr}; @@ -364,7 +320,7 @@ class HistUpdater { col_sampler_{std::move(column_sampler)}, evaluator_{std::make_unique(ctx, param, fmat->Info(), col_sampler_)}, p_last_fmat_(fmat), - histogram_builder_{new HistogramBuilder}, + histogram_builder_{new MultiHistogramBuilder}, task_{task}, ctx_{ctx} { monitor_->Init(__func__); @@ -387,7 +343,6 @@ class HistUpdater { // initialize temp data structure void InitData(DMatrix *fmat, RegTree const *p_tree) { monitor_->Start(__func__); - std::size_t page_id{0}; bst_bin_t n_total_bins{0}; partitioner_.clear(); for (auto const &page : fmat->GetBatches(ctx_, HistBatch(param_))) { @@ -398,10 +353,9 @@ class HistUpdater { } partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->Info().IsColumnSplit()); - ++page_id; } - histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, - collective::IsDistributed(), fmat->Info().IsColumnSplit()); + histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(), + fmat->Info().IsColumnSplit(), hist_param_); evaluator_ = std::make_unique(ctx_, this->param_, fmat->Info(), col_sampler_); p_last_tree_ = p_tree; monitor_->Stop(__func__); @@ -410,7 +364,7 @@ class HistUpdater { void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree, std::vector *best_splits) { monitor_->Start(__func__); - auto const &histograms = histogram_builder_->Histogram(); + auto const &histograms = histogram_builder_->Histogram(0); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (auto const &gmat : p_fmat->GetBatches(ctx_, HistBatch(param_))) { evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, best_splits); @@ -428,16 +382,8 @@ class HistUpdater { monitor_->Start(__func__); CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0)); - std::size_t page_id = 0; - auto space = ConstructHistSpace(partitioner_, {node}); - for (auto const &gidx : p_fmat->GetBatches(ctx_, HistBatch(param_))) { - std::vector nodes_to_build{node}; - std::vector nodes_to_sub; - this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree, - partitioner_.at(page_id).Partitions(), nodes_to_build, - nodes_to_sub, gpair.Slice(linalg::All(), 0).Values()); - ++page_id; - } + this->histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, node, + HistBatch(param_)); { GradientPairPrecise grad_stat; @@ -451,7 +397,7 @@ class HistUpdater { CHECK_GE(row_ptr.size(), 2); std::uint32_t const ibegin = row_ptr[0]; std::uint32_t const iend = row_ptr[1]; - auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot]; + auto hist = this->histogram_builder_->Histogram(0)[RegTree::kRoot]; auto begin = hist.data(); for (std::uint32_t i = ibegin; i < iend; ++i) { GradientPairPrecise const &et = begin[i]; @@ -474,7 +420,7 @@ class HistUpdater { monitor_->Start("EvaluateSplits"); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (auto const &gmat : p_fmat->GetBatches(ctx_, HistBatch(param_))) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, + evaluator_->EvaluateSplits(histogram_builder_->Histogram(0), gmat.cut, ft, *p_tree, &entries); break; } @@ -490,33 +436,8 @@ class HistUpdater { std::vector const &valid_candidates, linalg::MatrixView gpair) { monitor_->Start(__func__); - std::vector nodes_to_build(valid_candidates.size()); - std::vector nodes_to_sub(valid_candidates.size()); - - std::size_t n_idx = 0; - for (auto const &c : valid_candidates) { - auto left_nidx = (*p_tree)[c.nid].LeftChild(); - auto right_nidx = (*p_tree)[c.nid].RightChild(); - auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); - - auto build_nidx = left_nidx; - auto subtract_nidx = right_nidx; - if (fewer_right) { - std::swap(build_nidx, subtract_nidx); - } - nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}; - nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}; - n_idx++; - } - - std::size_t page_id{0}; - auto space = ConstructHistSpace(partitioner_, nodes_to_build); - for (auto const &gidx : p_fmat->GetBatches(ctx_, HistBatch(param_))) { - histogram_builder_->BuildHist(page_id, space, gidx, p_tree, - partitioner_.at(page_id).Partitions(), nodes_to_build, - nodes_to_sub, gpair.Values()); - ++page_id; - } + this->histogram_builder_->BuildHistLeftRight(p_fmat, p_tree, partitioner_, valid_candidates, + gpair, HistBatch(param_)); monitor_->Stop(__func__); } diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index f35a35bb4..70ebecd3d 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -27,8 +27,8 @@ void ParallelGHistBuilderReset() { for(size_t inode = 0; inode < kNodesExtended; inode++) { collection.AddHistRow(inode); + collection.AllocateData(inode); } - collection.AllocateAllData(); ParallelGHistBuilder hist_builder; hist_builder.Init(kBins); std::vector target_hist(kNodes); @@ -83,8 +83,8 @@ void ParallelGHistBuilderReduceHist(){ for(size_t inode = 0; inode < kNodes; inode++) { collection.AddHistRow(inode); + collection.AllocateData(inode); } - collection.AllocateAllData(); ParallelGHistBuilder hist_builder; hist_builder.Init(kBins); std::vector target_hist(kNodes); @@ -129,7 +129,7 @@ TEST(CutsBuilder, SearchGroupInd) { auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - std::vector group(kNumGroups); + std::vector group(kNumGroups); group[0] = 2; group[1] = 3; group[2] = 7; diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 3615f7587..48fd2d8e9 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -92,7 +92,7 @@ TEST(Learner, CheckGroup) { std::shared_ptr p_mat{RandomDataGenerator{kNumRows, kNumCols, 0.0f}.GenerateDMatrix()}; std::vector weight(kNumGroups, 1); - std::vector group(kNumGroups); + std::vector group(kNumGroups); group[0] = 2; group[1] = 3; group[2] = 7; diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 7bde3aca2..1685a3c80 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -4,13 +4,13 @@ #include "../test_evaluate_splits.h" #include -#include // for GradientPairPrecise, Args, Gradie... -#include // for Context -#include // for FeatureType, DMatrix, MetaInfo -#include // for CHECK_EQ -#include // for RegTree, RTreeNodeStat +#include // for GradientPairPrecise, Args, Gradie... +#include // for Context +#include // for FeatureType, DMatrix, MetaInfo +#include // for CHECK_EQ +#include // for RegTree, RTreeNodeStat -#include // for make_shared, shared_ptr, addressof +#include // for make_shared, shared_ptr, addressof #include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts #include "../../../../src/common/random.h" // for ColumnSampler @@ -18,6 +18,8 @@ #include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix #include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator #include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry +#include "../../../../src/tree/hist/hist_cache.h" // for BoundedHistCollection +#include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam #include "../../../../src/tree/param.h" // for GradStats, TrainParam #include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo... @@ -34,7 +36,7 @@ void TestEvaluateSplits(bool force_read_by_column) { auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler}; - common::HistCollection hist; + BoundedHistCollection hist; std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}}; @@ -48,9 +50,9 @@ void TestEvaluateSplits(bool force_read_by_column) { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - hist.Init(gmat.cut.Ptrs().back()); - hist.AddHistRow(0); - hist.AllocateAllData(); + HistMakerTrainParam hist_param; + hist.Reset(gmat.cut.Ptrs().back(), hist_param.internal_max_cached_hist_node); + hist.AllocateHistograms({0}); common::BuildHist(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column); // Compute total gradient for all data points @@ -111,13 +113,13 @@ TEST(HistMultiEvaluator, Evaluate) { RandomDataGenerator{n_samples, n_features, 0.5}.Targets(n_targets).GenerateDMatrix(true); HistMultiEvaluator evaluator{&ctx, p_fmat->Info(), ¶m, sampler}; - std::vector histogram(n_targets); + HistMakerTrainParam hist_param; + std::vector histogram(n_targets); linalg::Vector root_sum({2}, Context::kCpuId); for (bst_target_t t{0}; t < n_targets; ++t) { auto &hist = histogram[t]; - hist.Init(n_bins * n_features); - hist.AddHistRow(0); - hist.AllocateAllData(); + hist.Reset(n_bins * n_features, hist_param.internal_max_cached_hist_node); + hist.AllocateHistograms({0}); auto node_hist = hist[0]; node_hist[0] = {-0.5, 0.5}; node_hist[1] = {2.0, 0.5}; @@ -143,7 +145,7 @@ TEST(HistMultiEvaluator, Evaluate) { std::vector entries(1, {/*nidx=*/0, /*depth=*/0}); - std::vector ptrs; + std::vector ptrs; std::transform(histogram.cbegin(), histogram.cend(), std::back_inserter(ptrs), [](auto const &h) { return std::addressof(h); }); @@ -225,16 +227,16 @@ auto CompareOneHotAndPartition(bool onehot) { auto sampler = std::make_shared(); auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler}; std::vector entries(1); + HistMakerTrainParam hist_param; for (auto const &gmat : dmat->GetBatches(&ctx, {32, param.sparse_threshold})) { - common::HistCollection hist; + BoundedHistCollection hist; entries.front().nid = 0; entries.front().depth = 0; - hist.Init(gmat.cut.TotalBins()); - hist.AddHistRow(0); - hist.AllocateAllData(); + hist.Reset(gmat.cut.TotalBins(), hist_param.internal_max_cached_hist_node); + hist.AllocateHistograms({0}); auto node_hist = hist[0]; CHECK_EQ(node_hist.size(), n_cats); @@ -261,10 +263,10 @@ TEST(HistEvaluator, Categorical) { } TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { - common::HistCollection hist; - hist.Init(cuts_.TotalBins()); - hist.AddHistRow(0); - hist.AllocateAllData(); + BoundedHistCollection hist; + HistMakerTrainParam hist_param; + hist.Reset(cuts_.TotalBins(), hist_param.internal_max_cached_hist_node); + hist.AllocateHistograms({0}); auto node_hist = hist[0]; ASSERT_EQ(node_hist.size(), feature_histogram_.size()); std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin()); diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index b43f7e360..b90b43101 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -2,16 +2,38 @@ * Copyright 2018-2023 by Contributors */ #include -#include // Context +#include // for bst_node_t, bst_bin_t, Gradient... +#include // for Context +#include // for BatchIterator, BatchSet, DMatrix +#include // for HostDeviceVector +#include // for MakeTensorView +#include // for Error, LogCheck_EQ, LogCheck_LT +#include // for Span, operator!= +#include // for RegTree -#include +#include // for max +#include // for size_t +#include // for int32_t, uint32_t +#include // for function +#include // for back_inserter +#include // for numeric_limits +#include // for shared_ptr, allocator, unique_ptr +#include // for iota, accumulate +#include // for vector -#include "../../../../src/common/categorical.h" -#include "../../../../src/common/row_set.h" -#include "../../../../src/tree/hist/expand_entry.h" -#include "../../../../src/tree/hist/histogram.h" -#include "../../categorical_helpers.h" -#include "../../helpers.h" +#include "../../../../src/collective/communicator-inl.h" // for GetRank, GetWorldSize +#include "../../../../src/common/hist_util.h" // for GHistRow, HistogramCuts, Sketch... +#include "../../../../src/common/ref_resource_view.h" // for RefResourceView +#include "../../../../src/common/row_set.h" // for RowSetCollection +#include "../../../../src/common/threading_utils.h" // for BlockedSpace2d +#include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix +#include "../../../../src/tree/common_row_partitioner.h" // for CommonRowPartitioner +#include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry +#include "../../../../src/tree/hist/hist_cache.h" // for BoundedHistCollection +#include "../../../../src/tree/hist/histogram.h" // for HistogramBuilder +#include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam +#include "../../categorical_helpers.h" // for OneHotEncodeFeature +#include "../../helpers.h" // for RandomDataGenerator, GenerateRa... namespace xgboost::tree { namespace { @@ -25,9 +47,8 @@ void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples void TestAddHistRows(bool is_distributed) { Context ctx; - std::vector nodes_for_explicit_hist_build_; - std::vector nodes_for_subtraction_trick_; - int starting_index = std::numeric_limits::max(); + std::vector nodes_to_build; + std::vector nodes_to_sub; size_t constexpr kNRows = 8, kNCols = 16; int32_t constexpr kMaxBins = 4; @@ -40,24 +61,22 @@ void TestAddHistRows(bool is_distributed) { tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3)); - nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4)); - nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5)); - nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); + nodes_to_build.emplace_back(3); + nodes_to_build.emplace_back(4); + nodes_to_sub.emplace_back(5); + nodes_to_sub.emplace_back(6); - HistogramBuilder histogram_builder; - histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, - is_distributed, false); - histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_); + HistMakerTrainParam hist_param; + HistogramBuilder histogram_builder; + histogram_builder.Reset(&ctx, gmat.cut.TotalBins(), {kMaxBins, 0.5}, is_distributed, false, + &hist_param); + histogram_builder.AddHistRows(&tree, &nodes_to_build, &nodes_to_sub, false); - ASSERT_EQ(starting_index, 3); - - for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) { - ASSERT_EQ(histogram_builder.Histogram().RowExists(node.nid), true); + for (bst_node_t const &nidx : nodes_to_build) { + ASSERT_TRUE(histogram_builder.Histogram().HistogramExists(nidx)); } - for (const CPUExpandEntry &node : nodes_for_subtraction_trick_) { - ASSERT_EQ(histogram_builder.Histogram().RowExists(node.nid), true); + for (bst_node_t const &nidx : nodes_to_sub) { + ASSERT_TRUE(histogram_builder.Histogram().HistogramExists(nidx)); } } @@ -68,83 +87,77 @@ TEST(CPUHistogram, AddRows) { } void TestSyncHist(bool is_distributed) { - size_t constexpr kNRows = 8, kNCols = 16; - int32_t constexpr kMaxBins = 4; + std::size_t constexpr kNRows = 8, kNCols = 16; + bst_bin_t constexpr kMaxBins = 4; Context ctx; - std::vector nodes_for_explicit_hist_build_; - std::vector nodes_for_subtraction_trick_; - int starting_index = std::numeric_limits::max(); + std::vector nodes_for_explicit_hist_build; + std::vector nodes_for_subtraction_trick; RegTree tree; auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); auto const &gmat = *(p_fmat->GetBatches(&ctx, BatchParam{kMaxBins, 0.5}).begin()); - HistogramBuilder histogram; + HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); - histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false); + HistMakerTrainParam hist_param; + histogram.Reset(&ctx, total_bins, {kMaxBins, 0.5}, is_distributed, false, &hist_param); - common::RowSetCollection row_set_collection_; + common::RowSetCollection row_set_collection; { - row_set_collection_.Clear(); - std::vector &row_indices = *row_set_collection_.Data(); + row_set_collection.Clear(); + std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kNRows); std::iota(row_indices.begin(), row_indices.end(), 0); - row_set_collection_.Init(); + row_set_collection.Init(); } // level 0 - nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0)); - histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_); + nodes_for_explicit_hist_build.emplace_back(0); + histogram.AddHistRows(&tree, &nodes_for_explicit_hist_build, &nodes_for_subtraction_trick, false); tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); + nodes_for_explicit_hist_build.clear(); + nodes_for_subtraction_trick.clear(); // level 1 - nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1)); - nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2)); + nodes_for_explicit_hist_build.emplace_back(tree[0].LeftChild()); + nodes_for_subtraction_trick.emplace_back(tree[0].RightChild()); - histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_); + histogram.AddHistRows(&tree, &nodes_for_explicit_hist_build, &nodes_for_subtraction_trick, false); tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); + nodes_for_explicit_hist_build.clear(); + nodes_for_subtraction_trick.clear(); // level 2 - nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3)); - nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4)); - nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5)); - nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); + nodes_for_explicit_hist_build.emplace_back(3); + nodes_for_subtraction_trick.emplace_back(4); + nodes_for_explicit_hist_build.emplace_back(5); + nodes_for_subtraction_trick.emplace_back(6); - histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_); + histogram.AddHistRows(&tree, &nodes_for_explicit_hist_build, &nodes_for_subtraction_trick, false); - const size_t n_nodes = nodes_for_explicit_hist_build_.size(); + const size_t n_nodes = nodes_for_explicit_hist_build.size(); ASSERT_EQ(n_nodes, 2ul); - row_set_collection_.AddSplit(0, tree[0].LeftChild(), tree[0].RightChild(), 4, - 4); - row_set_collection_.AddSplit(1, tree[1].LeftChild(), tree[1].RightChild(), 2, - 2); - row_set_collection_.AddSplit(2, tree[2].LeftChild(), tree[2].RightChild(), 2, - 2); + row_set_collection.AddSplit(0, tree[0].LeftChild(), tree[0].RightChild(), 4, 4); + row_set_collection.AddSplit(1, tree[1].LeftChild(), tree[1].RightChild(), 2, 2); + row_set_collection.AddSplit(2, tree[2].LeftChild(), tree[2].RightChild(), 2, 2); common::BlockedSpace2d space( n_nodes, - [&](size_t node) { - const int32_t nid = nodes_for_explicit_hist_build_[node].nid; - return row_set_collection_[nid].Size(); + [&](std::size_t nidx_in_set) { + bst_node_t nidx = nodes_for_explicit_hist_build[nidx_in_set]; + return row_set_collection[nidx].Size(); }, 256); std::vector target_hists(n_nodes); - for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) { - const int32_t nid = nodes_for_explicit_hist_build_[i].nid; - target_hists[i] = histogram.Histogram()[nid]; + for (size_t i = 0; i < nodes_for_explicit_hist_build.size(); ++i) { + bst_node_t nidx = nodes_for_explicit_hist_build[i]; + target_hists[i] = histogram.Histogram()[nidx]; } // set values to specific nodes hist @@ -168,8 +181,7 @@ void TestSyncHist(bool is_distributed) { histogram.Buffer().Reset(1, n_nodes, space, target_hists); // sync hist - histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, starting_index); + histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick); using GHistRowT = common::GHistRow; auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right, @@ -182,11 +194,10 @@ void TestSyncHist(bool is_distributed) { } }; size_t node_id = 0; - for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) { - auto this_hist = histogram.Histogram()[node.nid]; - const size_t parent_id = tree[node.nid].Parent(); - const size_t subtraction_node_id = - nodes_for_subtraction_trick_[node_id].nid; + for (auto const &nidx : nodes_for_explicit_hist_build) { + auto this_hist = histogram.Histogram()[nidx]; + const size_t parent_id = tree[nidx].Parent(); + const size_t subtraction_node_id = nodes_for_subtraction_trick[node_id]; auto parent_hist = histogram.Histogram()[parent_id]; auto sibling_hist = histogram.Histogram()[subtraction_node_id]; @@ -194,11 +205,10 @@ void TestSyncHist(bool is_distributed) { ++node_id; } node_id = 0; - for (const CPUExpandEntry &node : nodes_for_subtraction_trick_) { - auto this_hist = histogram.Histogram()[node.nid]; - const size_t parent_id = tree[node.nid].Parent(); - const size_t subtraction_node_id = - nodes_for_explicit_hist_build_[node_id].nid; + for (auto const &nidx : nodes_for_subtraction_trick) { + auto this_hist = histogram.Histogram()[nidx]; + const size_t parent_id = tree[nidx].Parent(); + const size_t subtraction_node_id = nodes_for_explicit_hist_build[node_id]; auto parent_hist = histogram.Histogram()[parent_id]; auto sibling_hist = histogram.Histogram()[subtraction_node_id]; @@ -232,9 +242,9 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_ {0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f}}; bst_node_t nid = 0; - HistogramBuilder histogram; - histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, - is_col_split); + HistogramBuilder histogram; + HistMakerTrainParam hist_param; + histogram.Reset(&ctx, total_bins, {kMaxBins, 0.5}, is_distributed, is_col_split, &hist_param); RegTree tree; @@ -246,12 +256,17 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_ row_set_collection.Init(); CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)}; - std::vector nodes_for_explicit_hist_build; - nodes_for_explicit_hist_build.push_back(node); + std::vector nodes_to_build{node.nid}; + std::vector dummy_sub; + + histogram.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false); + common::BlockedSpace2d space{ + 1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256}; for (auto const &gidx : p_fmat->GetBatches(&ctx, {kMaxBins, 0.5})) { - histogram.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, - gpair, force_read_by_column); + histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build, + linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column); } + histogram.SyncHistogram(&tree, nodes_to_build, {}); // Check if number of histogram bins is correct ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back()); @@ -312,18 +327,18 @@ void ValidateCategoricalHistogram(size_t n_categories, void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) { size_t constexpr kRows = 340; - int32_t constexpr kBins = 256; + bst_bin_t constexpr kBins = 256; auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories); auto cat_m = GetDMatrixFromData(x, kRows, 1); cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); Context ctx; - BatchParam batch_param{0, static_cast(kBins)}; + BatchParam batch_param{0, kBins}; RegTree tree; - CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)}; - std::vector nodes_for_explicit_hist_build; - nodes_for_explicit_hist_build.push_back(node); + CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(RegTree::kRoot)}; + std::vector nodes_to_build; + nodes_to_build.push_back(node.nid); auto gpair = GenerateRandomGradients(kRows, 0, 2); @@ -333,30 +348,41 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) { row_indices.resize(kRows); std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); + HistMakerTrainParam hist_param; + std::vector dummy_sub; + + common::BlockedSpace2d space{ + 1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256}; /** * Generate hist with cat data. */ - HistogramBuilder cat_hist; + HistogramBuilder cat_hist; for (auto const &gidx : cat_m->GetBatches(&ctx, {kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); - cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false); - cat_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, - gpair.HostVector(), force_read_by_column); + cat_hist.Reset(&ctx, total_bins, {kBins, 0.5}, false, false, &hist_param); + cat_hist.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false); + cat_hist.BuildHist(0, space, gidx, row_set_collection, nodes_to_build, + linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()), + force_read_by_column); } + cat_hist.SyncHistogram(&tree, nodes_to_build, {}); /** * Generate hist with one hot encoded data. */ auto x_encoded = OneHotEncodeFeature(x, n_categories); auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories); - HistogramBuilder onehot_hist; + HistogramBuilder onehot_hist; for (auto const &gidx : encode_m->GetBatches(&ctx, {kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); - onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false); - onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, - gpair.HostVector(), force_read_by_column); + onehot_hist.Reset(&ctx, total_bins, {kBins, 0.5}, false, false, &hist_param); + onehot_hist.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false); + onehot_hist.BuildHist(0, space, gidx, row_set_collection, nodes_to_build, + linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()), + force_read_by_column); } + onehot_hist.SyncHistogram(&tree, nodes_to_build, {}); auto cat = cat_hist.Histogram()[0]; auto onehot = onehot_hist.Histogram()[0]; @@ -383,19 +409,22 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo batch_param.hess = hess; } - std::vector partition_size(1, 0); - size_t total_bins{0}; - size_t n_samples{0}; + std::vector partition_size(1, 0); + bst_bin_t total_bins{0}; + bst_row_t n_samples{0}; auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0); auto const &h_gpair = gpair.HostVector(); RegTree tree; - std::vector nodes; - nodes.emplace_back(0, tree.GetDepth(0)); + std::vector nodes{RegTree::kRoot}; + common::BlockedSpace2d space{ + 1, [&](std::size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, 256}; common::GHistRow multi_page; - HistogramBuilder multi_build; + HistogramBuilder multi_build; + HistMakerTrainParam hist_param; + std::vector dummy_sub; { /** * Multi page @@ -413,23 +442,21 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo } ASSERT_EQ(n_samples, m->Info().num_row_); - common::BlockedSpace2d space{ - 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, - 256}; - - multi_build.Reset(total_bins, batch_param, ctx->Threads(), rows_set.size(), false, false); - - size_t page_idx{0}; + multi_build.Reset(ctx, total_bins, batch_param, false, false, &hist_param); + multi_build.AddHistRows(&tree, &nodes, &dummy_sub, false); + std::size_t page_idx{0}; for (auto const &page : m->GetBatches(ctx, batch_param)) { - multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {}, h_gpair, + multi_build.BuildHist(page_idx, space, page, rows_set[page_idx], nodes, + linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()), force_read_by_column); ++page_idx; } - ASSERT_EQ(page_idx, 2); - multi_page = multi_build.Histogram()[0]; + multi_build.SyncHistogram(&tree, nodes, {}); + + multi_page = multi_build.Histogram()[RegTree::kRoot]; } - HistogramBuilder single_build; + HistogramBuilder single_build; common::GHistRow single_page; { /** @@ -438,18 +465,24 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo common::RowSetCollection row_set_collection; InitRowPartitionForTest(&row_set_collection, n_samples); - single_build.Reset(total_bins, batch_param, ctx->Threads(), 1, false, false); + single_build.Reset(ctx, total_bins, batch_param, false, false, &hist_param); SparsePage concat; std::vector hess(m->Info().num_row_, 1.0f); - for (auto const& page : m->GetBatches()) { + for (auto const &page : m->GetBatches()) { concat.Push(page); } auto cut = common::SketchOnDMatrix(ctx, m.get(), batch_param.max_bin, false, hess); GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false, std::numeric_limits::quiet_NaN(), ctx->Threads()); - single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column); - single_page = single_build.Histogram()[0]; + + single_build.AddHistRows(&tree, &nodes, &dummy_sub, false); + single_build.BuildHist(0, space, gmat, row_set_collection, nodes, + linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()), + force_read_by_column); + single_build.SyncHistogram(&tree, nodes, {}); + + single_page = single_build.Histogram()[RegTree::kRoot]; } for (size_t i = 0; i < single_page.size(); ++i) { @@ -473,4 +506,108 @@ TEST(CPUHistogram, ExternalMemory) { TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, false); TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true); } + +namespace { +class OverflowTest : public ::testing::TestWithParam> { + public: + std::vector TestOverflow(bool limit, bool is_distributed, + bool is_col_split) { + bst_bin_t constexpr kBins = 256; + Context ctx; + HistMakerTrainParam hist_param; + if (limit) { + hist_param.Init(Args{{"internal_max_cached_hist_node", "1"}}); + } + + std::shared_ptr Xy = + is_col_split ? RandomDataGenerator{8192, 16, 0.5}.GenerateDMatrix(true) + : RandomDataGenerator{8192, 16, 0.5}.Bins(kBins).GenerateQuantileDMatrix(true); + if (is_col_split) { + Xy = + std::shared_ptr{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())}; + } + + double sparse_thresh{TrainParam::DftSparseThreshold()}; + auto batch = BatchParam{kBins, sparse_thresh}; + bst_bin_t n_total_bins{0}; + float split_cond{0}; + for (auto const &page : Xy->GetBatches(&ctx, batch)) { + n_total_bins = page.cut.TotalBins(); + // use a cut point in the second column for split + split_cond = page.cut.Values()[kBins + kBins / 2]; + } + + RegTree tree; + MultiHistogramBuilder hist_builder; + CHECK_EQ(Xy->Info().IsColumnSplit(), is_col_split); + + hist_builder.Reset(&ctx, n_total_bins, tree.NumTargets(), batch, is_distributed, + Xy->Info().IsColumnSplit(), &hist_param); + + std::vector partitioners; + partitioners.emplace_back(&ctx, Xy->Info().num_row_, /*base_rowid=*/0, + Xy->Info().IsColumnSplit()); + + auto gpair = GenerateRandomGradients(Xy->Info().num_row_, 0.0, 1.0); + + CPUExpandEntry best; + hist_builder.BuildRootHist(Xy.get(), &tree, partitioners, + linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), + best, batch); + + best.split.Update(1.0f, 1, split_cond, false, false, GradStats{1.0, 1.0}, GradStats{1.0, 1.0}); + tree.ExpandNode(best.nid, best.split.SplitIndex(), best.split.split_value, false, + /*base_weight=*/2.0f, + /*left_leaf_weight=*/1.0f, /*right_leaf_weight=*/1.0f, best.GetLossChange(), + /*sum_hess=*/2.0f, best.split.left_sum.GetHess(), + best.split.right_sum.GetHess()); + + std::vector valid_candidates{best}; + for (auto const &page : Xy->GetBatches(&ctx, batch)) { + partitioners.front().UpdatePosition(&ctx, page, valid_candidates, &tree); + } + CHECK_NE(partitioners.front()[tree.LeftChild(best.nid)].Size(), 0); + CHECK_NE(partitioners.front()[tree.RightChild(best.nid)].Size(), 0); + + hist_builder.BuildHistLeftRight( + Xy.get(), &tree, partitioners, valid_candidates, + linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), batch); + + if (limit) { + CHECK(!hist_builder.Histogram(0).HistogramExists(best.nid)); + } else { + CHECK(hist_builder.Histogram(0).HistogramExists(best.nid)); + } + + std::vector result; + auto hist = hist_builder.Histogram(0)[tree.LeftChild(best.nid)]; + std::copy(hist.cbegin(), hist.cend(), std::back_inserter(result)); + hist = hist_builder.Histogram(0)[tree.RightChild(best.nid)]; + std::copy(hist.cbegin(), hist.cend(), std::back_inserter(result)); + + return result; + } + + void RunTest() { + auto param = GetParam(); + auto res0 = this->TestOverflow(false, std::get<0>(param), std::get<1>(param)); + auto res1 = this->TestOverflow(true, std::get<0>(param), std::get<1>(param)); + ASSERT_EQ(res0, res1); + } +}; + +auto MakeParamsForTest() { + std::vector> configs; + for (auto i : {true, false}) { + for (auto j : {true, false}) { + configs.emplace_back(i, j); + } + } + return configs; +} +} // anonymous namespace + +TEST_P(OverflowTest, Overflow) { this->RunTest(); } + +INSTANTIATE_TEST_SUITE_P(CPUHistogram, OverflowTest, ::testing::ValuesIn(MakeParamsForTest())); } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h index a7e8972e5..04da4777d 100644 --- a/tests/cpp/tree/test_evaluate_splits.h +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -2,22 +2,24 @@ * Copyright 2022-2023 by XGBoost Contributors */ #include -#include // for GradientPairInternal, GradientPairPrecise -#include // for MetaInfo -#include // for HostDeviceVector -#include // for operator!=, Span, SpanIterator +#include // for GradientPairInternal, GradientPairPrecise +#include // for MetaInfo +#include // for HostDeviceVector +#include // for operator!=, Span, SpanIterator -#include // for max, max_element, next_permutation, copy -#include // for isnan -#include // for size_t -#include // for int32_t, uint64_t, uint32_t -#include // for numeric_limits -#include // for iota -#include // for make_tuple, tie, tuple -#include // for pair -#include // for vector +#include // for max, max_element, next_permutation, copy +#include // for isnan +#include // for size_t +#include // for int32_t, uint64_t, uint32_t +#include // for numeric_limits +#include // for iota +#include // for make_tuple, tie, tuple +#include // for pair +#include // for vector #include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow +#include "../../../src/tree/hist/hist_cache.h" // for HistogramCollection +#include "../../../src/tree/hist/param.h" // for HistMakerTrainParam #include "../../../src/tree/param.h" // for TrainParam, GradStats #include "../../../src/tree/split_evaluator.h" // for TreeEvaluator #include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution @@ -35,7 +37,7 @@ class TestPartitionBasedSplit : public ::testing::Test { MetaInfo info_; float best_score_{-std::numeric_limits::infinity()}; common::HistogramCuts cuts_; - common::HistCollection hist_; + BoundedHistCollection hist_; GradientPairPrecise total_gpair_; void SetUp() override { @@ -56,9 +58,9 @@ class TestPartitionBasedSplit : public ::testing::Test { cuts_.min_vals_.Resize(1); - hist_.Init(cuts_.TotalBins()); - hist_.AddHistRow(0); - hist_.AllocateAllData(); + HistMakerTrainParam hist_param; + hist_.Reset(cuts_.TotalBins(), hist_param.internal_max_cached_hist_node); + hist_.AllocateHistograms({0}); auto node_hist = hist_[0]; SimpleLCG lcg; diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index ace17933b..4cfc61321 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -7,6 +7,7 @@ from hypothesis import given, settings, strategies import xgboost as xgb from xgboost import testing as tm from xgboost.testing.data import check_inf +from xgboost.testing.data_iter import run_mixed_sparsity sys.path.append("tests/python") import test_quantile_dmatrix as tqd @@ -232,3 +233,6 @@ class TestQuantileDMatrix: rng = cp.random.default_rng(1994) check_inf(rng) + + def test_mixed_sparsity(self) -> None: + run_mixed_sparsity("cuda") diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index b7428dfac..8ee00b8c0 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -16,6 +16,7 @@ from xgboost.testing import ( predictor_equal, ) from xgboost.testing.data import check_inf, np_dtypes +from xgboost.testing.data_iter import run_mixed_sparsity class TestQuantileDMatrix: @@ -334,3 +335,6 @@ class TestQuantileDMatrix: with pytest.raises(ValueError, match="consistent"): xgb.train({}, Xy, num_boost_round=2, xgb_model=booster) + + def test_mixed_sparsity(self) -> None: + run_mixed_sparsity("cpu") diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 5374a2891..3fa32660d 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -11,6 +11,7 @@ from xgboost import testing as tm from xgboost.testing.params import ( cat_parameter_strategy, exact_parameter_strategy, + hist_cache_strategy, hist_multi_parameter_strategy, hist_parameter_strategy, ) @@ -40,14 +41,22 @@ class TestTreeMethodMulti: @given( exact_parameter_strategy, hist_parameter_strategy, + hist_cache_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy, ) @settings(deadline=None, print_blob=True) - def test_approx(self, param, hist_param, num_rounds, dataset): + def test_approx( + self, param: Dict[str, Any], + hist_param: Dict[str, Any], + cache_param: Dict[str, Any], + num_rounds: int, + dataset: tm.TestDataset, + ) -> None: param["tree_method"] = "approx" param = dataset.set_params(param) param.update(hist_param) + param.update(cache_param) result = train_result(param, dataset.get_dmat(), num_rounds) note(result) assert tm.non_increasing(result["train"][dataset.metric]) @@ -55,18 +64,25 @@ class TestTreeMethodMulti: @given( exact_parameter_strategy, hist_multi_parameter_strategy, + hist_cache_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy, ) @settings(deadline=None, print_blob=True) def test_hist( - self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset + self, + param: Dict[str, Any], + hist_param: Dict[str, Any], + cache_param: Dict[str, Any], + num_rounds: int, + dataset: tm.TestDataset, ) -> None: if dataset.name.endswith("-l1"): return param["tree_method"] = "hist" param = dataset.set_params(param) param.update(hist_param) + param.update(cache_param) result = train_result(param, dataset.get_dmat(), num_rounds) note(result) assert tm.non_increasing(result["train"][dataset.metric]) @@ -91,14 +107,23 @@ class TestTreeMethod: @given( exact_parameter_strategy, hist_parameter_strategy, + hist_cache_strategy, strategies.integers(1, 20), tm.make_dataset_strategy(), ) @settings(deadline=None, print_blob=True) - def test_approx(self, param, hist_param, num_rounds, dataset): + def test_approx( + self, + param: Dict[str, Any], + hist_param: Dict[str, Any], + cache_param: Dict[str, Any], + num_rounds: int, + dataset: tm.TestDataset, + ) -> None: param["tree_method"] = "approx" param = dataset.set_params(param) param.update(hist_param) + param.update(cache_param) result = train_result(param, dataset.get_dmat(), num_rounds) note(result) assert tm.non_increasing(result["train"][dataset.metric]) @@ -130,17 +155,25 @@ class TestTreeMethod: @given( exact_parameter_strategy, hist_parameter_strategy, + hist_cache_strategy, strategies.integers(1, 20), tm.make_dataset_strategy() ) @settings(deadline=None, print_blob=True) - def test_hist(self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset) -> None: - param['tree_method'] = 'hist' + def test_hist( + self, param: Dict[str, Any], + hist_param: Dict[str, Any], + cache_param: Dict[str, Any], + num_rounds: int, + dataset: tm.TestDataset, + ) -> None: + param["tree_method"] = "hist" param = dataset.set_params(param) param.update(hist_param) + param.update(cache_param) result = train_result(param, dataset.get_dmat(), num_rounds) note(result) - assert tm.non_increasing(result['train'][dataset.metric]) + assert tm.non_increasing(result["train"][dataset.metric]) def test_hist_categorical(self): # hist must be same as exact on all-categorial data diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index 3add01192..5630e5f3e 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -24,7 +24,7 @@ from sklearn.datasets import make_classification, make_regression import xgboost as xgb from xgboost import testing as tm from xgboost.data import _is_cudf_df -from xgboost.testing.params import hist_parameter_strategy +from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy from xgboost.testing.shared import ( get_feature_weights, validate_data_initialization, @@ -1512,14 +1512,23 @@ class TestWithDask: else: assert history[-1] < history[0] - @given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy()) + @given( + params=hist_parameter_strategy, + cache_param=hist_cache_strategy, + dataset=tm.make_dataset_strategy(), + ) @settings( deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True ) def test_hist( - self, params: Dict, dataset: tm.TestDataset, client: "Client" + self, + params: Dict[str, Any], + cache_param: Dict[str, Any], + dataset: tm.TestDataset, + client: "Client", ) -> None: num_rounds = 10 + params.update(cache_param) self.run_updater_test(client, params, num_rounds, dataset, "hist") def test_quantile_dmatrix(self, client: Client) -> None: @@ -1579,14 +1588,23 @@ class TestWithDask: rmse = result["history"]["Valid"]["rmse"][-1] assert rmse < 32.0 - @given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy()) + @given( + params=hist_parameter_strategy, + cache_param=hist_cache_strategy, + dataset=tm.make_dataset_strategy() + ) @settings( deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True ) def test_approx( - self, client: "Client", params: Dict, dataset: tm.TestDataset + self, + client: "Client", + params: Dict, + cache_param: Dict[str, Any], + dataset: tm.TestDataset, ) -> None: num_rounds = 10 + params.update(cache_param) self.run_updater_test(client, params, num_rounds, dataset, "approx") def test_adaptive(self) -> None: @@ -2239,7 +2257,7 @@ async def test_worker_left(c, s, a, b): ) await async_poll_for(lambda: len(s.workers) == 2, timeout=5) with pytest.raises(RuntimeError, match="Missing"): - await xgb.dask.train( + await xgb.dask.train( c, {}, d_train, @@ -2256,7 +2274,7 @@ async def test_worker_restarted(c, s, a, b): ) await c.restart_workers([a.worker_address]) with pytest.raises(RuntimeError, match="Missing"): - await xgb.dask.train( + await xgb.dask.train( c, {}, d_train,