Add max_cat_threshold to GPU and handle missing cat values. (#8212)
This commit is contained in:
parent
441ffc017a
commit
b5eb36f1af
@ -43,9 +43,9 @@ class EvaluateSplitAgent {
|
||||
public:
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
|
||||
using MaxReduceT =
|
||||
cub::WarpReduce<ArgMaxT>;
|
||||
using MaxReduceT = cub::WarpReduce<ArgMaxT>;
|
||||
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
|
||||
|
||||
struct TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename MaxReduceT::TempStorage max_reduce;
|
||||
@ -159,50 +159,82 @@ class EvaluateSplitAgent {
|
||||
if (threadIdx.x == best_thread) {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
true, param);
|
||||
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir,
|
||||
static_cast<bst_cat_t>(fvalue), fidx, left, right, param);
|
||||
}
|
||||
}
|
||||
}
|
||||
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split,
|
||||
bst_feature_t * __restrict__ sorted_idx,
|
||||
std::size_t offset) {
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
auto rest = thread_active
|
||||
? LoadGpair(node_histogram + sorted_idx[scan_begin + threadIdx.x] - offset)
|
||||
: GradientPairPrecise();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
BlockScanT(temp_storage->scan).InclusiveSum(rest, rest, prefix_op);
|
||||
GradientPairPrecise bin = parent_sum - rest - missing;
|
||||
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
|
||||
evaluator, missing_left)
|
||||
: kNullGain;
|
||||
|
||||
/**
|
||||
* \brief Gather and update the best split.
|
||||
*/
|
||||
__device__ __forceinline__ void PartitionUpdate(bst_bin_t scan_begin, bool thread_active,
|
||||
bool missing_left, bst_bin_t it,
|
||||
GradientPairPrecise const &left_sum,
|
||||
GradientPairPrecise const &right_sum,
|
||||
DeviceSplitCandidate *__restrict__ best_split) {
|
||||
auto gain =
|
||||
thread_active ? evaluator.CalcSplitGain(param, nidx, fidx, left_sum, right_sum) : kNullGain;
|
||||
|
||||
// Find thread with best gain
|
||||
auto best =
|
||||
MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
|
||||
// This reduce result is only valid in thread 0
|
||||
// broadcast to the rest of the warp
|
||||
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
|
||||
// Best thread updates the split
|
||||
if (threadIdx.x == best_thread) {
|
||||
GradientPairPrecise left = missing_left ? bin + missing : bin;
|
||||
GradientPairPrecise right = parent_sum - left;
|
||||
auto best_thresh =
|
||||
threadIdx.x + (scan_begin - gidx_begin); // index of best threshold inside a feature.
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
|
||||
right, true, param);
|
||||
assert(thread_active);
|
||||
// index of best threshold inside a feature.
|
||||
auto best_thresh = it - gidx_begin;
|
||||
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left_sum,
|
||||
right_sum, param);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* \brief Partition-based split for categorical feature.
|
||||
*/
|
||||
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split,
|
||||
common::Span<bst_feature_t> sorted_idx,
|
||||
std::size_t node_offset,
|
||||
GPUTrainingParam const ¶m) {
|
||||
bst_bin_t n_bins_feature = gidx_end - gidx_begin;
|
||||
auto n_bins = std::min(param.max_cat_threshold, n_bins_feature);
|
||||
|
||||
bst_bin_t it_begin = gidx_begin;
|
||||
bst_bin_t it_end = it_begin + n_bins - 1;
|
||||
|
||||
// forward
|
||||
for (bst_bin_t scan_begin = it_begin; scan_begin < it_end; scan_begin += kBlockSize) {
|
||||
auto it = scan_begin + static_cast<bst_bin_t>(threadIdx.x);
|
||||
bool thread_active = it < it_end;
|
||||
|
||||
auto right_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
|
||||
: GradientPairPrecise();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
BlockScanT(temp_storage->scan).InclusiveSum(right_sum, right_sum, prefix_op);
|
||||
GradientPairPrecise left_sum = parent_sum - right_sum;
|
||||
|
||||
PartitionUpdate(scan_begin, thread_active, true, it, left_sum, right_sum, best_split);
|
||||
}
|
||||
|
||||
// backward
|
||||
it_begin = gidx_end - 1;
|
||||
it_end = it_begin - n_bins + 1;
|
||||
prefix_op = SumCallbackOp<GradientPairPrecise>{}; // reset
|
||||
|
||||
for (bst_bin_t scan_begin = it_begin; scan_begin > it_end; scan_begin -= kBlockSize) {
|
||||
auto it = scan_begin - static_cast<bst_bin_t>(threadIdx.x);
|
||||
bool thread_active = it > it_end;
|
||||
|
||||
auto left_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
|
||||
: GradientPairPrecise();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
BlockScanT(temp_storage->scan).InclusiveSum(left_sum, left_sum, prefix_op);
|
||||
GradientPairPrecise right_sum = parent_sum - left_sum;
|
||||
|
||||
PartitionUpdate(scan_begin, thread_active, false, it, left_sum, right_sum, best_split);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -242,7 +274,7 @@ __global__ __launch_bounds__(kBlockSize) void EvaluateSplitsKernel(
|
||||
auto total_bins = shared_inputs.feature_values.size();
|
||||
size_t offset = total_bins * input_idx;
|
||||
auto node_sorted_idx = sorted_idx.subspan(offset, total_bins);
|
||||
agent.Partition(&best_split, node_sorted_idx.data(), offset);
|
||||
agent.Partition(&best_split, node_sorted_idx, offset, shared_inputs.param);
|
||||
}
|
||||
} else {
|
||||
agent.Numerical(&best_split);
|
||||
@ -273,36 +305,28 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
|
||||
|
||||
// Simple case for one hot split
|
||||
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
|
||||
out_split.split_cats.Set(common::AsCat(out_split.fvalue));
|
||||
out_split.split_cats.Set(common::AsCat(out_split.thresh));
|
||||
return;
|
||||
}
|
||||
|
||||
// partition-based split
|
||||
auto node_sorted_idx = d_sorted_idx.subspan(shared_inputs.feature_values.size() * input_idx,
|
||||
shared_inputs.feature_values.size());
|
||||
size_t node_offset = input_idx * shared_inputs.feature_values.size();
|
||||
auto best_thresh = out_split.PopBestThresh();
|
||||
auto const best_thresh = out_split.thresh;
|
||||
if (best_thresh == -1) {
|
||||
return;
|
||||
}
|
||||
auto f_sorted_idx = node_sorted_idx.subspan(shared_inputs.feature_segments[fidx],
|
||||
shared_inputs.FeatureBins(fidx));
|
||||
if (out_split.dir != kLeftDir) {
|
||||
// forward, missing on right
|
||||
bool forward = out_split.dir == kLeftDir;
|
||||
bst_bin_t partition = forward ? best_thresh + 1 : best_thresh;
|
||||
auto beg = dh::tcbegin(f_sorted_idx);
|
||||
// Don't put all the categories into one side
|
||||
auto boundary = std::min(static_cast<size_t>((best_thresh + 1)), (f_sorted_idx.size() - 1));
|
||||
boundary = std::max(boundary, static_cast<size_t>(1ul));
|
||||
auto end = beg + boundary;
|
||||
thrust::for_each(thrust::seq, beg, end, [&](auto c) {
|
||||
auto cat = shared_inputs.feature_values[c - node_offset];
|
||||
assert(!out_split.split_cats.Check(cat) && "already set");
|
||||
out_split.SetCat(cat);
|
||||
});
|
||||
} else {
|
||||
assert((f_sorted_idx.size() - best_thresh + 1) != 0 && " == 0");
|
||||
thrust::for_each(thrust::seq, dh::tcrbegin(f_sorted_idx),
|
||||
dh::tcrbegin(f_sorted_idx) + (f_sorted_idx.size() - best_thresh), [&](auto c) {
|
||||
assert(partition > 0 && "Invalid partition.");
|
||||
thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
|
||||
auto cat = shared_inputs.feature_values[c - node_offset];
|
||||
out_split.SetCat(cat);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GPUHistEvaluator::LaunchEvaluateSplits(
|
||||
|
||||
@ -141,7 +141,8 @@ class GPUHistEvaluator {
|
||||
*/
|
||||
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
|
||||
copy_stream_.View().Sync();
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
|
||||
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
/**
|
||||
|
||||
@ -143,8 +143,12 @@ class HistEvaluator {
|
||||
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
auto const &cut_val = cut.Values();
|
||||
auto const &parent = snode_[nidx];
|
||||
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
bst_bin_t f_end = cut_ptr[fidx + 1];
|
||||
bst_bin_t n_bins_feature{f_end - f_begin};
|
||||
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
|
||||
|
||||
// statistics on both sides of split
|
||||
@ -153,19 +157,18 @@ class HistEvaluator {
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
|
||||
bst_bin_t ibegin, iend;
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
auto f_hist = hist.subspan(f_begin, n_bins_feature);
|
||||
bst_bin_t it_begin, it_end;
|
||||
if (d_step > 0) {
|
||||
ibegin = f_begin;
|
||||
iend = ibegin + n_bins - 1;
|
||||
it_begin = f_begin;
|
||||
it_end = it_begin + n_bins - 1;
|
||||
} else {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||
iend = ibegin - n_bins + 1;
|
||||
it_begin = f_end - 1;
|
||||
it_end = it_begin - n_bins + 1;
|
||||
}
|
||||
|
||||
bst_bin_t best_thresh{-1};
|
||||
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
|
||||
for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
|
||||
auto j = i - f_begin; // index local to current feature
|
||||
if (d_step == 1) {
|
||||
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
|
||||
@ -187,13 +190,15 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
if (best_thresh != -1) {
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
|
||||
best.cat_bits = decltype(best.cat_bits)(n, 0);
|
||||
common::CatBitField cat_bits{best.cat_bits};
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
|
||||
CHECK_GT(partition, 0);
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
|
||||
[&](size_t c) { cat_bits.Set(c); });
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
|
||||
auto cat = cut_val[c + f_begin];
|
||||
cat_bits.Set(cat);
|
||||
});
|
||||
}
|
||||
|
||||
p_best->Update(best);
|
||||
|
||||
@ -29,6 +29,7 @@ struct GPUTrainingParam {
|
||||
float max_delta_step;
|
||||
float learning_rate;
|
||||
uint32_t max_cat_to_onehot;
|
||||
bst_bin_t max_cat_threshold;
|
||||
|
||||
GPUTrainingParam() = default;
|
||||
|
||||
@ -38,7 +39,8 @@ struct GPUTrainingParam {
|
||||
reg_alpha(param.reg_alpha),
|
||||
max_delta_step(param.max_delta_step),
|
||||
learning_rate{param.learning_rate},
|
||||
max_cat_to_onehot{param.max_cat_to_onehot} {}
|
||||
max_cat_to_onehot{param.max_cat_to_onehot},
|
||||
max_cat_threshold{param.max_cat_threshold} {}
|
||||
};
|
||||
|
||||
/**
|
||||
@ -57,6 +59,9 @@ struct DeviceSplitCandidate {
|
||||
DefaultDirection dir {kLeftDir};
|
||||
int findex {-1};
|
||||
float fvalue {0};
|
||||
// categorical split, either it's the split category for OHE or the threshold for partition-based
|
||||
// split.
|
||||
bst_cat_t thresh{-1};
|
||||
|
||||
common::CatBitField split_cats;
|
||||
bool is_cat { false };
|
||||
@ -75,22 +80,6 @@ struct DeviceSplitCandidate {
|
||||
*this = other;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* \brief The largest encoded category in the split bitset
|
||||
*/
|
||||
bst_cat_t MaxCat() const {
|
||||
// Reuse the fvalue for categorical values.
|
||||
return static_cast<bst_cat_t>(fvalue);
|
||||
}
|
||||
/**
|
||||
* \brief Return the best threshold for cat split, reset the value after return.
|
||||
*/
|
||||
XGBOOST_DEVICE size_t PopBestThresh() {
|
||||
// fvalue is also being used for storing the threshold for categorical split
|
||||
auto best_thresh = static_cast<size_t>(this->fvalue);
|
||||
this->fvalue = 0;
|
||||
return best_thresh;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE void SetCat(T c) {
|
||||
@ -116,6 +105,26 @@ struct DeviceSplitCandidate {
|
||||
findex = findex_in;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Update for partition-based splits.
|
||||
*/
|
||||
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
|
||||
bst_feature_t findex_in, GradientPairPrecise left_sum_in,
|
||||
GradientPairPrecise right_sum_in, GPUTrainingParam const& param) {
|
||||
if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight &&
|
||||
right_sum_in.GetHess() >= param.min_child_weight) {
|
||||
loss_chg = loss_chg_in;
|
||||
dir = dir_in;
|
||||
fvalue = std::numeric_limits<float>::quiet_NaN();
|
||||
thresh = thresh_in;
|
||||
is_cat = true;
|
||||
left_sum = left_sum_in;
|
||||
right_sum = right_sum_in;
|
||||
findex = findex_in;
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
|
||||
@ -123,6 +132,7 @@ struct DeviceSplitCandidate {
|
||||
<< "dir: " << c.dir << ", "
|
||||
<< "findex: " << c.findex << ", "
|
||||
<< "fvalue: " << c.fvalue << ", "
|
||||
<< "thresh: " << c.thresh << ", "
|
||||
<< "is_cat: " << c.is_cat << ", "
|
||||
<< "left sum: " << c.left_sum << ", "
|
||||
<< "right sum: " << c.right_sum << std::endl;
|
||||
|
||||
@ -601,13 +601,14 @@ struct GPUHistMakerDevice {
|
||||
|
||||
auto is_cat = candidate.split.is_cat;
|
||||
if (is_cat) {
|
||||
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
||||
<< "Categorical feature value too large.";
|
||||
std::vector<uint32_t> split_cats;
|
||||
// should be set to nan in evaluation split.
|
||||
CHECK(common::CheckNAN(candidate.split.fvalue));
|
||||
std::vector<common::CatBitField::value_type> split_cats;
|
||||
|
||||
CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
|
||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||
auto max_cat = candidate.split.MaxCat();
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
|
||||
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
|
||||
CHECK_LE(split_cats.size(), h_cats.size());
|
||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||
|
||||
@ -616,6 +617,7 @@ struct GPUHistMakerDevice {
|
||||
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
} else {
|
||||
CHECK(!common::CheckNAN(candidate.split.fvalue));
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
|
||||
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
* Copyright 2020-2022 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../../src/tree/gpu_hist/evaluate_splits.cuh"
|
||||
#include "../../helpers.h"
|
||||
#include "../../histogram_helpers.h"
|
||||
@ -17,29 +18,292 @@ auto ZeroParam() {
|
||||
tparam.UpdateAllowUnknown(args);
|
||||
return tparam;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
|
||||
GPUTrainingParam param{param_};
|
||||
cuts_.cut_ptrs_.SetDevice(0);
|
||||
cuts_.cut_values_.SetDevice(0);
|
||||
cuts_.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram{feature_histogram_};
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
auto d_feature_types = dh::ToSpan(feature_types);
|
||||
|
||||
EvaluateSplitInputs input{1, 0, parent_sum_, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
d_feature_types,
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts_.cut_values_.ConstDeviceSpan(),
|
||||
cuts_.min_vals_.ConstDeviceSpan(),
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
|
||||
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, 0);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
|
||||
ASSERT_EQ(result.thresh, 1);
|
||||
this->CheckResult(result.loss_chg, result.findex, result.fvalue, result.is_cat,
|
||||
result.dir == kLeftDir, result.left_sum, result.right_sum);
|
||||
}
|
||||
|
||||
TEST(GpuHist, PartitionBasic) {
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.max_cat_to_onehot = 0;
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = std::vector<float>{0.0, 1.0, 2.0};
|
||||
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 3};
|
||||
cuts.min_vals_.HostVector() = std::vector<float>{0.0};
|
||||
cuts.cut_ptrs_.SetDevice(0);
|
||||
cuts.cut_values_.SetDevice(0);
|
||||
cuts.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
|
||||
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
auto max_cat =
|
||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}};
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-7.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}};
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
{
|
||||
// All -1.0, gain from splitting should be 0.0
|
||||
GradientPairPrecise parent_sum(-3.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}};
|
||||
EvaluateSplitInputs input{2, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_FLOAT_EQ(result.loss_chg, 0.0f);
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
// With 3.0/3.0 missing values
|
||||
// Forward, first 2 categories are selected, while the last one go to left along with missing value
|
||||
{
|
||||
GradientPairPrecise parent_sum(0.0, 6.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}};
|
||||
EvaluateSplitInputs input{3, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}};
|
||||
EvaluateSplitInputs input{4, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.dir, kLeftDir);
|
||||
EXPECT_EQ(cats, std::bitset<32>("10100000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
{
|
||||
// -1.0s go right
|
||||
// -3.0s go left
|
||||
GradientPairPrecise parent_sum(-5.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}};
|
||||
EvaluateSplitInputs input{5, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(cats, std::bitset<32>("01000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, PartitionTwoFeatures) {
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.max_cat_to_onehot = 0;
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = std::vector<float>{0.0, 1.0, 2.0, 0.0, 1.0, 2.0};
|
||||
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 3, 6};
|
||||
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
|
||||
cuts.cut_ptrs_.SetDevice(0);
|
||||
cuts.cut_values_.SetDevice(0);
|
||||
cuts.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types(dh::ToSpan(feature_types));
|
||||
auto max_cat =
|
||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram = std::vector<GradientPairPrecise>{
|
||||
{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram = std::vector<GradientPairPrecise>{
|
||||
{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}};
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(), parent_sum.GetHess());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, PartitionTwoNodes) {
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.max_cat_to_onehot = 0;
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = std::vector<float>{0.0, 1.0, 2.0};
|
||||
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 3};
|
||||
cuts.min_vals_.HostVector() = std::vector<float>{0.0};
|
||||
cuts.cut_ptrs_.SetDevice(0);
|
||||
cuts.cut_values_.SetDevice(0);
|
||||
cuts.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
|
||||
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types(dh::ToSpan(feature_types));
|
||||
auto max_cat =
|
||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
|
||||
{
|
||||
GradientPairPrecise parent_sum(-6.0, 3.0);
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_a =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0},
|
||||
{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
|
||||
thrust::device_vector<EvaluateSplitInputs> inputs(2);
|
||||
inputs[0] = EvaluateSplitInputs{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_a)};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_b =
|
||||
std::vector<GradientPairPrecise>{{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}};
|
||||
inputs[1] = EvaluateSplitInputs{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_b)};
|
||||
thrust::device_vector<GPUExpandEntry> results(2);
|
||||
evaluator.EvaluateSplits({0, 1}, 1, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(results));
|
||||
GPUExpandEntry result_a = results[0];
|
||||
GPUExpandEntry result_b = results[1];
|
||||
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(0)[0]),
|
||||
std::bitset<32>("10000000000000000000000000000000"));
|
||||
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(1)[0]),
|
||||
std::bitset<32>("11000000000000000000000000000000"));
|
||||
}
|
||||
}
|
||||
|
||||
void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 2, 4};
|
||||
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
|
||||
cuts.cut_ptrs_.SetDevice(0);
|
||||
cuts.cut_values_.SetDevice(0);
|
||||
cuts.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
common::HistogramCuts cuts{MakeCutsForTest({1.0, 2.0, 11.0, 12.0}, {0, 2, 4}, {0.0, 0.0}, 0)};
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
|
||||
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(),
|
||||
FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
@ -50,9 +314,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
}
|
||||
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
@ -68,7 +330,11 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
if (is_categorical) {
|
||||
ASSERT_TRUE(std::isnan(result.fvalue));
|
||||
} else {
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
}
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
|
||||
parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(),
|
||||
@ -79,7 +345,7 @@ TEST(GpuHist, EvaluateSingleSplit) {
|
||||
TestEvaluateSingleSplit(false);
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateCategoricalSplit) {
|
||||
TEST(GpuHist, EvaluateSingleCategoricalSplit) {
|
||||
TestEvaluateSingleSplit(true);
|
||||
}
|
||||
|
||||
@ -96,7 +362,6 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{{-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
@ -146,7 +411,6 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(2, 0);
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
@ -186,7 +450,6 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(2, 0);
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
@ -227,7 +490,6 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
thrust::device_vector<GradientPairPrecise> feature_histogram_right =
|
||||
std::vector<GradientPairPrecise>{
|
||||
{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
EvaluateSplitInputs input_left{
|
||||
1,0,
|
||||
parent_sum,
|
||||
@ -263,8 +525,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
|
||||
TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
|
||||
GPUHistEvaluator evaluator{param_,
|
||||
static_cast<bst_feature_t>(info_.num_col_), 0};
|
||||
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(info_.num_col_), 0};
|
||||
|
||||
cuts_.cut_ptrs_.SetDevice(0);
|
||||
cuts_.cut_values_.SetDevice(0);
|
||||
@ -287,6 +548,5 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -185,5 +185,33 @@ TEST(HistEvaluator, Categorical) {
|
||||
|
||||
ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg);
|
||||
}
|
||||
|
||||
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
common::HistCollection hist;
|
||||
hist.Init(cuts_.TotalBins());
|
||||
hist.AddHistRow(0);
|
||||
hist.AllocateAllData();
|
||||
auto node_hist = hist[0];
|
||||
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
|
||||
std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin());
|
||||
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
MetaInfo info;
|
||||
info.num_col_ = 1;
|
||||
info.feature_types = {FeatureType::kCategorical};
|
||||
auto evaluator =
|
||||
HistEvaluator<CPUExpandEntry>{param_, info, common::OmpGetNumThreads(0), sampler};
|
||||
evaluator.InitRoot(GradStats{parent_sum_});
|
||||
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
RegTree tree;
|
||||
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
|
||||
auto const& split = entries.front().split;
|
||||
|
||||
this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat,
|
||||
split.DefaultLeft(),
|
||||
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
|
||||
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -97,5 +97,59 @@ class TestPartitionBasedSplit : public ::testing::Test {
|
||||
} while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
|
||||
}
|
||||
};
|
||||
|
||||
inline auto MakeCutsForTest(std::vector<float> values, std::vector<uint32_t> ptrs,
|
||||
std::vector<float> min_values, int32_t device) {
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = values;
|
||||
cuts.cut_ptrs_.HostVector() = ptrs;
|
||||
cuts.min_vals_.HostVector() = min_values;
|
||||
|
||||
if (device >= 0) {
|
||||
cuts.cut_ptrs_.SetDevice(device);
|
||||
cuts.cut_values_.SetDevice(device);
|
||||
cuts.min_vals_.SetDevice(device);
|
||||
}
|
||||
|
||||
return cuts;
|
||||
}
|
||||
|
||||
class TestCategoricalSplitWithMissing : public testing::Test {
|
||||
protected:
|
||||
common::HistogramCuts cuts_;
|
||||
// Setup gradients and parent sum with missing values.
|
||||
GradientPairPrecise parent_sum_{1.0, 6.0};
|
||||
std::vector<GradientPairPrecise> feature_histogram_{
|
||||
{0.5, 0.5}, {0.5, 0.5}, {1.0, 1.0}, {1.0, 1.0}};
|
||||
TrainParam param_;
|
||||
|
||||
void SetUp() override {
|
||||
cuts_ = MakeCutsForTest({0.0, 1.0, 2.0, 3.0}, {0, 4}, {0.0}, -1);
|
||||
auto max_cat = *std::max_element(cuts_.cut_values_.HostVector().begin(),
|
||||
cuts_.cut_values_.HostVector().end());
|
||||
cuts_.SetCategorical(true, max_cat);
|
||||
param_.UpdateAllowUnknown(
|
||||
Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}});
|
||||
}
|
||||
|
||||
void CheckResult(float loss_chg, bst_feature_t split_ind, float fvalue, bool is_cat,
|
||||
bool dft_left, GradientPairPrecise left_sum, GradientPairPrecise right_sum) {
|
||||
// forward
|
||||
// it: 0, gain: 0.545455
|
||||
// it: 1, gain: 1.000000
|
||||
// it: 2, gain: 2.250000
|
||||
// backward
|
||||
// it: 3, gain: 1.000000
|
||||
// it: 2, gain: 2.250000
|
||||
// it: 1, gain: 3.142857
|
||||
ASSERT_NEAR(loss_chg, 2.97619, kRtEps);
|
||||
ASSERT_TRUE(is_cat);
|
||||
ASSERT_TRUE(std::isnan(fvalue));
|
||||
ASSERT_EQ(split_ind, 0);
|
||||
ASSERT_FALSE(dft_left);
|
||||
ASSERT_EQ(left_sum.GetHess(), 2.5);
|
||||
ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import Dict, Any
|
||||
import numpy as np
|
||||
import sys
|
||||
import gc
|
||||
@ -77,6 +78,48 @@ class TestGPUUpdaters:
|
||||
def test_categorical_ohe(self, rows, cols, rounds, cats):
|
||||
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
@given(
|
||||
tm.categorical_dataset_strategy,
|
||||
test_up.exact_parameter_strategy,
|
||||
test_up.hist_parameter_strategy,
|
||||
test_up.cat_parameter_strategy,
|
||||
strategies.integers(4, 32),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(
|
||||
self,
|
||||
dataset: tm.TestDataset,
|
||||
exact_parameters: Dict[str, Any],
|
||||
hist_parameters: Dict[str, Any],
|
||||
cat_parameters: Dict[str, Any],
|
||||
n_rounds: int,
|
||||
) -> None:
|
||||
cat_parameters.update(exact_parameters)
|
||||
cat_parameters.update(hist_parameters)
|
||||
cat_parameters["tree_method"] = "gpu_hist"
|
||||
|
||||
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
|
||||
tm.non_increasing(results["train"]["rmse"])
|
||||
|
||||
@given(
|
||||
test_up.hist_parameter_strategy,
|
||||
test_up.cat_parameter_strategy,
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_categorical_ames_housing(
|
||||
self,
|
||||
hist_parameters: Dict[str, Any],
|
||||
cat_parameters: Dict[str, Any],
|
||||
) -> None:
|
||||
cat_parameters.update(hist_parameters)
|
||||
dataset = tm.TestDataset(
|
||||
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
|
||||
)
|
||||
cat_parameters["tree_method"] = "gpu_hist"
|
||||
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
||||
tm.non_increasing(results["train"]["rmse"])
|
||||
|
||||
@given(
|
||||
strategies.integers(10, 400),
|
||||
strategies.integers(3, 8),
|
||||
|
||||
@ -234,7 +234,7 @@ class TestTreeMethod:
|
||||
) -> None:
|
||||
parameters: Dict[str, Any] = {"tree_method": tree_method}
|
||||
cat, label = tm.make_categorical(
|
||||
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
|
||||
n_samples=rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
|
||||
)
|
||||
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
|
||||
@ -259,9 +259,6 @@ class TestTreeMethod:
|
||||
# Test with OHE split
|
||||
run(self.USE_ONEHOT)
|
||||
|
||||
if tree_method == "gpu_hist": # fixme: Test with GPU.
|
||||
return
|
||||
|
||||
# Test with partition-based split
|
||||
run(self.USE_PART)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user