Fix GPU categorical split memory allocation. (#9529)

This commit is contained in:
Jiaming Yuan
2023-08-29 10:06:03 +08:00
committed by GitHub
parent be6a552956
commit 942b957eef
5 changed files with 67 additions and 64 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2020-2022 by XGBoost Contributors
/**
* Copyright 2020-2023, XGBoost Contributors
*/
#include <algorithm> // std::max
#include <vector>
@@ -11,9 +11,7 @@
#include "evaluate_splits.cuh"
#include "expand_entry.cuh"
namespace xgboost {
namespace tree {
namespace xgboost::tree {
// With constraints
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
const GradientPairInt64 &missing,
@@ -315,11 +313,11 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
common::Span<common::CatBitField::value_type> out,
DeviceSplitCandidate *p_out_split) {
auto &out_split = *p_out_split;
out_split.split_cats = common::CatBitField{out};
auto out_cats = common::CatBitField{out};
// Simple case for one hot split
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
out_split.split_cats.Set(common::AsCat(out_split.thresh));
out_cats.Set(common::AsCat(out_split.thresh));
return;
}
@@ -339,7 +337,7 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
assert(partition > 0 && "Invalid partition.");
thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
auto cat = shared_inputs.feature_values[c - node_offset];
out_split.SetCat(cat);
out_cats.Set(common::AsCat(cat));
});
}
@@ -444,8 +442,7 @@ void GPUHistEvaluator::EvaluateSplits(
if (split.is_cat) {
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
device_cats_accessor.GetNodeCatStorage(input.nidx),
&out_splits[i]);
device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]);
}
float base_weight =
@@ -477,6 +474,4 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(
cudaMemcpyDeviceToHost));
return root_entry;
}
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree

View File

@@ -37,8 +37,8 @@ struct EvaluateSplitSharedInputs {
common::Span<const float> feature_values;
common::Span<const float> min_fvalue;
bool is_dense;
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
__device__ auto FeatureBins(bst_feature_t fidx) const {
[[nodiscard]] XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
[[nodiscard]] __device__ std::uint32_t FeatureBins(bst_feature_t fidx) const {
return feature_segments[fidx + 1] - feature_segments[fidx];
}
};
@@ -105,7 +105,7 @@ class GPUHistEvaluator {
}
/**
* \brief Get device category storage of nidx for internal calculation.
* @brief Get device category storage of nidx for internal calculation.
*/
auto DeviceCatStorage(const std::vector<bst_node_t> &nidx) {
if (!has_categoricals_) return CatAccessor{};
@@ -120,8 +120,8 @@ class GPUHistEvaluator {
/**
* \brief Get sorted index storage based on the left node of inputs.
*/
auto SortedIdx(int num_nodes, bst_feature_t total_bins) {
if(!need_sort_histogram_) return common::Span<bst_feature_t>();
auto SortedIdx(int num_nodes, bst_bin_t total_bins) {
if (!need_sort_histogram_) return common::Span<bst_feature_t>{};
cat_sorted_idx_.resize(num_nodes * total_bins);
return dh::ToSpan(cat_sorted_idx_);
}
@@ -146,12 +146,22 @@ class GPUHistEvaluator {
* \brief Get host category storage for nidx. Different from the internal version, this
* returns strictly 1 node.
*/
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
[[nodiscard]] common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
copy_stream_.View().Sync();
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
return cats_out;
}
[[nodiscard]] auto GetDeviceNodeCats(bst_node_t nidx) {
copy_stream_.View().Sync();
if (has_categoricals_) {
CatAccessor accessor = {dh::ToSpan(split_cats_), node_categorical_storage_size_};
return common::KCatBitField{accessor.GetNodeCatStorage(nidx)};
} else {
return common::KCatBitField{};
}
}
/**
* \brief Add a split to the internal tree evaluator.
*/