Add max_cat_threshold to GPU and handle missing cat values. (#8212)

This commit is contained in:
Jiaming Yuan 2022-09-07 00:57:51 +08:00 committed by GitHub
parent 441ffc017a
commit b5eb36f1af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 546 additions and 122 deletions

View File

@ -43,9 +43,9 @@ class EvaluateSplitAgent {
public: public:
using ArgMaxT = cub::KeyValuePair<int, float>; using ArgMaxT = cub::KeyValuePair<int, float>;
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>; using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
using MaxReduceT = using MaxReduceT = cub::WarpReduce<ArgMaxT>;
cub::WarpReduce<ArgMaxT>;
using SumReduceT = cub::WarpReduce<GradientPairPrecise>; using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
struct TempStorage { struct TempStorage {
typename BlockScanT::TempStorage scan; typename BlockScanT::TempStorage scan;
typename MaxReduceT::TempStorage max_reduce; typename MaxReduceT::TempStorage max_reduce;
@ -159,50 +159,82 @@ class EvaluateSplitAgent {
if (threadIdx.x == best_thread) { if (threadIdx.x == best_thread) {
int32_t split_gidx = (scan_begin + threadIdx.x); int32_t split_gidx = (scan_begin + threadIdx.x);
float fvalue = feature_values[split_gidx]; float fvalue = feature_values[split_gidx];
GradientPairPrecise left = GradientPairPrecise left = missing_left ? bin + missing : bin;
missing_left ? bin + missing : bin;
GradientPairPrecise right = parent_sum - left; GradientPairPrecise right = parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir,
true, param); static_cast<bst_cat_t>(fvalue), fidx, left, right, param);
} }
} }
} }
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split, /**
bst_feature_t * __restrict__ sorted_idx, * \brief Gather and update the best split.
std::size_t offset) { */
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) { __device__ __forceinline__ void PartitionUpdate(bst_bin_t scan_begin, bool thread_active,
bool thread_active = (scan_begin + threadIdx.x) < gidx_end; bool missing_left, bst_bin_t it,
GradientPairPrecise const &left_sum,
auto rest = thread_active GradientPairPrecise const &right_sum,
? LoadGpair(node_histogram + sorted_idx[scan_begin + threadIdx.x] - offset) DeviceSplitCandidate *__restrict__ best_split) {
: GradientPairPrecise(); auto gain =
// No min value for cat feature, use inclusive scan. thread_active ? evaluator.CalcSplitGain(param, nidx, fidx, left_sum, right_sum) : kNullGain;
BlockScanT(temp_storage->scan).InclusiveSum(rest, rest, prefix_op);
GradientPairPrecise bin = parent_sum - rest - missing;
// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
evaluator, missing_left)
: kNullGain;
// Find thread with best gain // Find thread with best gain
auto best = auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
// This reduce result is only valid in thread 0 // This reduce result is only valid in thread 0
// broadcast to the rest of the warp // broadcast to the rest of the warp
auto best_thread = __shfl_sync(0xffffffff, best.key, 0); auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
// Best thread updates the split // Best thread updates the split
if (threadIdx.x == best_thread) { if (threadIdx.x == best_thread) {
GradientPairPrecise left = missing_left ? bin + missing : bin; assert(thread_active);
GradientPairPrecise right = parent_sum - left; // index of best threshold inside a feature.
auto best_thresh = auto best_thresh = it - gidx_begin;
threadIdx.x + (scan_begin - gidx_begin); // index of best threshold inside a feature. best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left_sum,
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left, right_sum, param);
right, true, param);
} }
} }
/**
* \brief Partition-based split for categorical feature.
*/
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split,
common::Span<bst_feature_t> sorted_idx,
std::size_t node_offset,
GPUTrainingParam const &param) {
bst_bin_t n_bins_feature = gidx_end - gidx_begin;
auto n_bins = std::min(param.max_cat_threshold, n_bins_feature);
bst_bin_t it_begin = gidx_begin;
bst_bin_t it_end = it_begin + n_bins - 1;
// forward
for (bst_bin_t scan_begin = it_begin; scan_begin < it_end; scan_begin += kBlockSize) {
auto it = scan_begin + static_cast<bst_bin_t>(threadIdx.x);
bool thread_active = it < it_end;
auto right_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
: GradientPairPrecise();
// No min value for cat feature, use inclusive scan.
BlockScanT(temp_storage->scan).InclusiveSum(right_sum, right_sum, prefix_op);
GradientPairPrecise left_sum = parent_sum - right_sum;
PartitionUpdate(scan_begin, thread_active, true, it, left_sum, right_sum, best_split);
}
// backward
it_begin = gidx_end - 1;
it_end = it_begin - n_bins + 1;
prefix_op = SumCallbackOp<GradientPairPrecise>{}; // reset
for (bst_bin_t scan_begin = it_begin; scan_begin > it_end; scan_begin -= kBlockSize) {
auto it = scan_begin - static_cast<bst_bin_t>(threadIdx.x);
bool thread_active = it > it_end;
auto left_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
: GradientPairPrecise();
// No min value for cat feature, use inclusive scan.
BlockScanT(temp_storage->scan).InclusiveSum(left_sum, left_sum, prefix_op);
GradientPairPrecise right_sum = parent_sum - left_sum;
PartitionUpdate(scan_begin, thread_active, false, it, left_sum, right_sum, best_split);
}
} }
}; };
@ -242,7 +274,7 @@ __global__ __launch_bounds__(kBlockSize) void EvaluateSplitsKernel(
auto total_bins = shared_inputs.feature_values.size(); auto total_bins = shared_inputs.feature_values.size();
size_t offset = total_bins * input_idx; size_t offset = total_bins * input_idx;
auto node_sorted_idx = sorted_idx.subspan(offset, total_bins); auto node_sorted_idx = sorted_idx.subspan(offset, total_bins);
agent.Partition(&best_split, node_sorted_idx.data(), offset); agent.Partition(&best_split, node_sorted_idx, offset, shared_inputs.param);
} }
} else { } else {
agent.Numerical(&best_split); agent.Numerical(&best_split);
@ -273,37 +305,29 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
// Simple case for one hot split // Simple case for one hot split
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) { if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
out_split.split_cats.Set(common::AsCat(out_split.fvalue)); out_split.split_cats.Set(common::AsCat(out_split.thresh));
return; return;
} }
// partition-based split
auto node_sorted_idx = d_sorted_idx.subspan(shared_inputs.feature_values.size() * input_idx, auto node_sorted_idx = d_sorted_idx.subspan(shared_inputs.feature_values.size() * input_idx,
shared_inputs.feature_values.size()); shared_inputs.feature_values.size());
size_t node_offset = input_idx * shared_inputs.feature_values.size(); size_t node_offset = input_idx * shared_inputs.feature_values.size();
auto best_thresh = out_split.PopBestThresh(); auto const best_thresh = out_split.thresh;
if (best_thresh == -1) {
return;
}
auto f_sorted_idx = node_sorted_idx.subspan(shared_inputs.feature_segments[fidx], auto f_sorted_idx = node_sorted_idx.subspan(shared_inputs.feature_segments[fidx],
shared_inputs.FeatureBins(fidx)); shared_inputs.FeatureBins(fidx));
if (out_split.dir != kLeftDir) { bool forward = out_split.dir == kLeftDir;
// forward, missing on right bst_bin_t partition = forward ? best_thresh + 1 : best_thresh;
auto beg = dh::tcbegin(f_sorted_idx); auto beg = dh::tcbegin(f_sorted_idx);
// Don't put all the categories into one side assert(partition > 0 && "Invalid partition.");
auto boundary = std::min(static_cast<size_t>((best_thresh + 1)), (f_sorted_idx.size() - 1)); thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
boundary = std::max(boundary, static_cast<size_t>(1ul));
auto end = beg + boundary;
thrust::for_each(thrust::seq, beg, end, [&](auto c) {
auto cat = shared_inputs.feature_values[c - node_offset];
assert(!out_split.split_cats.Check(cat) && "already set");
out_split.SetCat(cat);
});
} else {
assert((f_sorted_idx.size() - best_thresh + 1) != 0 && " == 0");
thrust::for_each(thrust::seq, dh::tcrbegin(f_sorted_idx),
dh::tcrbegin(f_sorted_idx) + (f_sorted_idx.size() - best_thresh), [&](auto c) {
auto cat = shared_inputs.feature_values[c - node_offset]; auto cat = shared_inputs.feature_values[c - node_offset];
out_split.SetCat(cat); out_split.SetCat(cat);
}); });
} }
}
void GPUHistEvaluator::LaunchEvaluateSplits( void GPUHistEvaluator::LaunchEvaluateSplits(
bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs, bst_feature_t number_active_features, common::Span<const EvaluateSplitInputs> d_inputs,

View File

@ -141,7 +141,8 @@ class GPUHistEvaluator {
*/ */
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const { common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
copy_stream_.View().Sync(); copy_stream_.View().Sync();
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_); auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
return cats_out; return cats_out;
} }
/** /**

View File

@ -143,8 +143,12 @@ class HistEvaluator {
static_assert(d_step == +1 || d_step == -1, "Invalid step."); static_assert(d_step == +1 || d_step == -1, "Invalid step.");
auto const &cut_ptr = cut.Ptrs(); auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &parent = snode_[nidx]; auto const &parent = snode_[nidx];
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
bst_bin_t f_begin = cut_ptr[fidx];
bst_bin_t f_end = cut_ptr[fidx + 1];
bst_bin_t n_bins_feature{f_end - f_begin};
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature); auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
// statistics on both sides of split // statistics on both sides of split
@ -153,19 +157,18 @@ class HistEvaluator {
// best split so far // best split so far
SplitEntry best; SplitEntry best;
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature); auto f_hist = hist.subspan(f_begin, n_bins_feature);
bst_bin_t ibegin, iend; bst_bin_t it_begin, it_end;
bst_bin_t f_begin = cut_ptr[fidx];
if (d_step > 0) { if (d_step > 0) {
ibegin = f_begin; it_begin = f_begin;
iend = ibegin + n_bins - 1; it_end = it_begin + n_bins - 1;
} else { } else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1; it_begin = f_end - 1;
iend = ibegin - n_bins + 1; it_end = it_begin - n_bins + 1;
} }
bst_bin_t best_thresh{-1}; bst_bin_t best_thresh{-1};
for (bst_bin_t i = ibegin; i != iend; i += d_step) { for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
auto j = i - f_begin; // index local to current feature auto j = i - f_begin; // index local to current feature
if (d_step == 1) { if (d_step == 1) {
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess()); right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
@ -187,13 +190,15 @@ class HistEvaluator {
} }
if (best_thresh != -1) { if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1); auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
best.cat_bits = decltype(best.cat_bits)(n, 0); best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits}; common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin); bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0); CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
[&](size_t c) { cat_bits.Set(c); }); auto cat = cut_val[c + f_begin];
cat_bits.Set(cat);
});
} }
p_best->Update(best); p_best->Update(best);

View File

@ -29,6 +29,7 @@ struct GPUTrainingParam {
float max_delta_step; float max_delta_step;
float learning_rate; float learning_rate;
uint32_t max_cat_to_onehot; uint32_t max_cat_to_onehot;
bst_bin_t max_cat_threshold;
GPUTrainingParam() = default; GPUTrainingParam() = default;
@ -38,7 +39,8 @@ struct GPUTrainingParam {
reg_alpha(param.reg_alpha), reg_alpha(param.reg_alpha),
max_delta_step(param.max_delta_step), max_delta_step(param.max_delta_step),
learning_rate{param.learning_rate}, learning_rate{param.learning_rate},
max_cat_to_onehot{param.max_cat_to_onehot} {} max_cat_to_onehot{param.max_cat_to_onehot},
max_cat_threshold{param.max_cat_threshold} {}
}; };
/** /**
@ -57,6 +59,9 @@ struct DeviceSplitCandidate {
DefaultDirection dir {kLeftDir}; DefaultDirection dir {kLeftDir};
int findex {-1}; int findex {-1};
float fvalue {0}; float fvalue {0};
// categorical split, either it's the split category for OHE or the threshold for partition-based
// split.
bst_cat_t thresh{-1};
common::CatBitField split_cats; common::CatBitField split_cats;
bool is_cat { false }; bool is_cat { false };
@ -75,22 +80,6 @@ struct DeviceSplitCandidate {
*this = other; *this = other;
} }
} }
/**
* \brief The largest encoded category in the split bitset
*/
bst_cat_t MaxCat() const {
// Reuse the fvalue for categorical values.
return static_cast<bst_cat_t>(fvalue);
}
/**
* \brief Return the best threshold for cat split, reset the value after return.
*/
XGBOOST_DEVICE size_t PopBestThresh() {
// fvalue is also being used for storing the threshold for categorical split
auto best_thresh = static_cast<size_t>(this->fvalue);
this->fvalue = 0;
return best_thresh;
}
template <typename T> template <typename T>
XGBOOST_DEVICE void SetCat(T c) { XGBOOST_DEVICE void SetCat(T c) {
@ -116,6 +105,26 @@ struct DeviceSplitCandidate {
findex = findex_in; findex = findex_in;
} }
} }
/**
* \brief Update for partition-based splits.
*/
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
bst_feature_t findex_in, GradientPairPrecise left_sum_in,
GradientPairPrecise right_sum_in, GPUTrainingParam const& param) {
if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
loss_chg = loss_chg_in;
dir = dir_in;
fvalue = std::numeric_limits<float>::quiet_NaN();
thresh = thresh_in;
is_cat = true;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
}
}
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; } XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) { friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
@ -123,6 +132,7 @@ struct DeviceSplitCandidate {
<< "dir: " << c.dir << ", " << "dir: " << c.dir << ", "
<< "findex: " << c.findex << ", " << "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", " << "fvalue: " << c.fvalue << ", "
<< "thresh: " << c.thresh << ", "
<< "is_cat: " << c.is_cat << ", " << "is_cat: " << c.is_cat << ", "
<< "left sum: " << c.left_sum << ", " << "left sum: " << c.left_sum << ", "
<< "right sum: " << c.right_sum << std::endl; << "right sum: " << c.right_sum << std::endl;

View File

@ -601,13 +601,14 @@ struct GPUHistMakerDevice {
auto is_cat = candidate.split.is_cat; auto is_cat = candidate.split.is_cat;
if (is_cat) { if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max()) // should be set to nan in evaluation split.
<< "Categorical feature value too large."; CHECK(common::CheckNAN(candidate.split.fvalue));
std::vector<uint32_t> split_cats; std::vector<common::CatBitField::value_type> split_cats;
CHECK_GT(candidate.split.split_cats.Bits().size(), 0); CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto max_cat = candidate.split.MaxCat(); auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
CHECK_LE(split_cats.size(), h_cats.size()); CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
@ -616,6 +617,7 @@ struct GPUHistMakerDevice {
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else { } else {
CHECK(!common::CheckNAN(candidate.split.fvalue));
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue, tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight, candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.loss_chg, parent_sum.GetHess(),

View File

@ -2,6 +2,7 @@
* Copyright 2020-2022 by XGBoost contributors * Copyright 2020-2022 by XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../../../src/tree/gpu_hist/evaluate_splits.cuh" #include "../../../../src/tree/gpu_hist/evaluate_splits.cuh"
#include "../../helpers.h" #include "../../helpers.h"
#include "../../histogram_helpers.h" #include "../../histogram_helpers.h"
@ -17,29 +18,292 @@ auto ZeroParam() {
tparam.UpdateAllowUnknown(args); tparam.UpdateAllowUnknown(args);
return tparam; return tparam;
} }
} // anonymous namespace } // anonymous namespace
TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
GPUTrainingParam param{param_};
cuts_.cut_ptrs_.SetDevice(0);
cuts_.cut_values_.SetDevice(0);
cuts_.min_vals_.SetDevice(0);
thrust::device_vector<GradientPairPrecise> feature_histogram{feature_histogram_};
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
auto d_feature_types = dh::ToSpan(feature_types);
EvaluateSplitInputs input{1, 0, parent_sum_, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
EvaluateSplitSharedInputs shared_inputs{
param,
d_feature_types,
cuts_.cut_ptrs_.ConstDeviceSpan(),
cuts_.cut_values_.ConstDeviceSpan(),
cuts_.min_vals_.ConstDeviceSpan(),
};
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), 0};
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, 0);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
ASSERT_EQ(result.thresh, 1);
this->CheckResult(result.loss_chg, result.findex, result.fvalue, result.is_cat,
result.dir == kLeftDir, result.left_sum, result.right_sum);
}
TEST(GpuHist, PartitionBasic) {
TrainParam tparam = ZeroParam();
tparam.max_cat_to_onehot = 0;
GPUTrainingParam param{tparam};
common::HistogramCuts cuts;
cuts.cut_values_.HostVector() = std::vector<float>{0.0, 1.0, 2.0};
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 3};
cuts.min_vals_.HostVector() = std::vector<float>{0.0};
cuts.cut_ptrs_.SetDevice(0);
cuts.cut_values_.SetDevice(0);
cuts.min_vals_.SetDevice(0);
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
common::Span<FeatureType> d_feature_types;
auto max_cat =
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
cuts.SetCategorical(true, max_cat);
d_feature_types = dh::ToSpan(feature_types);
EvaluateSplitSharedInputs shared_inputs{
param,
d_feature_types,
cuts.cut_ptrs_.ConstDeviceSpan(),
cuts.cut_values_.ConstDeviceSpan(),
cuts.min_vals_.ConstDeviceSpan(),
};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
{
// -1.0s go right
// -3.0s go left
GradientPairPrecise parent_sum(-5.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}};
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(result.dir, kLeftDir);
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
{
// -1.0s go right
// -3.0s go left
GradientPairPrecise parent_sum(-7.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}};
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(result.dir, kLeftDir);
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
{
// All -1.0, gain from splitting should be 0.0
GradientPairPrecise parent_sum(-3.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}};
EvaluateSplitInputs input{2, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
EXPECT_EQ(result.dir, kLeftDir);
EXPECT_FLOAT_EQ(result.loss_chg, 0.0f);
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
// With 3.0/3.0 missing values
// Forward, first 2 categories are selected, while the last one go to left along with missing value
{
GradientPairPrecise parent_sum(0.0, 6.0);
thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}};
EvaluateSplitInputs input{3, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
EXPECT_EQ(result.dir, kLeftDir);
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
{
// -1.0s go right
// -3.0s go left
GradientPairPrecise parent_sum(-5.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}};
EvaluateSplitInputs input{4, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(result.dir, kLeftDir);
EXPECT_EQ(cats, std::bitset<32>("10100000000000000000000000000000"));
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
{
// -1.0s go right
// -3.0s go left
GradientPairPrecise parent_sum(-5.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}};
EvaluateSplitInputs input{5, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(cats, std::bitset<32>("01000000000000000000000000000000"));
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
}
TEST(GpuHist, PartitionTwoFeatures) {
TrainParam tparam = ZeroParam();
tparam.max_cat_to_onehot = 0;
GPUTrainingParam param{tparam};
common::HistogramCuts cuts;
cuts.cut_values_.HostVector() = std::vector<float>{0.0, 1.0, 2.0, 0.0, 1.0, 2.0};
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 3, 6};
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
cuts.cut_ptrs_.SetDevice(0);
cuts.cut_values_.SetDevice(0);
cuts.min_vals_.SetDevice(0);
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
common::Span<FeatureType> d_feature_types(dh::ToSpan(feature_types));
auto max_cat =
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
cuts.SetCategorical(true, max_cat);
EvaluateSplitSharedInputs shared_inputs{
param,
d_feature_types,
cuts.cut_ptrs_.ConstDeviceSpan(),
cuts.cut_values_.ConstDeviceSpan(),
cuts.min_vals_.ConstDeviceSpan(),
};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
{
GradientPairPrecise parent_sum(-6.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram = std::vector<GradientPairPrecise>{
{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(result.findex, 1);
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
{
GradientPairPrecise parent_sum(-6.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram = std::vector<GradientPairPrecise>{
{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}};
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)};
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
EXPECT_EQ(result.findex, 1);
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
}
}
TEST(GpuHist, PartitionTwoNodes) {
TrainParam tparam = ZeroParam();
tparam.max_cat_to_onehot = 0;
GPUTrainingParam param{tparam};
common::HistogramCuts cuts;
cuts.cut_values_.HostVector() = std::vector<float>{0.0, 1.0, 2.0};
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 3};
cuts.min_vals_.HostVector() = std::vector<float>{0.0};
cuts.cut_ptrs_.SetDevice(0);
cuts.cut_values_.SetDevice(0);
cuts.min_vals_.SetDevice(0);
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
common::Span<FeatureType> d_feature_types(dh::ToSpan(feature_types));
auto max_cat =
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
cuts.SetCategorical(true, max_cat);
EvaluateSplitSharedInputs shared_inputs{
param,
d_feature_types,
cuts.cut_ptrs_.ConstDeviceSpan(),
cuts.cut_values_.ConstDeviceSpan(),
cuts.min_vals_.ConstDeviceSpan(),
};
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
{
GradientPairPrecise parent_sum(-6.0, 3.0);
thrust::device_vector<GradientPairPrecise> feature_histogram_a =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0},
{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
thrust::device_vector<EvaluateSplitInputs> inputs(2);
inputs[0] = EvaluateSplitInputs{0, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram_a)};
thrust::device_vector<GradientPairPrecise> feature_histogram_b =
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
inputs[1] = EvaluateSplitInputs{1, 0, parent_sum, dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram_b)};
thrust::device_vector<GPUExpandEntry> results(2);
evaluator.EvaluateSplits({0, 1}, 1, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(results));
GPUExpandEntry result_a = results[0];
GPUExpandEntry result_b = results[1];
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(0)[0]),
std::bitset<32>("10000000000000000000000000000000"));
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(1)[0]),
std::bitset<32>("11000000000000000000000000000000"));
}
}
void TestEvaluateSingleSplit(bool is_categorical) { void TestEvaluateSingleSplit(bool is_categorical) {
GradientPairPrecise parent_sum(0.0, 1.0); GradientPairPrecise parent_sum(0.0, 1.0);
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
GPUTrainingParam param{tparam}; GPUTrainingParam param{tparam};
common::HistogramCuts cuts; common::HistogramCuts cuts{MakeCutsForTest({1.0, 2.0, 11.0, 12.0}, {0, 2, 4}, {0.0, 0.0}, 0)};
cuts.cut_values_.HostVector() = std::vector<float>{1.0, 2.0, 11.0, 12.0}; thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 2, 4};
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
cuts.cut_ptrs_.SetDevice(0);
cuts.cut_values_.SetDevice(0);
cuts.min_vals_.SetDevice(0);
thrust::device_vector<bst_feature_t> feature_set =
std::vector<bst_feature_t>{0, 1};
// Setup gradients so that second feature gets higher gain // Setup gradients so that second feature gets higher gain
thrust::device_vector<GradientPairPrecise> feature_histogram = thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{ std::vector<GradientPairPrecise>{
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}}; {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
dh::device_vector<FeatureType> feature_types(feature_set.size(), dh::device_vector<FeatureType> feature_types(feature_set.size(),
FeatureType::kCategorical); FeatureType::kCategorical);
common::Span<FeatureType> d_feature_types; common::Span<FeatureType> d_feature_types;
@ -50,9 +314,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
d_feature_types = dh::ToSpan(feature_types); d_feature_types = dh::ToSpan(feature_types);
} }
EvaluateSplitInputs input{1,0, EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
parent_sum,
dh::ToSpan(feature_set),
dh::ToSpan(feature_histogram)}; dh::ToSpan(feature_histogram)};
EvaluateSplitSharedInputs shared_inputs{ EvaluateSplitSharedInputs shared_inputs{
param, param,
@ -68,7 +330,11 @@ void TestEvaluateSingleSplit(bool is_categorical) {
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split; DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.findex, 1);
if (is_categorical) {
ASSERT_TRUE(std::isnan(result.fvalue));
} else {
EXPECT_EQ(result.fvalue, 11.0); EXPECT_EQ(result.fvalue, 11.0);
}
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
parent_sum.GetGrad()); parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(),
@ -79,7 +345,7 @@ TEST(GpuHist, EvaluateSingleSplit) {
TestEvaluateSingleSplit(false); TestEvaluateSingleSplit(false);
} }
TEST(GpuHist, EvaluateCategoricalSplit) { TEST(GpuHist, EvaluateSingleCategoricalSplit) {
TestEvaluateSingleSplit(true); TestEvaluateSingleSplit(true);
} }
@ -96,7 +362,6 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0}; thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
thrust::device_vector<GradientPairPrecise> feature_histogram = thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{{-0.5, 0.5}, {0.5, 0.5}}; std::vector<GradientPairPrecise>{{-0.5, 0.5}, {0.5, 0.5}};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
EvaluateSplitInputs input{1,0, EvaluateSplitInputs input{1,0,
parent_sum, parent_sum,
dh::ToSpan(feature_set), dh::ToSpan(feature_set),
@ -146,7 +411,6 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
thrust::device_vector<GradientPairPrecise> feature_histogram = thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{ std::vector<GradientPairPrecise>{
{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}}; {-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
thrust::device_vector<int> monotonic_constraints(2, 0);
EvaluateSplitInputs input{1,0, EvaluateSplitInputs input{1,0,
parent_sum, parent_sum,
dh::ToSpan(feature_set), dh::ToSpan(feature_set),
@ -186,7 +450,6 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
thrust::device_vector<GradientPairPrecise> feature_histogram = thrust::device_vector<GradientPairPrecise> feature_histogram =
std::vector<GradientPairPrecise>{ std::vector<GradientPairPrecise>{
{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}}; {-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
thrust::device_vector<int> monotonic_constraints(2, 0);
EvaluateSplitInputs input{1,0, EvaluateSplitInputs input{1,0,
parent_sum, parent_sum,
dh::ToSpan(feature_set), dh::ToSpan(feature_set),
@ -227,7 +490,6 @@ TEST(GpuHist, EvaluateSplits) {
thrust::device_vector<GradientPairPrecise> feature_histogram_right = thrust::device_vector<GradientPairPrecise> feature_histogram_right =
std::vector<GradientPairPrecise>{ std::vector<GradientPairPrecise>{
{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}}; {-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
EvaluateSplitInputs input_left{ EvaluateSplitInputs input_left{
1,0, 1,0,
parent_sum, parent_sum,
@ -263,8 +525,7 @@ TEST(GpuHist, EvaluateSplits) {
TEST_F(TestPartitionBasedSplit, GpuHist) { TEST_F(TestPartitionBasedSplit, GpuHist) {
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}}; dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
GPUHistEvaluator evaluator{param_, GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(info_.num_col_), 0};
static_cast<bst_feature_t>(info_.num_col_), 0};
cuts_.cut_ptrs_.SetDevice(0); cuts_.cut_ptrs_.SetDevice(0);
cuts_.cut_values_.SetDevice(0); cuts_.cut_values_.SetDevice(0);
@ -287,6 +548,5 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split; auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16); ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
} }
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -185,5 +185,33 @@ TEST(HistEvaluator, Categorical) {
ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg); ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg);
} }
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
common::HistCollection hist;
hist.Init(cuts_.TotalBins());
hist.AddHistRow(0);
hist.AllocateAllData();
auto node_hist = hist[0];
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin());
auto sampler = std::make_shared<common::ColumnSampler>();
MetaInfo info;
info.num_col_ = 1;
info.feature_types = {FeatureType::kCategorical};
auto evaluator =
HistEvaluator<CPUExpandEntry>{param_, info, common::OmpGetNumThreads(0), sampler};
evaluator.InitRoot(GradStats{parent_sum_});
std::vector<CPUExpandEntry> entries(1);
RegTree tree;
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
auto const& split = entries.front().split;
this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat,
split.DefaultLeft(),
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -97,5 +97,59 @@ class TestPartitionBasedSplit : public ::testing::Test {
} while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end())); } while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
} }
}; };
inline auto MakeCutsForTest(std::vector<float> values, std::vector<uint32_t> ptrs,
std::vector<float> min_values, int32_t device) {
common::HistogramCuts cuts;
cuts.cut_values_.HostVector() = values;
cuts.cut_ptrs_.HostVector() = ptrs;
cuts.min_vals_.HostVector() = min_values;
if (device >= 0) {
cuts.cut_ptrs_.SetDevice(device);
cuts.cut_values_.SetDevice(device);
cuts.min_vals_.SetDevice(device);
}
return cuts;
}
class TestCategoricalSplitWithMissing : public testing::Test {
protected:
common::HistogramCuts cuts_;
// Setup gradients and parent sum with missing values.
GradientPairPrecise parent_sum_{1.0, 6.0};
std::vector<GradientPairPrecise> feature_histogram_{
{0.5, 0.5}, {0.5, 0.5}, {1.0, 1.0}, {1.0, 1.0}};
TrainParam param_;
void SetUp() override {
cuts_ = MakeCutsForTest({0.0, 1.0, 2.0, 3.0}, {0, 4}, {0.0}, -1);
auto max_cat = *std::max_element(cuts_.cut_values_.HostVector().begin(),
cuts_.cut_values_.HostVector().end());
cuts_.SetCategorical(true, max_cat);
param_.UpdateAllowUnknown(
Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}});
}
void CheckResult(float loss_chg, bst_feature_t split_ind, float fvalue, bool is_cat,
bool dft_left, GradientPairPrecise left_sum, GradientPairPrecise right_sum) {
// forward
// it: 0, gain: 0.545455
// it: 1, gain: 1.000000
// it: 2, gain: 2.250000
// backward
// it: 3, gain: 1.000000
// it: 2, gain: 2.250000
// it: 1, gain: 3.142857
ASSERT_NEAR(loss_chg, 2.97619, kRtEps);
ASSERT_TRUE(is_cat);
ASSERT_TRUE(std::isnan(fvalue));
ASSERT_EQ(split_ind, 0);
ASSERT_FALSE(dft_left);
ASSERT_EQ(left_sum.GetHess(), 2.5);
ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
}
};
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -1,3 +1,4 @@
from typing import Dict, Any
import numpy as np import numpy as np
import sys import sys
import gc import gc
@ -77,6 +78,48 @@ class TestGPUUpdaters:
def test_categorical_ohe(self, rows, cols, rounds, cats): def test_categorical_ohe(self, rows, cols, rounds, cats):
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
@given(
tm.categorical_dataset_strategy,
test_up.exact_parameter_strategy,
test_up.hist_parameter_strategy,
test_up.cat_parameter_strategy,
strategies.integers(4, 32),
)
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(
self,
dataset: tm.TestDataset,
exact_parameters: Dict[str, Any],
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
n_rounds: int,
) -> None:
cat_parameters.update(exact_parameters)
cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = "gpu_hist"
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"])
@given(
test_up.hist_parameter_strategy,
test_up.cat_parameter_strategy,
)
@settings(deadline=None, print_blob=True)
def test_categorical_ames_housing(
self,
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
) -> None:
cat_parameters.update(hist_parameters)
dataset = tm.TestDataset(
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
)
cat_parameters["tree_method"] = "gpu_hist"
results = train_result(cat_parameters, dataset.get_dmat(), 16)
tm.non_increasing(results["train"]["rmse"])
@given( @given(
strategies.integers(10, 400), strategies.integers(10, 400),
strategies.integers(3, 8), strategies.integers(3, 8),

View File

@ -234,7 +234,7 @@ class TestTreeMethod:
) -> None: ) -> None:
parameters: Dict[str, Any] = {"tree_method": tree_method} parameters: Dict[str, Any] = {"tree_method": tree_method}
cat, label = tm.make_categorical( cat, label = tm.make_categorical(
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5 n_samples=rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
) )
Xy = xgb.DMatrix(cat, label, enable_categorical=True) Xy = xgb.DMatrix(cat, label, enable_categorical=True)
@ -259,9 +259,6 @@ class TestTreeMethod:
# Test with OHE split # Test with OHE split
run(self.USE_ONEHOT) run(self.USE_ONEHOT)
if tree_method == "gpu_hist": # fixme: Test with GPU.
return
# Test with partition-based split # Test with partition-based split
run(self.USE_PART) run(self.USE_PART)