Fix GPU categorical split memory allocation. (#9529)
This commit is contained in:
parent
be6a552956
commit
942b957eef
@ -52,7 +52,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
|||||||
*
|
*
|
||||||
* Go to left if it's NOT the matching category, which matches one-hot encoding.
|
* Go to left if it's NOT the matching category, which matches one-hot encoding.
|
||||||
*/
|
*/
|
||||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
|
inline XGBOOST_DEVICE bool Decision(common::Span<CatBitField::value_type const> cats, float cat) {
|
||||||
KCatBitField const s_cats(cats);
|
KCatBitField const s_cats(cats);
|
||||||
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
|
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2022 by XGBoost Contributors
|
* Copyright 2020-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <algorithm> // std::max
|
#include <algorithm> // std::max
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -11,9 +11,7 @@
|
|||||||
#include "evaluate_splits.cuh"
|
#include "evaluate_splits.cuh"
|
||||||
#include "expand_entry.cuh"
|
#include "expand_entry.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
// With constraints
|
// With constraints
|
||||||
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
|
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
|
||||||
const GradientPairInt64 &missing,
|
const GradientPairInt64 &missing,
|
||||||
@ -315,11 +313,11 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
|
|||||||
common::Span<common::CatBitField::value_type> out,
|
common::Span<common::CatBitField::value_type> out,
|
||||||
DeviceSplitCandidate *p_out_split) {
|
DeviceSplitCandidate *p_out_split) {
|
||||||
auto &out_split = *p_out_split;
|
auto &out_split = *p_out_split;
|
||||||
out_split.split_cats = common::CatBitField{out};
|
auto out_cats = common::CatBitField{out};
|
||||||
|
|
||||||
// Simple case for one hot split
|
// Simple case for one hot split
|
||||||
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
|
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
|
||||||
out_split.split_cats.Set(common::AsCat(out_split.thresh));
|
out_cats.Set(common::AsCat(out_split.thresh));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,7 +337,7 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
|
|||||||
assert(partition > 0 && "Invalid partition.");
|
assert(partition > 0 && "Invalid partition.");
|
||||||
thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
|
thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
|
||||||
auto cat = shared_inputs.feature_values[c - node_offset];
|
auto cat = shared_inputs.feature_values[c - node_offset];
|
||||||
out_split.SetCat(cat);
|
out_cats.Set(common::AsCat(cat));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -444,8 +442,7 @@ void GPUHistEvaluator::EvaluateSplits(
|
|||||||
|
|
||||||
if (split.is_cat) {
|
if (split.is_cat) {
|
||||||
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
|
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
|
||||||
device_cats_accessor.GetNodeCatStorage(input.nidx),
|
device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]);
|
||||||
&out_splits[i]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float base_weight =
|
float base_weight =
|
||||||
@ -477,6 +474,4 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(
|
|||||||
cudaMemcpyDeviceToHost));
|
cudaMemcpyDeviceToHost));
|
||||||
return root_entry;
|
return root_entry;
|
||||||
}
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -37,8 +37,8 @@ struct EvaluateSplitSharedInputs {
|
|||||||
common::Span<const float> feature_values;
|
common::Span<const float> feature_values;
|
||||||
common::Span<const float> min_fvalue;
|
common::Span<const float> min_fvalue;
|
||||||
bool is_dense;
|
bool is_dense;
|
||||||
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
|
[[nodiscard]] XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
|
||||||
__device__ auto FeatureBins(bst_feature_t fidx) const {
|
[[nodiscard]] __device__ std::uint32_t FeatureBins(bst_feature_t fidx) const {
|
||||||
return feature_segments[fidx + 1] - feature_segments[fidx];
|
return feature_segments[fidx + 1] - feature_segments[fidx];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -105,7 +105,7 @@ class GPUHistEvaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get device category storage of nidx for internal calculation.
|
* @brief Get device category storage of nidx for internal calculation.
|
||||||
*/
|
*/
|
||||||
auto DeviceCatStorage(const std::vector<bst_node_t> &nidx) {
|
auto DeviceCatStorage(const std::vector<bst_node_t> &nidx) {
|
||||||
if (!has_categoricals_) return CatAccessor{};
|
if (!has_categoricals_) return CatAccessor{};
|
||||||
@ -120,8 +120,8 @@ class GPUHistEvaluator {
|
|||||||
/**
|
/**
|
||||||
* \brief Get sorted index storage based on the left node of inputs.
|
* \brief Get sorted index storage based on the left node of inputs.
|
||||||
*/
|
*/
|
||||||
auto SortedIdx(int num_nodes, bst_feature_t total_bins) {
|
auto SortedIdx(int num_nodes, bst_bin_t total_bins) {
|
||||||
if(!need_sort_histogram_) return common::Span<bst_feature_t>();
|
if (!need_sort_histogram_) return common::Span<bst_feature_t>{};
|
||||||
cat_sorted_idx_.resize(num_nodes * total_bins);
|
cat_sorted_idx_.resize(num_nodes * total_bins);
|
||||||
return dh::ToSpan(cat_sorted_idx_);
|
return dh::ToSpan(cat_sorted_idx_);
|
||||||
}
|
}
|
||||||
@ -146,12 +146,22 @@ class GPUHistEvaluator {
|
|||||||
* \brief Get host category storage for nidx. Different from the internal version, this
|
* \brief Get host category storage for nidx. Different from the internal version, this
|
||||||
* returns strictly 1 node.
|
* returns strictly 1 node.
|
||||||
*/
|
*/
|
||||||
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
|
[[nodiscard]] common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
|
||||||
copy_stream_.View().Sync();
|
copy_stream_.View().Sync();
|
||||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
|
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
|
||||||
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||||
return cats_out;
|
return cats_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] auto GetDeviceNodeCats(bst_node_t nidx) {
|
||||||
|
copy_stream_.View().Sync();
|
||||||
|
if (has_categoricals_) {
|
||||||
|
CatAccessor accessor = {dh::ToSpan(split_cats_), node_categorical_storage_size_};
|
||||||
|
return common::KCatBitField{accessor.GetNodeCatStorage(nidx)};
|
||||||
|
} else {
|
||||||
|
return common::KCatBitField{};
|
||||||
|
}
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* \brief Add a split to the internal tree evaluator.
|
* \brief Add a split to the internal tree evaluator.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -64,7 +64,6 @@ struct DeviceSplitCandidate {
|
|||||||
// split.
|
// split.
|
||||||
bst_cat_t thresh{-1};
|
bst_cat_t thresh{-1};
|
||||||
|
|
||||||
common::CatBitField split_cats;
|
|
||||||
bool is_cat { false };
|
bool is_cat { false };
|
||||||
|
|
||||||
GradientPairInt64 left_sum;
|
GradientPairInt64 left_sum;
|
||||||
@ -72,12 +71,6 @@ struct DeviceSplitCandidate {
|
|||||||
|
|
||||||
XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT
|
XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
XGBOOST_DEVICE void SetCat(T c) {
|
|
||||||
this->split_cats.Set(common::AsCat(c));
|
|
||||||
fvalue = std::max(this->fvalue, static_cast<float>(c));
|
|
||||||
}
|
|
||||||
|
|
||||||
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, float fvalue_in,
|
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, float fvalue_in,
|
||||||
int findex_in, GradientPairInt64 left_sum_in,
|
int findex_in, GradientPairInt64 left_sum_in,
|
||||||
GradientPairInt64 right_sum_in, bool cat,
|
GradientPairInt64 right_sum_in, bool cat,
|
||||||
@ -100,22 +93,23 @@ struct DeviceSplitCandidate {
|
|||||||
*/
|
*/
|
||||||
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
|
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
|
||||||
bst_feature_t findex_in, GradientPairInt64 left_sum_in,
|
bst_feature_t findex_in, GradientPairInt64 left_sum_in,
|
||||||
GradientPairInt64 right_sum_in, GPUTrainingParam const& param, const GradientQuantiser& quantiser) {
|
GradientPairInt64 right_sum_in, GPUTrainingParam const& param,
|
||||||
if (loss_chg_in > loss_chg &&
|
const GradientQuantiser& quantiser) {
|
||||||
quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight &&
|
if (loss_chg_in > loss_chg &&
|
||||||
quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) {
|
quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight &&
|
||||||
loss_chg = loss_chg_in;
|
quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) {
|
||||||
dir = dir_in;
|
loss_chg = loss_chg_in;
|
||||||
fvalue = std::numeric_limits<float>::quiet_NaN();
|
dir = dir_in;
|
||||||
thresh = thresh_in;
|
fvalue = std::numeric_limits<float>::quiet_NaN();
|
||||||
is_cat = true;
|
thresh = thresh_in;
|
||||||
left_sum = left_sum_in;
|
is_cat = true;
|
||||||
right_sum = right_sum_in;
|
left_sum = left_sum_in;
|
||||||
findex = findex_in;
|
right_sum = right_sum_in;
|
||||||
}
|
findex = findex_in;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
|
[[nodiscard]] XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
|
friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
|
||||||
os << "loss_chg:" << c.loss_chg << ", "
|
os << "loss_chg:" << c.loss_chg << ", "
|
||||||
|
|||||||
@ -7,9 +7,9 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <limits>
|
#include <cstddef> // for size_t
|
||||||
#include <memory>
|
#include <memory> // for unique_ptr, make_unique
|
||||||
#include <utility>
|
#include <utility> // for move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.cuh"
|
#include "../collective/communicator-inl.cuh"
|
||||||
@ -216,9 +216,9 @@ struct GPUHistMakerDevice {
|
|||||||
void InitFeatureGroupsOnce() {
|
void InitFeatureGroupsOnce() {
|
||||||
if (!feature_groups) {
|
if (!feature_groups) {
|
||||||
CHECK(page);
|
CHECK(page);
|
||||||
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
|
feature_groups = std::make_unique<FeatureGroups>(page->Cuts(), page->is_dense,
|
||||||
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
|
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
|
||||||
sizeof(GradientPairPrecise)));
|
sizeof(GradientPairPrecise));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,10 +245,10 @@ struct GPUHistMakerDevice {
|
|||||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
||||||
dmat->Info().IsColumnSplit(), ctx_->gpu_id);
|
dmat->Info().IsColumnSplit(), ctx_->gpu_id);
|
||||||
|
|
||||||
quantiser.reset(new GradientQuantiser(this->gpair));
|
quantiser = std::make_unique<GradientQuantiser>(this->gpair);
|
||||||
|
|
||||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||||
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
|
row_partitioner = std::make_unique<RowPartitioner>(ctx_->gpu_id, sample.sample_rows);
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
|
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
|
||||||
@ -295,7 +295,7 @@ struct GPUHistMakerDevice {
|
|||||||
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
||||||
// Store the feature set ptrs so they dont go out of scope before the kernel is called
|
// Store the feature set ptrs so they dont go out of scope before the kernel is called
|
||||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
|
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
|
||||||
for (size_t i = 0; i < candidates.size(); i++) {
|
for (std::size_t i = 0; i < candidates.size(); i++) {
|
||||||
auto candidate = candidates.at(i);
|
auto candidate = candidates.at(i);
|
||||||
int left_nidx = tree[candidate.nid].LeftChild();
|
int left_nidx = tree[candidate.nid].LeftChild();
|
||||||
int right_nidx = tree[candidate.nid].RightChild();
|
int right_nidx = tree[candidate.nid].RightChild();
|
||||||
@ -328,14 +328,13 @@ struct GPUHistMakerDevice {
|
|||||||
d_node_inputs.data().get(), h_node_inputs.data(),
|
d_node_inputs.data().get(), h_node_inputs.data(),
|
||||||
h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault));
|
h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault));
|
||||||
|
|
||||||
this->evaluator_.EvaluateSplits(nidx, max_active_features,
|
this->evaluator_.EvaluateSplits(nidx, max_active_features, dh::ToSpan(d_node_inputs),
|
||||||
dh::ToSpan(d_node_inputs), shared_inputs,
|
shared_inputs, dh::ToSpan(entries));
|
||||||
dh::ToSpan(entries));
|
|
||||||
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
|
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
|
||||||
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
|
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
|
||||||
cudaMemcpyDeviceToHost));
|
cudaMemcpyDeviceToHost));
|
||||||
dh::DefaultStream().Sync();
|
dh::DefaultStream().Sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildHist(int nidx) {
|
void BuildHist(int nidx) {
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
@ -367,23 +366,29 @@ struct GPUHistMakerDevice {
|
|||||||
struct NodeSplitData {
|
struct NodeSplitData {
|
||||||
RegTree::Node split_node;
|
RegTree::Node split_node;
|
||||||
FeatureType split_type;
|
FeatureType split_type;
|
||||||
common::CatBitField node_cats;
|
common::KCatBitField node_cats;
|
||||||
};
|
};
|
||||||
|
|
||||||
void UpdatePosition(const std::vector<GPUExpandEntry>& candidates, RegTree* p_tree) {
|
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) {
|
||||||
if (candidates.empty()) return;
|
if (candidates.empty()) {
|
||||||
std::vector<int> nidx(candidates.size());
|
return;
|
||||||
std::vector<int> left_nidx(candidates.size());
|
}
|
||||||
std::vector<int> right_nidx(candidates.size());
|
|
||||||
|
std::vector<bst_node_t> nidx(candidates.size());
|
||||||
|
std::vector<bst_node_t> left_nidx(candidates.size());
|
||||||
|
std::vector<bst_node_t> right_nidx(candidates.size());
|
||||||
std::vector<NodeSplitData> split_data(candidates.size());
|
std::vector<NodeSplitData> split_data(candidates.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates.size(); i++) {
|
for (size_t i = 0; i < candidates.size(); i++) {
|
||||||
auto& e = candidates[i];
|
auto const& e = candidates[i];
|
||||||
RegTree::Node split_node = (*p_tree)[e.nid];
|
RegTree::Node split_node = (*p_tree)[e.nid];
|
||||||
auto split_type = p_tree->NodeSplitType(e.nid);
|
auto split_type = p_tree->NodeSplitType(e.nid);
|
||||||
nidx.at(i) = e.nid;
|
nidx.at(i) = e.nid;
|
||||||
left_nidx.at(i) = split_node.LeftChild();
|
left_nidx.at(i) = split_node.LeftChild();
|
||||||
right_nidx.at(i) = split_node.RightChild();
|
right_nidx.at(i) = split_node.RightChild();
|
||||||
split_data.at(i) = NodeSplitData{split_node, split_type, e.split.split_cats};
|
split_data.at(i) = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
|
||||||
|
|
||||||
|
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||||
@ -391,7 +396,7 @@ struct GPUHistMakerDevice {
|
|||||||
nidx, left_nidx, right_nidx, split_data,
|
nidx, left_nidx, right_nidx, split_data,
|
||||||
[=] __device__(bst_uint ridx, const NodeSplitData& data) {
|
[=] __device__(bst_uint ridx, const NodeSplitData& data) {
|
||||||
// given a row index, returns the node id it belongs to
|
// given a row index, returns the node id it belongs to
|
||||||
bst_float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
|
float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
|
||||||
// Missing value
|
// Missing value
|
||||||
bool go_left = true;
|
bool go_left = true;
|
||||||
if (isnan(cut_value)) {
|
if (isnan(cut_value)) {
|
||||||
@ -621,7 +626,6 @@ struct GPUHistMakerDevice {
|
|||||||
CHECK(common::CheckNAN(candidate.split.fvalue));
|
CHECK(common::CheckNAN(candidate.split.fvalue));
|
||||||
std::vector<common::CatBitField::value_type> split_cats;
|
std::vector<common::CatBitField::value_type> split_cats;
|
||||||
|
|
||||||
CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
|
|
||||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||||
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
|
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
|
||||||
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
|
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user