Always use partition based categorical splits. (#7857)
This commit is contained in:
@@ -199,13 +199,11 @@ __device__ void EvaluateFeature(
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS, typename GradientSumT>
|
||||
__global__ void EvaluateSplitsKernel(
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
ObjInfo task,
|
||||
common::Span<bst_feature_t> sorted_idx,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||
__global__ void EvaluateSplitsKernel(EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
common::Span<bst_feature_t> sorted_idx,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT =
|
||||
@@ -241,7 +239,7 @@ __global__ void EvaluateSplitsKernel(
|
||||
|
||||
if (common::IsCat(inputs.feature_types, fidx)) {
|
||||
auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx];
|
||||
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) {
|
||||
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot)) {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kOneHot>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
|
||||
} else {
|
||||
@@ -310,7 +308,7 @@ __device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input,
|
||||
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
|
||||
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_splits) {
|
||||
if (!split_cats_.empty()) {
|
||||
@@ -323,7 +321,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
// One block for each feature
|
||||
uint32_t constexpr kBlockThreads = 256;
|
||||
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
|
||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, task, this->SortedIdx(left),
|
||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, this->SortedIdx(left),
|
||||
evaluator, dh::ToSpan(feature_best_splits));
|
||||
|
||||
// Reduce to get best candidate for left and right child over all features
|
||||
@@ -365,7 +363,7 @@ void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
|
||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
common::Span<GPUExpandEntry> out_entries) {
|
||||
@@ -373,7 +371,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
|
||||
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(2);
|
||||
auto out_splits = dh::ToSpan(splits_out_storage);
|
||||
this->EvaluateSplits(left, right, task, evaluator, out_splits);
|
||||
this->EvaluateSplits(left, right, evaluator, out_splits);
|
||||
|
||||
auto d_sorted_idx = this->SortedIdx(left);
|
||||
auto d_entries = out_entries;
|
||||
@@ -385,7 +383,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
|
||||
auto fidx = out_splits[i].findex;
|
||||
|
||||
if (split.is_cat &&
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
|
||||
bool is_left = i == 0;
|
||||
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
|
||||
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
|
||||
@@ -405,11 +403,11 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
|
||||
|
||||
template <typename GradientSumT>
|
||||
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
||||
EvaluateSplitInputs<GradientSumT> input, float weight, ObjInfo task) {
|
||||
EvaluateSplitInputs<GradientSumT> input, float weight) {
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
|
||||
auto out_split = dh::ToSpan(splits_out);
|
||||
auto evaluator = tree_evaluator_.GetEvaluator<GPUTrainingParam>();
|
||||
this->EvaluateSplits(input, {}, task, evaluator, out_split);
|
||||
this->EvaluateSplits(input, {}, evaluator, out_split);
|
||||
|
||||
auto cats_out = this->DeviceCatStorage(input.nidx);
|
||||
auto d_sorted_idx = this->SortedIdx(input);
|
||||
@@ -421,7 +419,7 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
||||
auto fidx = out_split[i].findex;
|
||||
|
||||
if (split.is_cat &&
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
|
||||
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class GPUHistEvaluator {
|
||||
/**
|
||||
* \brief Reset the evaluator, should be called before any use.
|
||||
*/
|
||||
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, ObjInfo task,
|
||||
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
|
||||
bst_feature_t n_features, TrainParam const ¶m, int32_t device);
|
||||
|
||||
/**
|
||||
@@ -150,21 +150,20 @@ class GPUHistEvaluator {
|
||||
|
||||
// impl of evaluate splits, contains CUDA kernels so it's public
|
||||
void EvaluateSplits(EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_splits);
|
||||
/**
|
||||
* \brief Evaluate splits for left and right nodes.
|
||||
*/
|
||||
void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
|
||||
void EvaluateSplits(GPUExpandEntry candidate,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
common::Span<GPUExpandEntry> out_splits);
|
||||
/**
|
||||
* \brief Evaluate splits for root node.
|
||||
*/
|
||||
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight,
|
||||
ObjInfo task);
|
||||
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight);
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -16,12 +16,12 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
common::Span<FeatureType const> ft, ObjInfo task,
|
||||
common::Span<FeatureType const> ft,
|
||||
bst_feature_t n_features, TrainParam const ¶m,
|
||||
int32_t device) {
|
||||
param_ = param;
|
||||
tree_evaluator_ = TreeEvaluator{param, n_features, device};
|
||||
if (cuts.HasCategorical() && !task.UseOneHot()) {
|
||||
if (cuts.HasCategorical()) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
|
||||
auto beg = thrust::make_counting_iterator<size_t>(1ul);
|
||||
@@ -34,7 +34,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
auto idx = i - 1;
|
||||
if (common::IsCat(ft, idx)) {
|
||||
auto n_bins = ptrs[i] - ptrs[idx];
|
||||
bool use_sort = !common::UseOneHot(n_bins, to_onehot, task);
|
||||
bool use_sort = !common::UseOneHot(n_bins, to_onehot);
|
||||
return use_sort;
|
||||
}
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user