Test categorical features with column-split gpu quantile (#9595)
This commit is contained in:
@@ -634,12 +634,25 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
||||
});
|
||||
CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1);
|
||||
max_values.resize(d_in_columns_ptr.size() - 1);
|
||||
|
||||
// In some cases (e.g. column-wise data split), we may have empty columns, so we need to keep
|
||||
// track of the unique keys (feature indices) after the thrust::reduce_by_key` call.
|
||||
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1);
|
||||
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
|
||||
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it,
|
||||
thrust::make_discard_iterator(), d_max_values.begin(),
|
||||
thrust::equal_to<bst_feature_t>{},
|
||||
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
|
||||
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values));
|
||||
auto new_end = thrust::reduce_by_key(
|
||||
thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
|
||||
d_max_values.begin(), thrust::equal_to<bst_feature_t>{},
|
||||
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
|
||||
d_max_keys.erase(new_end.first, d_max_keys.end());
|
||||
d_max_values.erase(new_end.second, d_max_values.end());
|
||||
|
||||
// The device vector needs to be initialized explicitly since we may have some missing columns.
|
||||
SketchEntry default_entry{};
|
||||
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1,
|
||||
default_entry);
|
||||
thrust::scatter(thrust::cuda::par(alloc), d_max_values.begin(), d_max_values.end(),
|
||||
d_max_keys.begin(), d_max_results.begin());
|
||||
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results));
|
||||
auto max_it = MakeIndexTransformIter([&](auto i) {
|
||||
if (IsCat(h_feature_types, i)) {
|
||||
return max_values[i].value;
|
||||
|
||||
@@ -35,13 +35,13 @@ struct WQSummary {
|
||||
/*! \brief an entry in the sketch summary */
|
||||
struct Entry {
|
||||
/*! \brief minimum rank */
|
||||
RType rmin;
|
||||
RType rmin{};
|
||||
/*! \brief maximum rank */
|
||||
RType rmax;
|
||||
RType rmax{};
|
||||
/*! \brief maximum weight */
|
||||
RType wmin;
|
||||
RType wmin{};
|
||||
/*! \brief the value of data */
|
||||
DType value;
|
||||
DType value{};
|
||||
// constructor
|
||||
XGBOOST_DEVICE Entry() {} // NOLINT
|
||||
// constructor
|
||||
|
||||
Reference in New Issue
Block a user