From 942b957eef8e81c071ba543847eea8f4d2806df6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 29 Aug 2023 10:06:03 +0800 Subject: [PATCH] Fix GPU categorical split memory allocation. (#9529) --- src/common/categorical.h | 2 +- src/tree/gpu_hist/evaluate_splits.cu | 21 +++++------ src/tree/gpu_hist/evaluate_splits.cuh | 22 ++++++++---- src/tree/updater_gpu_common.cuh | 36 ++++++++----------- src/tree/updater_gpu_hist.cu | 50 +++++++++++++++------------ 5 files changed, 67 insertions(+), 64 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index 249a818e5..32b771ad6 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -52,7 +52,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) { * * Go to left if it's NOT the matching category, which matches one-hot encoding. */ -inline XGBOOST_DEVICE bool Decision(common::Span cats, float cat) { +inline XGBOOST_DEVICE bool Decision(common::Span cats, float cat) { KCatBitField const s_cats(cats); if (XGBOOST_EXPECT(InvalidCat(cat), false)) { return true; diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 30941c060..ecfc6c3ce 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2020-2022 by XGBoost Contributors +/** + * Copyright 2020-2023, XGBoost Contributors */ #include // std::max #include @@ -11,9 +11,7 @@ #include "evaluate_splits.cuh" #include "expand_entry.cuh" -namespace xgboost { -namespace tree { - +namespace xgboost::tree { // With constraints XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan, const GradientPairInt64 &missing, @@ -315,11 +313,11 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu common::Span out, DeviceSplitCandidate *p_out_split) { auto &out_split = *p_out_split; - out_split.split_cats = common::CatBitField{out}; + auto out_cats = common::CatBitField{out}; // Simple case for one hot split if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) { - out_split.split_cats.Set(common::AsCat(out_split.thresh)); + out_cats.Set(common::AsCat(out_split.thresh)); return; } @@ -339,7 +337,7 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu assert(partition > 0 && "Invalid partition."); thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) { auto cat = shared_inputs.feature_values[c - node_offset]; - out_split.SetCat(cat); + out_cats.Set(common::AsCat(cat)); }); } @@ -444,8 +442,7 @@ void GPUHistEvaluator::EvaluateSplits( if (split.is_cat) { SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i, - device_cats_accessor.GetNodeCatStorage(input.nidx), - &out_splits[i]); + device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]); } float base_weight = @@ -477,6 +474,4 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit( cudaMemcpyDeviceToHost)); return root_entry; } - -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index 25a8cde89..667982aa9 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -37,8 +37,8 @@ struct EvaluateSplitSharedInputs { common::Span feature_values; common::Span min_fvalue; bool is_dense; - XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; } - __device__ auto FeatureBins(bst_feature_t fidx) const { + [[nodiscard]] XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; } + [[nodiscard]] __device__ std::uint32_t FeatureBins(bst_feature_t fidx) const { return feature_segments[fidx + 1] - feature_segments[fidx]; } }; @@ -105,7 +105,7 @@ class GPUHistEvaluator { } /** - * \brief Get device category storage of nidx for internal calculation. + * @brief Get device category storage of nidx for internal calculation. */ auto DeviceCatStorage(const std::vector &nidx) { if (!has_categoricals_) return CatAccessor{}; @@ -120,8 +120,8 @@ class GPUHistEvaluator { /** * \brief Get sorted index storage based on the left node of inputs. */ - auto SortedIdx(int num_nodes, bst_feature_t total_bins) { - if(!need_sort_histogram_) return common::Span(); + auto SortedIdx(int num_nodes, bst_bin_t total_bins) { + if (!need_sort_histogram_) return common::Span{}; cat_sorted_idx_.resize(num_nodes * total_bins); return dh::ToSpan(cat_sorted_idx_); } @@ -146,12 +146,22 @@ class GPUHistEvaluator { * \brief Get host category storage for nidx. Different from the internal version, this * returns strictly 1 node. */ - common::Span GetHostNodeCats(bst_node_t nidx) const { + [[nodiscard]] common::Span GetHostNodeCats(bst_node_t nidx) const { copy_stream_.View().Sync(); auto cats_out = common::Span{h_split_cats_}.subspan( nidx * node_categorical_storage_size_, node_categorical_storage_size_); return cats_out; } + + [[nodiscard]] auto GetDeviceNodeCats(bst_node_t nidx) { + copy_stream_.View().Sync(); + if (has_categoricals_) { + CatAccessor accessor = {dh::ToSpan(split_cats_), node_categorical_storage_size_}; + return common::KCatBitField{accessor.GetNodeCatStorage(nidx)}; + } else { + return common::KCatBitField{}; + } + } /** * \brief Add a split to the internal tree evaluator. */ diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 1637300b6..8f5b27ac6 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -64,7 +64,6 @@ struct DeviceSplitCandidate { // split. bst_cat_t thresh{-1}; - common::CatBitField split_cats; bool is_cat { false }; GradientPairInt64 left_sum; @@ -72,12 +71,6 @@ struct DeviceSplitCandidate { XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT - template - XGBOOST_DEVICE void SetCat(T c) { - this->split_cats.Set(common::AsCat(c)); - fvalue = std::max(this->fvalue, static_cast(c)); - } - XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, float fvalue_in, int findex_in, GradientPairInt64 left_sum_in, GradientPairInt64 right_sum_in, bool cat, @@ -100,22 +93,23 @@ struct DeviceSplitCandidate { */ XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in, bst_feature_t findex_in, GradientPairInt64 left_sum_in, - GradientPairInt64 right_sum_in, GPUTrainingParam const& param, const GradientQuantiser& quantiser) { - if (loss_chg_in > loss_chg && - quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight && - quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) { - loss_chg = loss_chg_in; - dir = dir_in; - fvalue = std::numeric_limits::quiet_NaN(); - thresh = thresh_in; - is_cat = true; - left_sum = left_sum_in; - right_sum = right_sum_in; - findex = findex_in; - } + GradientPairInt64 right_sum_in, GPUTrainingParam const& param, + const GradientQuantiser& quantiser) { + if (loss_chg_in > loss_chg && + quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight && + quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) { + loss_chg = loss_chg_in; + dir = dir_in; + fvalue = std::numeric_limits::quiet_NaN(); + thresh = thresh_in; + is_cat = true; + left_sum = left_sum_in; + right_sum = right_sum_in; + findex = findex_in; + } } - XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; } + [[nodiscard]] XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; } friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) { os << "loss_chg:" << c.loss_chg << ", " diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 33dfbf8c5..10fb913b3 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -7,9 +7,9 @@ #include #include -#include -#include -#include +#include // for size_t +#include // for unique_ptr, make_unique +#include // for move #include #include "../collective/communicator-inl.cuh" @@ -216,9 +216,9 @@ struct GPUHistMakerDevice { void InitFeatureGroupsOnce() { if (!feature_groups) { CHECK(page); - feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, - dh::MaxSharedMemoryOptin(ctx_->gpu_id), - sizeof(GradientPairPrecise))); + feature_groups = std::make_unique(page->Cuts(), page->is_dense, + dh::MaxSharedMemoryOptin(ctx_->gpu_id), + sizeof(GradientPairPrecise)); } } @@ -245,10 +245,10 @@ struct GPUHistMakerDevice { this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, dmat->Info().IsColumnSplit(), ctx_->gpu_id); - quantiser.reset(new GradientQuantiser(this->gpair)); + quantiser = std::make_unique(this->gpair); row_partitioner.reset(); // Release the device memory first before reallocating - row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows)); + row_partitioner = std::make_unique(ctx_->gpu_id, sample.sample_rows); // Init histogram hist.Init(ctx_->gpu_id, page->Cuts().TotalBins()); @@ -295,7 +295,7 @@ struct GPUHistMakerDevice { dh::TemporaryArray entries(2 * candidates.size()); // Store the feature set ptrs so they dont go out of scope before the kernel is called std::vector>> feature_sets; - for (size_t i = 0; i < candidates.size(); i++) { + for (std::size_t i = 0; i < candidates.size(); i++) { auto candidate = candidates.at(i); int left_nidx = tree[candidate.nid].LeftChild(); int right_nidx = tree[candidate.nid].RightChild(); @@ -328,14 +328,13 @@ struct GPUHistMakerDevice { d_node_inputs.data().get(), h_node_inputs.data(), h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault)); - this->evaluator_.EvaluateSplits(nidx, max_active_features, - dh::ToSpan(d_node_inputs), shared_inputs, - dh::ToSpan(entries)); + this->evaluator_.EvaluateSplits(nidx, max_active_features, dh::ToSpan(d_node_inputs), + shared_inputs, dh::ToSpan(entries)); dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(), sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); dh::DefaultStream().Sync(); - } + } void BuildHist(int nidx) { auto d_node_hist = hist.GetNodeHistogram(nidx); @@ -367,23 +366,29 @@ struct GPUHistMakerDevice { struct NodeSplitData { RegTree::Node split_node; FeatureType split_type; - common::CatBitField node_cats; + common::KCatBitField node_cats; }; - void UpdatePosition(const std::vector& candidates, RegTree* p_tree) { - if (candidates.empty()) return; - std::vector nidx(candidates.size()); - std::vector left_nidx(candidates.size()); - std::vector right_nidx(candidates.size()); + void UpdatePosition(std::vector const& candidates, RegTree* p_tree) { + if (candidates.empty()) { + return; + } + + std::vector nidx(candidates.size()); + std::vector left_nidx(candidates.size()); + std::vector right_nidx(candidates.size()); std::vector split_data(candidates.size()); + for (size_t i = 0; i < candidates.size(); i++) { - auto& e = candidates[i]; + auto const& e = candidates[i]; RegTree::Node split_node = (*p_tree)[e.nid]; auto split_type = p_tree->NodeSplitType(e.nid); nidx.at(i) = e.nid; left_nidx.at(i) = split_node.LeftChild(); right_nidx.at(i) = split_node.RightChild(); - split_data.at(i) = NodeSplitData{split_node, split_type, e.split.split_cats}; + split_data.at(i) = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)}; + + CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); } auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); @@ -391,7 +396,7 @@ struct GPUHistMakerDevice { nidx, left_nidx, right_nidx, split_data, [=] __device__(bst_uint ridx, const NodeSplitData& data) { // given a row index, returns the node id it belongs to - bst_float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); + float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); // Missing value bool go_left = true; if (isnan(cut_value)) { @@ -621,7 +626,6 @@ struct GPUHistMakerDevice { CHECK(common::CheckNAN(candidate.split.fvalue)); std::vector split_cats; - CHECK_GT(candidate.split.split_cats.Bits().size(), 0); auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex); split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);