From 61dd854a523def39d6b0d2952656bcce857a70f8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 30 Aug 2024 02:39:14 +0800 Subject: [PATCH] [EM] Refactor GPU histogram builder. (#10764) - Expose the maximum number of cached nodes to be consistent with the CPU implementation. Also easier for testing. - Extract the subtraction trick for easier testing. - Split up the `GradientQuantiser` to avoid circular dependency. --- doc/parameter.rst | 6 +- include/xgboost/c_api.h | 1 + src/data/ellpack_page_raw_format.cu | 8 +- src/tree/gpu_hist/expand_entry.cuh | 8 +- src/tree/gpu_hist/histogram.cu | 27 ++- src/tree/gpu_hist/histogram.cuh | 130 ++++++++----- src/tree/gpu_hist/quantiser.cuh | 39 ++++ src/tree/hist/hist_cache.h | 8 +- src/tree/hist/histogram.h | 2 +- src/tree/hist/param.h | 26 ++- src/tree/updater_gpu_common.cuh | 8 +- src/tree/updater_gpu_hist.cu | 202 +++++++++++--------- tests/cpp/tree/gpu_hist/test_histogram.cu | 70 +++++-- tests/cpp/tree/hist/test_evaluate_splits.cc | 12 +- tests/cpp/tree/test_evaluate_splits.h | 3 +- tests/cpp/tree/test_gpu_hist.cu | 2 +- tests/python-gpu/test_gpu_updaters.py | 29 +++ 17 files changed, 394 insertions(+), 187 deletions(-) create mode 100644 src/tree/gpu_hist/quantiser.cuh diff --git a/doc/parameter.rst b/doc/parameter.rst index a77655922..49d42f838 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -232,12 +232,12 @@ Parameters for Tree Booster * ``max_cached_hist_node``, [default = 65536] - Maximum number of cached nodes for CPU histogram. + Maximum number of cached nodes for histogram. .. versionadded:: 2.0.0 - - For most of the cases this parameter should not be set except for growing deep trees - on CPU. + - For most of the cases this parameter should not be set except for growing deep + trees. After 3.0, this parameter affects GPU algorithms as well. .. _cat-param: diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index ffff11ddb..c4ab4f246 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -522,6 +522,7 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand * - nthread (optional): Number of threads used for initializing DMatrix. * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with the corresponding booster training parameter. + * - on_host (optional): Whether the data should be placed on host memory. Used by GPU inputs. * @param out The created Quantile DMatrix. * * @return 0 when success, -1 when failure happens diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 4f39497b8..8d317aca5 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -60,10 +60,10 @@ template RET_IF_NOT(fi->Read(&impl->is_dense)); RET_IF_NOT(fi->Read(&impl->row_stride)); - if (has_hmm_ats_ && !this->param_.prefetch_copy) { - RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer)); - } else { + if (this->param_.prefetch_copy || !has_hmm_ats_) { RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer)); + } else { + RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer)); } RET_IF_NOT(fi->Read(&impl->base_rowid)); dh::DefaultStream().Sync(); @@ -95,7 +95,7 @@ template CHECK(this->cuts_->cut_values_.DeviceCanRead()); impl->SetCuts(this->cuts_); - fi->Read(page, this->param_.prefetch_copy); + fi->Read(page, this->param_.prefetch_copy || !this->has_hmm_ats_); dh::DefaultStream().Sync(); return true; diff --git a/src/tree/gpu_hist/expand_entry.cuh b/src/tree/gpu_hist/expand_entry.cuh index 42dc7f49a..b4dc41da2 100644 --- a/src/tree/gpu_hist/expand_entry.cuh +++ b/src/tree/gpu_hist/expand_entry.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023, XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #ifndef EXPAND_ENTRY_CUH_ #define EXPAND_ENTRY_CUH_ @@ -7,9 +7,9 @@ #include // for numeric_limits #include // for move -#include "../param.h" -#include "../updater_gpu_common.cuh" -#include "xgboost/base.h" // for bst_node_t +#include "../param.h" // for TrainParam +#include "../updater_gpu_common.cuh" // for DeviceSplitCandidate +#include "xgboost/base.h" // for bst_node_t namespace xgboost::tree { struct GPUExpandEntry { diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 731e71367..364df3fe4 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -356,13 +356,19 @@ class DeviceHistogramBuilderImpl { }; DeviceHistogramBuilder::DeviceHistogramBuilder() - : p_impl_{std::make_unique()} {} + : p_impl_{std::make_unique()} { + monitor_.Init(__func__); +} DeviceHistogramBuilder::~DeviceHistogramBuilder() = default; -void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, - bool force_global_memory) { +void DeviceHistogramBuilder::Reset(Context const* ctx, std::size_t max_cached_hist_nodes, + FeatureGroupsAccessor const& feature_groups, + bst_bin_t n_total_bins, bool force_global_memory) { + this->monitor_.Start(__func__); this->p_impl_->Reset(ctx, feature_groups, force_global_memory); + this->hist_.Reset(ctx, n_total_bins, max_cached_hist_nodes); + this->monitor_.Stop(__func__); } void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, @@ -372,6 +378,21 @@ void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, common::Span ridx, common::Span histogram, GradientQuantiser rounding) { + this->monitor_.Start(__func__); this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding); + this->monitor_.Stop(__func__); +} + +void DeviceHistogramBuilder::AllReduceHist(Context const* ctx, MetaInfo const& info, + bst_node_t nidx, std::size_t num_histograms) { + this->monitor_.Start(__func__); + auto d_node_hist = hist_.GetNodeHistogram(nidx); + using ReduceT = typename std::remove_pointer::type::ValueT; + auto rc = collective::GlobalSum( + ctx, info, + linalg::MakeVec(reinterpret_cast(d_node_hist.data()), + d_node_hist.size() * 2 * num_histograms, ctx->Device())); + SafeColl(rc); + this->monitor_.Stop(__func__); } } // namespace xgboost::tree diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 87c60a8bf..95a00fd79 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -9,7 +9,9 @@ #include "../../common/device_helpers.cuh" // for LaunchN #include "../../common/device_vector.cuh" // for device_vector #include "../../data/ellpack_page.cuh" // for EllpackDeviceAccessor +#include "expand_entry.cuh" // for GPUExpandEntry #include "feature_groups.cuh" // for FeatureGroupsAccessor +#include "quantiser.cuh" // for GradientQuantiser #include "xgboost/base.h" // for GradientPair, GradientPairInt64 #include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span @@ -34,77 +36,51 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) { atomicAdd(y_high, sig); } -class GradientQuantiser { - private: - /* Convert gradient to fixed point representation. */ - GradientPairPrecise to_fixed_point_; - /* Convert fixed point representation back to floating point. */ - GradientPairPrecise to_floating_point_; - - public: - GradientQuantiser(Context const* ctx, common::Span gpair, MetaInfo const& info); - [[nodiscard]] XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const { - auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(), - gpair.GetHess() * to_fixed_point_.GetHess()); - return adjusted; - } - [[nodiscard]] XGBOOST_DEVICE GradientPairInt64 - ToFixedPoint(GradientPairPrecise const& gpair) const { - auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(), - gpair.GetHess() * to_fixed_point_.GetHess()); - return adjusted; - } - [[nodiscard]] XGBOOST_DEVICE GradientPairPrecise - ToFloatingPoint(const GradientPairInt64& gpair) const { - auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad(); - auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess(); - return {g,h}; - } -}; +namespace cuda_impl { +// Start with about 16mb +std::size_t constexpr DftReserveSize() { return 1 << 22; } +} // namespace cuda_impl /** * @brief Data storage for node histograms on device. Automatically expands. * - * @tparam kStopGrowingSize Do not grow beyond this size - * * @author Rory * @date 28/07/2018 */ -template class DeviceHistogramStorage { private: using GradientSumT = GradientPairInt64; + std::size_t stop_growing_size_{0}; /** @brief Map nidx to starting index of its histogram. */ std::map nidx_map_; // Large buffer of zeroed memory, caches histograms dh::device_vector data_; - // If we run out of storage allocate one histogram at a time - // in overflow. Not cached, overwritten when a new histogram - // is requested + // If we run out of storage allocate one histogram at a time in overflow. Not cached, + // overwritten when a new histogram is requested dh::device_vector overflow_; std::map overflow_nidx_map_; int n_bins_; - DeviceOrd device_id_; - static constexpr size_t kNumItemsInGradientSum = + static constexpr std::size_t kNumItemsInGradientSum = sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT); static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2."); public: - // Start with about 16mb - DeviceHistogramStorage() { data_.reserve(1 << 22); } - void Init(DeviceOrd device_id, int n_bins) { - this->n_bins_ = n_bins; - this->device_id_ = device_id; - } + explicit DeviceHistogramStorage() { data_.reserve(cuda_impl::DftReserveSize()); } - void Reset(Context const* ctx) { + void Reset(Context const* ctx, bst_bin_t n_total_bins, std::size_t max_cached_nodes) { + this->n_bins_ = n_total_bins; auto d_data = data_.data().get(); dh::LaunchN(data_.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); overflow_nidx_map_.clear(); + + auto max_cached_bin_values = + static_cast(n_total_bins) * max_cached_nodes * kNumItemsInGradientSum; + this->stop_growing_size_ = max_cached_bin_values; } - [[nodiscard]] bool HistogramExists(int nidx) const { + + [[nodiscard]] bool HistogramExists(bst_node_t nidx) const { return nidx_map_.find(nidx) != nidx_map_.cend() || overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); } @@ -112,14 +88,15 @@ class DeviceHistogramStorage { [[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } dh::device_vector& Data() { return data_; } - void AllocateHistograms(Context const* ctx, const std::vector& new_nidxs) { + void AllocateHistograms(Context const* ctx, std::vector const& new_nidxs) { for (int nidx : new_nidxs) { CHECK(!HistogramExists(nidx)); } // Number of items currently used in data const size_t used_size = nidx_map_.size() * HistogramSize(); const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size(); - if (used_size >= kStopGrowingSize) { + CHECK_GE(this->stop_growing_size_, kNumItemsInGradientSum); + if (used_size >= this->stop_growing_size_) { // Use overflow // Delete previous entries overflow_nidx_map_.clear(); @@ -171,18 +148,77 @@ class DeviceHistogramBuilderImpl; class DeviceHistogramBuilder { std::unique_ptr p_impl_; + DeviceHistogramStorage hist_; + common::Monitor monitor_; public: - DeviceHistogramBuilder(); + explicit DeviceHistogramBuilder(); ~DeviceHistogramBuilder(); - void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, + void Reset(Context const* ctx, std::size_t max_cached_hist_nodes, + FeatureGroupsAccessor const& feature_groups, bst_bin_t n_total_bins, bool force_global_memory); void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, GradientQuantiser rounding); + + [[nodiscard]] auto GetNodeHistogram(bst_node_t nidx) { return hist_.GetNodeHistogram(nidx); } + + // num histograms is the number of contiguous histograms in memory to reduce over + void AllReduceHist(Context const* ctx, MetaInfo const& info, bst_node_t nidx, + std::size_t num_histograms); + + // Attempt to do subtraction trick + // return true if succeeded + [[nodiscard]] bool SubtractionTrick(bst_node_t nidx_parent, bst_node_t nidx_histogram, + bst_node_t nidx_subtraction) { + if (!hist_.HistogramExists(nidx_histogram) || !hist_.HistogramExists(nidx_parent)) { + return false; + } + auto d_node_hist_parent = hist_.GetNodeHistogram(nidx_parent); + auto d_node_hist_histogram = hist_.GetNodeHistogram(nidx_histogram); + auto d_node_hist_subtraction = hist_.GetNodeHistogram(nidx_subtraction); + + dh::LaunchN(d_node_hist_parent.size(), [=] __device__(size_t idx) { + d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; + }); + return true; + } + + [[nodiscard]] auto SubtractHist(std::vector const& candidates, + std::vector const& build_nidx, + std::vector const& subtraction_nidx) { + this->monitor_.Start(__func__); + std::vector need_build; + for (std::size_t i = 0; i < subtraction_nidx.size(); i++) { + auto build_hist_nidx = build_nidx.at(i); + auto subtraction_trick_nidx = subtraction_nidx.at(i); + auto parent_nidx = candidates.at(i).nid; + + if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { + need_build.push_back(subtraction_trick_nidx); + } + } + this->monitor_.Stop(__func__); + return need_build; + } + + void AllocateHistograms(Context const* ctx, std::vector const& nodes_to_build, + std::vector const& nodes_to_sub) { + this->monitor_.Start(__func__); + std::vector all_new = nodes_to_build; + all_new.insert(all_new.end(), nodes_to_sub.cbegin(), nodes_to_sub.cend()); + // Allocate the histograms + // Guaranteed contiguous memory + this->AllocateHistograms(ctx, all_new); + this->monitor_.Stop(__func__); + } + + void AllocateHistograms(Context const* ctx, std::vector const& new_nidxs) { + this->hist_.AllocateHistograms(ctx, new_nidxs); + } }; } // namespace xgboost::tree #endif // HISTOGRAM_CUH_ diff --git a/src/tree/gpu_hist/quantiser.cuh b/src/tree/gpu_hist/quantiser.cuh new file mode 100644 index 000000000..36bd5a1d3 --- /dev/null +++ b/src/tree/gpu_hist/quantiser.cuh @@ -0,0 +1,39 @@ +/** + * Copyright 2020-2024, XGBoost Contributors + */ +#pragma once +#include "xgboost/base.h" // for GradientPairPrecise, GradientPairInt64 +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/span.h" // for Span + +namespace xgboost::tree { +class GradientQuantiser { + private: + /* Convert gradient to fixed point representation. */ + GradientPairPrecise to_fixed_point_; + /* Convert fixed point representation back to floating point. */ + GradientPairPrecise to_floating_point_; + + public: + GradientQuantiser(Context const* ctx, common::Span gpair, + MetaInfo const& info); + [[nodiscard]] XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const { + auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(), + gpair.GetHess() * to_fixed_point_.GetHess()); + return adjusted; + } + [[nodiscard]] XGBOOST_DEVICE GradientPairInt64 + ToFixedPoint(GradientPairPrecise const& gpair) const { + auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(), + gpair.GetHess() * to_fixed_point_.GetHess()); + return adjusted; + } + [[nodiscard]] XGBOOST_DEVICE GradientPairPrecise + ToFloatingPoint(const GradientPairInt64& gpair) const { + auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad(); + auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess(); + return {g, h}; + } +}; +} // namespace xgboost::tree diff --git a/src/tree/hist/hist_cache.h b/src/tree/hist/hist_cache.h index 715e1d73e..d70941b0c 100644 --- a/src/tree/hist/hist_cache.h +++ b/src/tree/hist/hist_cache.h @@ -11,7 +11,7 @@ #include "../../common/hist_util.h" // for GHistRow, ConstGHistRow #include "../../common/ref_resource_view.h" // for ReallocVector #include "xgboost/base.h" // for bst_node_t, bst_bin_t -#include "xgboost/logging.h" // for CHECK_GT +#include "xgboost/logging.h" // for CHECK_EQ #include "xgboost/span.h" // for Span namespace xgboost::tree { @@ -40,7 +40,7 @@ class BoundedHistCollection { // number of histogram bins across all features bst_bin_t n_total_bins_{0}; // limits the number of nodes that can be in the cache for each tree - std::size_t n_cached_nodes_{0}; + std::size_t max_cached_nodes_{0}; // whether the tree has grown beyond the cache limit bool has_exceeded_{false}; @@ -58,7 +58,7 @@ class BoundedHistCollection { } void Reset(bst_bin_t n_total_bins, std::size_t n_cached_nodes) { n_total_bins_ = n_total_bins; - n_cached_nodes_ = n_cached_nodes; + max_cached_nodes_ = n_cached_nodes; this->Clear(false); } /** @@ -73,7 +73,7 @@ class BoundedHistCollection { [[nodiscard]] bool CanHost(common::Span nodes_to_build, common::Span nodes_to_sub) const { auto n_new_nodes = nodes_to_build.size() + nodes_to_sub.size(); - return n_new_nodes + node_map_.size() <= n_cached_nodes_; + return n_new_nodes + node_map_.size() <= max_cached_nodes_; } /** diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 1e9dc9c7d..fcfa03e03 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -61,7 +61,7 @@ class HistogramBuilder { bool is_col_split, HistMakerTrainParam const *param) { n_threads_ = ctx->Threads(); param_ = p; - hist_.Reset(total_bins, param->max_cached_hist_node); + hist_.Reset(total_bins, param->MaxCachedHistNodes(ctx->Device())); buffer_.Init(total_bins); is_distributed_ = is_distributed; is_col_split_ = is_col_split; diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h index aa9d8cedf..e981e886a 100644 --- a/src/tree/hist/param.h +++ b/src/tree/hist/param.h @@ -1,31 +1,47 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #pragma once #include // for size_t +#include // for numeric_limits #include "xgboost/parameter.h" // for XGBoostParameter #include "xgboost/tree_model.h" // for RegTree +#include "xgboost/context.h" // for DeviceOrd namespace xgboost::tree { struct HistMakerTrainParam : public XGBoostParameter { - constexpr static std::size_t DefaultNodes() { return static_cast(1) << 16; } + private: + constexpr static std::size_t NotSet() { return std::numeric_limits::max(); } + + std::size_t max_cached_hist_node{NotSet()}; // NOLINT + + public: + // Smaller for GPU due to memory limitation. + constexpr static std::size_t CpuDefaultNodes() { return static_cast(1) << 16; } + constexpr static std::size_t CudaDefaultNodes() { return static_cast(1) << 12; } bool debug_synchronize{false}; - std::size_t max_cached_hist_node{DefaultNodes()}; void CheckTreesSynchronized(Context const* ctx, RegTree const* local_tree) const; + std::size_t MaxCachedHistNodes(DeviceOrd device) const { + if (max_cached_hist_node != NotSet()) { + return max_cached_hist_node; + } + return device.IsCPU() ? CpuDefaultNodes() : CudaDefaultNodes(); + } + // declare parameters DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { DMLC_DECLARE_FIELD(debug_synchronize) .set_default(false) .describe("Check if all distributed tree are identical after tree construction."); DMLC_DECLARE_FIELD(max_cached_hist_node) - .set_default(DefaultNodes()) + .set_default(NotSet()) .set_lower_bound(1) - .describe("Maximum number of nodes in CPU histogram cache. Only for internal usage."); + .describe("Maximum number of nodes in histogram cache."); } }; } // namespace xgboost::tree diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index f60d45196..f0e353e22 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -5,10 +5,10 @@ #include // for numeric_limits #include // for ostream -#include "gpu_hist/histogram.cuh" -#include "param.h" // for TrainParam -#include "xgboost/base.h" -#include "xgboost/task.h" // for ObjInfo +#include "gpu_hist/quantiser.cuh" // for GradientQuantiser +#include "param.h" // for TrainParam +#include "xgboost/base.h" // for bst_bin_t +#include "xgboost/task.h" // for ObjInfo namespace xgboost::tree { struct GPUTrainingParam { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 03b0e5a42..e4e27b72a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -64,6 +64,47 @@ struct NodeSplitData { }; static_assert(std::is_trivially_copyable_v); +// To be tuned. +constexpr double ExtMemPrefetchThresh() { return 4.0; } + +// Some nodes we will manually compute histograms, others we will do by subtraction +[[nodiscard]] bool AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer, + std::vector const& candidates, + common::Span nodes_to_build, + common::Span nodes_to_sub) { + auto const& tree = *p_tree; + std::size_t nidx_in_set{0}; + double total{0.0}, smaller{0.0}; + auto p_build_nidx = nodes_to_build.data(); + auto p_sub_nidx = nodes_to_sub.data(); + for (auto& e : candidates) { + // Decide whether to build the left histogram or right histogram Use sum of Hessian as + // a heuristic to select node with fewest training instances This optimization is for + // distributed training to avoid an allreduce call for synchronizing the number of + // instances for each node. + auto left_sum = quantizer->ToFloatingPoint(e.split.left_sum); + auto right_sum = quantizer->ToFloatingPoint(e.split.right_sum); + bool fewer_right = right_sum.GetHess() < left_sum.GetHess(); + total += left_sum.GetHess() + right_sum.GetHess(); + if (fewer_right) { + p_build_nidx[nidx_in_set] = tree[e.nid].RightChild(); + p_sub_nidx[nidx_in_set] = tree[e.nid].LeftChild(); + smaller += right_sum.GetHess(); + } else { + p_build_nidx[nidx_in_set] = tree[e.nid].LeftChild(); + p_sub_nidx[nidx_in_set] = tree[e.nid].RightChild(); + smaller += left_sum.GetHess(); + } + ++nidx_in_set; + } + + if (-kRtEps < smaller && smaller < kRtEps) { // Too close to 0, don't prefetch. + return false; + } + // Prefetch if these smaller nodes are not quite small. + return (total / smaller) < ExtMemPrefetchThresh(); +} + // GPU tree updater implementation. struct GPUHistMakerDevice { private: @@ -78,11 +119,31 @@ struct GPUHistMakerDevice { std::vector batch_ptr_; // node idx for each sample dh::device_vector positions_; + HistMakerTrainParam const* hist_param_; std::shared_ptr cuts_{nullptr}; - public: - DeviceHistogramStorage<> hist{}; + auto CreatePartitionNodes(RegTree const* p_tree, std::vector const& candidates) { + std::vector nidx(candidates.size()); + std::vector left_nidx(candidates.size()); + std::vector right_nidx(candidates.size()); + std::vector split_data(candidates.size()); + for (std::size_t i = 0, n = candidates.size(); i < n; i++) { + auto const& e = candidates[i]; + RegTree::Node split_node = (*p_tree)[e.nid]; + auto split_type = p_tree->NodeSplitType(e.nid); + nidx.at(i) = e.nid; + left_nidx[i] = split_node.LeftChild(); + right_nidx[i] = split_node.RightChild(); + split_data[i] = + NodeSplitData{split_node, split_type, this->evaluator_.GetDeviceNodeCats(e.nid)}; + + CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); + } + return std::make_tuple(nidx, left_nidx, right_nidx, split_data); + } + + public: dh::device_vector d_gpair; // storage for gpair; common::Span gpair; @@ -102,7 +163,7 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; - GPUHistMakerDevice(Context const* ctx, TrainParam _param, + GPUHistMakerDevice(Context const* ctx, TrainParam _param, HistMakerTrainParam const* hist_param, std::shared_ptr column_sampler, BatchParam batch_param, MetaInfo const& info, std::vector batch_ptr, std::shared_ptr cuts) @@ -112,8 +173,9 @@ struct GPUHistMakerDevice { column_sampler_(std::move(column_sampler)), interaction_constraints(param, static_cast(info.num_col_)), batch_ptr_{std::move(batch_ptr)}, + hist_param_{hist_param}, cuts_{std::move(cuts)} { - sampler = + this->sampler = std::make_unique(ctx, info.num_row_, batch_param, param.subsample, param.sampling_method, batch_ptr_.size() > 2); if (!param.monotone_constraints.empty()) { @@ -132,7 +194,7 @@ struct GPUHistMakerDevice { CHECK(cuts_); feature_groups = std::make_unique(*cuts_, info.IsDense(), dh::MaxSharedMemoryOptin(ctx_->Ordinal()), - sizeof(GradientPairPrecise)); + sizeof(GradientPairInt64)); } } @@ -142,7 +204,7 @@ struct GPUHistMakerDevice { this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); + common::SetDevice(ctx_->Ordinal()); this->interaction_constraints.Reset(); @@ -185,13 +247,12 @@ struct GPUHistMakerDevice { quantiser = std::make_unique(ctx_, this->gpair, p_fmat->Info()); - // Init histogram - hist.Init(ctx_->Device(), this->cuts_->TotalBins()); - hist.Reset(ctx_); - this->InitFeatureGroupsOnce(info); - this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false); + this->histogram_.Reset(ctx_, this->hist_param_->MaxCachedHistNodes(ctx_->Device()), + feature_groups->DeviceAccessor(ctx_->Device()), cuts_->TotalBins(), + false); + return p_fmat; } @@ -202,7 +263,7 @@ struct GPUHistMakerDevice { sampled_features->SetDevice(ctx_->Device()); common::Span feature_set = interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); - EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)}; + EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, histogram_.GetNodeHistogram(nidx)}; EvaluateSplitSharedInputs shared_inputs{gpu_param, *quantiser, p_fmat->Info().feature_types.ConstDeviceSpan(), @@ -250,12 +311,10 @@ struct GPUHistMakerDevice { common::Span right_feature_set = interaction_constraints.Query(right_sampled_features->DeviceSpan(), right_nidx); - h_node_inputs[i * 2] = {left_nidx, candidate.depth + 1, - candidate.split.left_sum, left_feature_set, - hist.GetNodeHistogram(left_nidx)}; - h_node_inputs[i * 2 + 1] = {right_nidx, candidate.depth + 1, - candidate.split.right_sum, right_feature_set, - hist.GetNodeHistogram(right_nidx)}; + h_node_inputs[i * 2] = {left_nidx, candidate.depth + 1, candidate.split.left_sum, + left_feature_set, histogram_.GetNodeHistogram(left_nidx)}; + h_node_inputs[i * 2 + 1] = {right_nidx, candidate.depth + 1, candidate.split.right_sum, + right_feature_set, histogram_.GetNodeHistogram(right_nidx)}; } bst_feature_t max_active_features = 0; for (auto input : h_node_inputs) { @@ -274,28 +333,17 @@ struct GPUHistMakerDevice { this->monitor.Stop(__func__); } - void BuildHist(EllpackPageImpl const* page, int nidx) { - auto d_node_hist = hist.GetNodeHistogram(nidx); - auto d_ridx = partitioners_.front()->GetRows(nidx); - this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), + void BuildHist(EllpackPage const& page, std::int32_t k, bst_bin_t nidx) { + monitor.Start(__func__); + auto d_node_hist = histogram_.GetNodeHistogram(nidx); + auto batch = page.Impl(); + auto acc = batch->GetDeviceAccessor(ctx_->Device()); + + auto d_ridx = partitioners_.at(k)->GetRows(nidx); + this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc, feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, d_node_hist, *quantiser); - } - - // Attempt to do subtraction trick - // return true if succeeded - bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { - if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) { - return false; - } - auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent); - auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); - auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); - - dh::LaunchN(cuts_->TotalBins(), [=] __device__(size_t idx) { - d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; - }); - return true; + monitor.Stop(__func__); } void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix, @@ -349,6 +397,7 @@ struct GPUHistMakerDevice { }; collective::SafeColl(rc); + CHECK_EQ(partitioners_.size(), 1) << "External memory with column split is not yet supported."; partitioners_.front()->UpdatePositionBatch( nidx, left_nidx, right_nidx, split_data, [=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) { @@ -393,10 +442,7 @@ struct GPUHistMakerDevice { monitor.Start(__func__); - std::vector nidx(candidates.size()); - std::vector left_nidx(candidates.size()); - std::vector right_nidx(candidates.size()); - std::vector split_data(candidates.size()); + auto [nidx, left_nidx, right_nidx, split_data] = this->CreatePartitionNodes(p_tree, candidates); for (size_t i = 0; i < candidates.size(); i++) { auto const& e = candidates[i]; @@ -531,19 +577,6 @@ struct GPUHistMakerDevice { return true; } - // num histograms is the number of contiguous histograms in memory to reduce over - void AllReduceHist(MetaInfo const& info, bst_node_t nidx, int num_histograms) { - monitor.Start(__func__); - auto d_node_hist = hist.GetNodeHistogram(nidx); - using ReduceT = typename std::remove_pointer::type::ValueT; - auto rc = collective::GlobalSum( - ctx_, info, - linalg::MakeVec(reinterpret_cast(d_node_hist.data()), - d_node_hist.size() * 2 * num_histograms, ctx_->Device())); - SafeColl(rc); - monitor.Stop(__func__); - } - /** * \brief Build GPU local histograms for the left and right child of some parent node */ @@ -555,48 +588,44 @@ struct GPUHistMakerDevice { this->monitor.Start(__func__); // Some nodes we will manually compute histograms // others we will do by subtraction - std::vector hist_nidx; - std::vector subtraction_nidx; - for (auto& e : candidates) { - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = e.split.right_sum.GetQuantisedHess() < e.split.left_sum.GetQuantisedHess(); - if (fewer_right) { - hist_nidx.emplace_back(tree[e.nid].RightChild()); - subtraction_nidx.emplace_back(tree[e.nid].LeftChild()); - } else { - hist_nidx.emplace_back(tree[e.nid].LeftChild()); - subtraction_nidx.emplace_back(tree[e.nid].RightChild()); - } - } + std::vector hist_nidx(candidates.size()); + std::vector subtraction_nidx(candidates.size()); + auto prefetch_copy = + AssignNodes(&tree, this->quantiser.get(), candidates, hist_nidx, subtraction_nidx); + std::vector all_new = hist_nidx; all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end()); // Allocate the histograms // Guaranteed contiguous memory - hist.AllocateHistograms(ctx_, all_new); + histogram_.AllocateHistograms(ctx_, all_new); - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + std::int32_t k = 0; + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(prefetch_copy))) { for (auto nidx : hist_nidx) { - this->BuildHist(page.Impl(), nidx); + this->BuildHist(page, k, nidx); } + ++k; } // Reduce all in one go // This gives much better latency in a distributed setting // when processing a large batch - this->AllReduceHist(p_fmat->Info(), hist_nidx.at(0), hist_nidx.size()); + this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), hist_nidx.at(0), hist_nidx.size()); for (size_t i = 0; i < subtraction_nidx.size(); i++) { auto build_hist_nidx = hist_nidx.at(i); auto subtraction_trick_nidx = subtraction_nidx.at(i); auto parent_nidx = candidates.at(i).nid; - if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { + if (!this->histogram_.SubtractionTrick(parent_nidx, build_hist_nidx, + subtraction_trick_nidx)) { // Calculate other histogram manually + std::int32_t k = 0; for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - this->BuildHist(page.Impl(), subtraction_trick_nidx); + this->BuildHist(page, k, subtraction_trick_nidx); + ++k; } - this->AllReduceHist(p_fmat->Info(), subtraction_trick_nidx, 1); + this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), subtraction_trick_nidx, 1); } } this->monitor.Stop(__func__); @@ -666,11 +695,13 @@ struct GPUHistMakerDevice { ctx_, p_fmat->Info(), linalg::MakeVec(reinterpret_cast(&root_sum_quantised), 2)); collective::SafeColl(rc); - hist.AllocateHistograms(ctx_, {kRootNIdx}); + histogram_.AllocateHistograms(ctx_, {kRootNIdx}); + std::int32_t k = 0; for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - this->BuildHist(page.Impl(), kRootNIdx); + this->BuildHist(page, k, kRootNIdx); + ++k; } - this->AllReduceHist(p_fmat->Info(), kRootNIdx, 1); + this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), kRootNIdx, 1); // Remember root stats auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised); @@ -812,15 +843,15 @@ class GPUHistMaker : public TreeUpdater { ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0)); this->column_sampler_ = std::make_shared(column_sampling_seed); - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); + common::SetDevice(ctx_->Ordinal()); p_fmat->Info().feature_types.SetDevice(ctx_->Device()); std::vector batch_ptr; auto batch = HistBatch(*param); auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr); - this->maker = std::make_unique(ctx_, *param, column_sampler_, batch, - p_fmat->Info(), batch_ptr, cuts); + this->maker = std::make_unique( + ctx_, *param, &hist_maker_param_, column_sampler_, batch, p_fmat->Info(), batch_ptr, cuts); p_last_fmat_ = p_fmat; initialised_ = true; @@ -888,9 +919,6 @@ class GPUGlobalApproxMaker : public TreeUpdater { // Used in test to count how many configurations are performed LOG(DEBUG) << "[GPU Approx]: Configure"; hist_maker_param_.UpdateAllowUnknown(args); - if (hist_maker_param_.max_cached_hist_node != HistMakerTrainParam::DefaultNodes()) { - LOG(WARNING) << "The `max_cached_hist_node` is ignored in GPU."; - } common::CheckComputeCapability(); initialised_ = false; @@ -932,8 +960,8 @@ class GPUGlobalApproxMaker : public TreeUpdater { auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr); batch.regen = false; // Regen only at the beginning of the iteration. - this->maker_ = std::make_unique(ctx_, *param, column_sampler_, batch, - p_fmat->Info(), batch_ptr, cuts); + this->maker_ = std::make_unique( + ctx_, *param, &hist_maker_param_, column_sampler_, batch, p_fmat->Info(), batch_ptr, cuts); std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 15c8f7def..06666e963 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -9,6 +9,7 @@ #include "../../../../src/tree/gpu_hist/histogram.cuh" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" // for RowPartitioner +#include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam #include "../../../../src/tree/param.h" // for TrainParam #include "../../categorical_helpers.h" // for OneHotEncodeFeature #include "../../helpers.h" @@ -21,13 +22,13 @@ TEST(Histogram, DeviceHistogramStorage) { constexpr size_t kNBins = 128; constexpr int kNNodes = 4; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; - DeviceHistogramStorage histogram; - histogram.Init(FstCU(), kNBins); + DeviceHistogramStorage histogram{}; + histogram.Reset(&ctx, kNBins, kNNodes); for (int i = 0; i < kNNodes; ++i) { histogram.AllocateHistograms(&ctx, {i}); } - histogram.Reset(&ctx); ASSERT_EQ(histogram.Data().size(), kStopGrowing); + histogram.Reset(&ctx, kNBins, kNNodes); // Use allocated memory but do not erase nidx_map. for (int i = 0; i < kNNodes; ++i) { @@ -55,6 +56,35 @@ TEST(Histogram, DeviceHistogramStorage) { EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1});); } +TEST(Histogram, SubtractionTrack) { + auto ctx = MakeCUDACtx(0); + + auto page = BuildEllpackPage(&ctx, 64, 4); + auto cuts = page->CutsShared(); + FeatureGroups fg{*cuts, true, std::numeric_limits::max(), + sizeof(GradientPairPrecise)}; + auto fg_acc = fg.DeviceAccessor(ctx.Device()); + auto n_total_bins = cuts->TotalBins(); + + // 2 nodes + auto max_cached_hist_nodes = 2ull; + DeviceHistogramBuilder histogram; + histogram.Reset(&ctx, max_cached_hist_nodes, fg_acc, n_total_bins, false); + histogram.AllocateHistograms(&ctx, {0, 1, 2}); + GPUExpandEntry root; + root.nid = 0; + auto need_build = histogram.SubtractHist({root}, {0}, {1}); + + std::vector candidates(2); + candidates[0].nid = 1; + candidates[1].nid = 2; + + need_build = histogram.SubtractHist(candidates, {3, 5}, {4, 6}); + ASSERT_EQ(need_build.size(), 2); + ASSERT_EQ(need_build[0], 4); + ASSERT_EQ(need_build[1], 6); +} + std::vector GetHostHistGpair() { // 24 bins, 3 bins for each feature (column). std::vector hist_gpair = { @@ -101,17 +131,16 @@ void TestBuildHist(bool use_shared_memory_histograms) { auto shm_size = use_shared_memory_histograms ? dh::MaxSharedMemoryOptin(ctx.Ordinal()) : 0; FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64)); - DeviceHistogramStorage hist; - hist.Init(ctx.Device(), page->Cuts().TotalBins()); - hist.AllocateHistograms(&ctx, {0}); - DeviceHistogramBuilder builder; - builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), !use_shared_memory_histograms); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(), + !use_shared_memory_histograms); + builder.AllocateHistograms(&ctx, {0}); builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), - row_partitioner->GetRows(0), hist.GetNodeHistogram(0), *quantiser); + row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser); - auto node_histogram = hist.GetNodeHistogram(0); + auto node_histogram = builder.GetNodeHistogram(0); std::vector h_result(node_histogram.size()); dh::CopyDeviceSpanToVector(&h_result, node_histogram); @@ -158,7 +187,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); DeviceHistogramBuilder builder; - builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global); builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, d_histogram, quantiser); @@ -173,7 +203,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); DeviceHistogramBuilder builder; - builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global); builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, d_new_histogram, quantiser); @@ -197,7 +228,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) dh::device_vector baseline(num_bins); DeviceHistogramBuilder builder; - builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + single_group.DeviceAccessor(ctx.Device()), num_bins, force_global); builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), quantiser); @@ -264,7 +296,8 @@ void TestGPUHistogramCategorical(size_t num_categories) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); DeviceHistogramBuilder builder; - builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + single_group.DeviceAccessor(ctx.Device()), num_categories, false); builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), quantiser); @@ -280,7 +313,8 @@ void TestGPUHistogramCategorical(size_t num_categories) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); DeviceHistogramBuilder builder; - builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false); builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), quantiser); @@ -429,7 +463,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParamGetRows(0); auto d_histogram = dh::ToSpan(multi_hist); DeviceHistogramBuilder builder; - builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), + fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global); builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()), fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, d_histogram, quantiser); @@ -454,7 +489,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParamDeviceAccessor(ctx.Device()), force_global); + builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()), + d_histogram.size(), force_global); builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()), fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, d_histogram, quantiser); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index b7aae1b57..43dc4f46a 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -51,7 +51,7 @@ void TestEvaluateSplits(bool force_read_by_column) { row_set_collection.Init(); HistMakerTrainParam hist_param; - hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node); + hist.Reset(gmat.cut.Ptrs().back(), hist_param.MaxCachedHistNodes(ctx.Device())); hist.AllocateHistograms({0}); auto const &elem = row_set_collection[0]; common::BuildHist(row_gpairs, common::Span{elem.begin(), elem.end()}, gmat, hist[0], @@ -120,7 +120,7 @@ TEST(HistMultiEvaluator, Evaluate) { linalg::Vector root_sum({2}, DeviceOrd::CPU()); for (bst_target_t t{0}; t < n_targets; ++t) { auto &hist = histogram[t]; - hist.Reset(n_bins * n_features, hist_param.max_cached_hist_node); + hist.Reset(n_bins * n_features, hist_param.MaxCachedHistNodes(ctx.Device())); hist.AllocateHistograms({0}); auto node_hist = hist[0]; node_hist[0] = {-0.5, 0.5}; @@ -237,7 +237,7 @@ auto CompareOneHotAndPartition(bool onehot) { entries.front().nid = 0; entries.front().depth = 0; - hist.Reset(gmat.cut.TotalBins(), hist_param.max_cached_hist_node); + hist.Reset(gmat.cut.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device())); hist.AllocateHistograms({0}); auto node_hist = hist[0]; @@ -265,9 +265,10 @@ TEST(HistEvaluator, Categorical) { } TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { + Context ctx; BoundedHistCollection hist; HistMakerTrainParam hist_param; - hist.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node); + hist.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device())); hist.AllocateHistograms({0}); auto node_hist = hist[0]; ASSERT_EQ(node_hist.size(), feature_histogram_.size()); @@ -277,10 +278,9 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { MetaInfo info; info.num_col_ = 1; info.feature_types = {FeatureType::kCategorical}; - Context ctx; + auto evaluator = HistEvaluator{&ctx, ¶m_, info, sampler}; evaluator.InitRoot(GradStats{parent_sum_}); - std::vector entries(1); RegTree tree; evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries); diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h index a25e75aef..c7c6854f5 100644 --- a/tests/cpp/tree/test_evaluate_splits.h +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -56,8 +56,9 @@ class TestPartitionBasedSplit : public ::testing::Test { cuts_.min_vals_.Resize(1); + Context ctx; HistMakerTrainParam hist_param; - hist_.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node); + hist_.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device())); hist_.AllocateHistograms({0}); auto node_hist = hist_[0]; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 61f764757..ebd92510d 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -216,7 +216,7 @@ TEST(GpuHist, ConfigIO) { } TEST(GpuHist, MaxDepth) { - Context ctx(MakeCUDACtx(0)); + auto ctx = MakeCUDACtx(0); size_t constexpr kRows = 16; size_t constexpr kCols = 4; auto p_mat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 91e76a06f..21f7f76fe 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -10,6 +10,7 @@ from xgboost import testing as tm from xgboost.testing.params import ( cat_parameter_strategy, exact_parameter_strategy, + hist_cache_strategy, hist_parameter_strategy, ) from xgboost.testing.updater import ( @@ -46,6 +47,7 @@ class TestGPUUpdaters: @given( exact_parameter_strategy, hist_parameter_strategy, + hist_cache_strategy, strategies.integers(1, 20), tm.make_dataset_strategy(), ) @@ -54,19 +56,44 @@ class TestGPUUpdaters: self, param: Dict[str, Any], hist_param: Dict[str, Any], + cache_param: Dict[str, Any], num_rounds: int, dataset: tm.TestDataset, ) -> None: param.update({"tree_method": "hist", "device": "cuda"}) param.update(hist_param) + param.update(cache_param) param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) note(str(result)) assert tm.non_increasing(result["train"][dataset.metric]) + @pytest.mark.parametrize("tree_method", ["approx", "hist"]) + def test_cache_size(self, tree_method: str) -> None: + from sklearn.datasets import make_regression + + X, y = make_regression(n_samples=4096, n_features=64, random_state=1994) + Xy = xgb.DMatrix(X, y) + results = [] + for cache_size in [1, 3, 2048]: + params: Dict[str, Any] = {"tree_method": tree_method, "device": "cuda"} + params["max_cached_hist_node"] = cache_size + evals_result: Dict[str, Dict[str, list]] = {} + xgb.train( + params, + Xy, + num_boost_round=4, + evals=[(Xy, "Train")], + evals_result=evals_result, + ) + results.append(evals_result["Train"]["rmse"]) + for i in range(1, len(results)): + np.testing.assert_allclose(results[0], results[i]) + @given( exact_parameter_strategy, hist_parameter_strategy, + hist_cache_strategy, strategies.integers(1, 20), tm.make_dataset_strategy(), ) @@ -75,11 +102,13 @@ class TestGPUUpdaters: self, param: Dict[str, Any], hist_param: Dict[str, Any], + cache_param: Dict[str, Any], num_rounds: int, dataset: tm.TestDataset, ) -> None: param.update({"tree_method": "approx", "device": "cuda"}) param.update(hist_param) + param.update(cache_param) param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) note(str(result))