Small refactor to categoricals (#7858)

This commit is contained in:
Rory Mitchell 2022-05-05 17:47:02 +02:00 committed by GitHub
parent 14ef38b834
commit 7ef54e39ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 110 additions and 150 deletions

5
.gitignore vendored
View File

@ -130,4 +130,7 @@ credentials.csv
# Visual Studio code + extensions # Visual Studio code + extensions
.vscode .vscode
.metals .metals
.bloop .bloop
# hypothesis python tests
.hypothesis

View File

@ -271,12 +271,19 @@ __device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a,
* \brief Set the bits for categorical splits based on the split threshold. * \brief Set the bits for categorical splits based on the split threshold.
*/ */
template <typename GradientSumT> template <typename GradientSumT>
__device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input, __device__ void SetCategoricalSplit(EvaluateSplitInputs<GradientSumT> const &input,
common::Span<bst_feature_t const> d_sorted_idx, bst_feature_t fidx, common::Span<bst_feature_t const> d_sorted_idx, bst_feature_t fidx,
bool is_left, common::Span<common::CatBitField::value_type> out, bool is_left, common::Span<common::CatBitField::value_type> out,
DeviceSplitCandidate *p_out_split) { DeviceSplitCandidate *p_out_split) {
auto &out_split = *p_out_split; auto &out_split = *p_out_split;
out_split.split_cats = common::CatBitField{out}; out_split.split_cats = common::CatBitField{out};
// Simple case for one hot split
if (common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
out_split.split_cats.Set(common::AsCat(out_split.fvalue));
return;
}
auto node_sorted_idx = auto node_sorted_idx =
is_left ? d_sorted_idx.subspan(0, input.feature_values.size()) is_left ? d_sorted_idx.subspan(0, input.feature_values.size())
: d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size()); : d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size());
@ -311,7 +318,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right, EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits) { common::Span<DeviceSplitCandidate> out_splits) {
if (!split_cats_.empty()) { if (need_sort_histogram_) {
this->SortHistogram(left, right, evaluator); this->SortHistogram(left, right, evaluator);
} }
@ -352,14 +359,13 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
template <typename GradientSumT> template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT> const &input,
common::Span<CatST> cats_out) { common::Span<CatST> cats_out) {
if (has_sort_) { if (cats_out.empty()) return;
dh::CUDAEvent event; dh::CUDAEvent event;
event.Record(dh::DefaultStream()); event.Record(dh::DefaultStream());
auto h_cats = this->HostCatStorage(input.nidx); auto h_cats = this->HostCatStorage(input.nidx);
copy_stream_.View().Wait(event); copy_stream_.View().Wait(event);
dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(),
cudaMemcpyDeviceToHost, copy_stream_.View())); cudaMemcpyDeviceToHost, copy_stream_.View()));
}
} }
template <typename GradientSumT> template <typename GradientSumT>
@ -376,17 +382,16 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate,
auto d_sorted_idx = this->SortedIdx(left); auto d_sorted_idx = this->SortedIdx(left);
auto d_entries = out_entries; auto d_entries = out_entries;
auto cats_out = this->DeviceCatStorage(left.nidx); auto cats_out = this->DeviceCatStorage(left.nidx);
// turn candidate into entry, along with hanlding sort based split. // turn candidate into entry, along with handling sort based split.
dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) { dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) {
auto const &input = i == 0 ? left : right; auto const &input = i == 0 ? left : right;
auto &split = out_splits[i]; auto &split = out_splits[i];
auto fidx = out_splits[i].findex; auto fidx = out_splits[i].findex;
if (split.is_cat && if (split.is_cat) {
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
bool is_left = i == 0; bool is_left = i == 0;
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2); auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]); SetCategoricalSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
} }
float base_weight = float base_weight =
@ -418,9 +423,8 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
auto &split = out_split[i]; auto &split = out_split[i];
auto fidx = out_split[i].findex; auto fidx = out_split[i].findex;
if (split.is_cat && if (split.is_cat) {
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) { SetCategoricalSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
} }
float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum}); float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum});

View File

@ -58,9 +58,12 @@ class GPUHistEvaluator {
dh::device_vector<bst_feature_t> feature_idx_; dh::device_vector<bst_feature_t> feature_idx_;
// Training param used for evaluation // Training param used for evaluation
TrainParam param_; TrainParam param_;
// whether the input data requires sort based split, which is more complicated so we try // Do we have any categorical features that require sorting histograms?
// to avoid it if possible. // use this to skip the expensive sort step
bool has_sort_{false}; bool need_sort_histogram_ = false;
// Number of elements of categorical storage type
// needed to hold categoricals for a single mode
std::size_t node_categorical_storage_size_ = 0;
// Copy the categories from device to host asynchronously. // Copy the categories from device to host asynchronously.
void CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, common::Span<CatST> cats_out); void CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, common::Span<CatST> cats_out);
@ -69,12 +72,17 @@ class GPUHistEvaluator {
* \brief Get host category storage of nidx for internal calculation. * \brief Get host category storage of nidx for internal calculation.
*/ */
auto HostCatStorage(bst_node_t nidx) { auto HostCatStorage(bst_node_t nidx) {
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
std::size_t min_size=(nidx+2)*node_categorical_storage_size_;
if(h_split_cats_.size()<min_size){
h_split_cats_.resize(min_size);
}
if (nidx == RegTree::kRoot) { if (nidx == RegTree::kRoot) {
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits); auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
return cats_out; return cats_out;
} }
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2); auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2);
return cats_out; return cats_out;
} }
@ -82,12 +90,15 @@ class GPUHistEvaluator {
* \brief Get device category storage of nidx for internal calculation. * \brief Get device category storage of nidx for internal calculation.
*/ */
auto DeviceCatStorage(bst_node_t nidx) { auto DeviceCatStorage(bst_node_t nidx) {
auto cat_bits = split_cats_.size() / param_.MaxNodes(); std::size_t min_size=(nidx+2)*node_categorical_storage_size_;
if(split_cats_.size()<min_size){
split_cats_.resize(min_size);
}
if (nidx == RegTree::kRoot) { if (nidx == RegTree::kRoot) {
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits); auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
return cats_out; return cats_out;
} }
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits * 2); auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2);
return cats_out; return cats_out;
} }
@ -123,8 +134,7 @@ class GPUHistEvaluator {
*/ */
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const { common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
copy_stream_.View().Sync(); copy_stream_.View().Sync();
auto cat_bits = h_split_cats_.size() / param_.MaxNodes(); auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
return cats_out; return cats_out;
} }
/** /**

View File

@ -30,46 +30,40 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
// This condition avoids sort-based split function calls if the users want // This condition avoids sort-based split function calls if the users want
// onehot-encoding-based splits. // onehot-encoding-based splits.
// For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x. // For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x.
has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { need_sort_histogram_ =
auto idx = i - 1; thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
if (common::IsCat(ft, idx)) { auto idx = i - 1;
auto n_bins = ptrs[i] - ptrs[idx]; if (common::IsCat(ft, idx)) {
bool use_sort = !common::UseOneHot(n_bins, to_onehot); auto n_bins = ptrs[i] - ptrs[idx];
return use_sort; bool use_sort = !common::UseOneHot(n_bins, to_onehot);
} return use_sort;
return false; }
}); return false;
});
if (has_sort_) { node_categorical_storage_size_ =
auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1);
CHECK_NE(bit_storage_size, 0); CHECK_NE(node_categorical_storage_size_, 0);
// We need to allocate for all nodes since the updater can grow the tree layer by split_cats_.resize(node_categorical_storage_size_);
// layer, all nodes in the same layer must be preserved until that layer is h_split_cats_.resize(node_categorical_storage_size_);
// finished. We can allocate one layer at a time, but the best case is reducing the dh::safe_cuda(
// size of the bitset by about a half, at the cost of invoking CUDA malloc many more cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
// times than necessary.
split_cats_.resize(param.MaxNodes() * bit_storage_size);
h_split_cats_.resize(split_cats_.size());
dh::safe_cuda(
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
sort_input_.resize(cat_sorted_idx_.size()); sort_input_.resize(cat_sorted_idx_.size());
/** /**
* cache feature index binary search result * cache feature index binary search result
*/ */
feature_idx_.resize(cat_sorted_idx_.size()); feature_idx_.resize(cat_sorted_idx_.size());
auto d_fidxes = dh::ToSpan(feature_idx_); auto d_fidxes = dh::ToSpan(feature_idx_);
auto it = thrust::make_counting_iterator(0ul); auto it = thrust::make_counting_iterator(0ul);
auto values = cuts.cut_values_.ConstDeviceSpan(); auto values = cuts.cut_values_.ConstDeviceSpan();
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(),
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), [=] XGBOOST_DEVICE(size_t i) {
feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) { auto fidx = dh::SegmentId(ptrs, i);
auto fidx = dh::SegmentId(ptrs, i); return fidx;
return fidx; });
});
}
} }
} }

View File

@ -194,8 +194,6 @@ struct GPUHistMakerDevice {
std::unique_ptr<GradientBasedSampler> sampler; std::unique_ptr<GradientBasedSampler> sampler;
std::unique_ptr<FeatureGroups> feature_groups; std::unique_ptr<FeatureGroups> feature_groups;
// Storing split categories for last node.
dh::caching_device_vector<uint32_t> node_categories;
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
common::Span<FeatureType const> _feature_types, bst_uint _n_rows, common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
@ -239,7 +237,8 @@ struct GPUHistMakerDevice {
param.colsample_bytree); param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id); this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
ctx_->gpu_id);
this->interaction_constraints.Reset(); this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{}); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
@ -349,14 +348,14 @@ struct GPUHistMakerDevice {
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
} }
void UpdatePosition(int nidx, RegTree* p_tree) { void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
RegTree::Node split_node = (*p_tree)[nidx]; RegTree::Node split_node = (*p_tree)[e.nid];
auto split_type = p_tree->NodeSplitType(nidx); auto split_type = p_tree->NodeSplitType(e.nid);
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto node_cats = dh::ToSpan(node_categories); auto node_cats = e.split.split_cats.Bits();
row_partitioner->UpdatePosition( row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(), e.nid, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(bst_uint ridx) { [=] __device__(bst_uint ridx) {
// given a row index, returns the node id it belongs to // given a row index, returns the node id it belongs to
bst_float cut_value = bst_float cut_value =
@ -569,27 +568,12 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max()) CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large."; << "Categorical feature value too large.";
std::vector<uint32_t> split_cats; std::vector<uint32_t> split_cats;
if (candidate.split.split_cats.Bits().empty()) { CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
if (common::InvalidCat(candidate.split.fvalue)) { auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
common::InvalidCategory(); auto max_cat = candidate.split.MaxCat();
} split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
auto cat = common::AsCat(candidate.split.fvalue); CHECK_LE(split_cats.size(), h_cats.size());
split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0); std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
common::CatBitField cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);
} else {
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto max_cat = candidate.split.MaxCat();
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
node_categories.resize(candidate.split.split_cats.Bits().size());
dh::safe_cuda(cudaMemcpyAsync(
node_categories.data().get(), candidate.split.split_cats.Data(),
candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice));
}
tree.ExpandCategorical( tree.ExpandCategorical(
candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
@ -676,7 +660,7 @@ struct GPUHistMakerDevice {
// Update position is only run when child is valid, instead of right after apply // Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call // split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist. // in GPU Hist.
this->UpdatePosition(candidate.nid, p_tree); this->UpdatePosition(candidate, p_tree);
monitor.Stop("UpdatePosition"); monitor.Stop("UpdatePosition");
monitor.Start("BuildHist"); monitor.Start("BuildHist");

View File

@ -24,14 +24,16 @@ void TestEvaluateSingleSplit(bool is_categorical) {
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
GPUTrainingParam param{tparam}; GPUTrainingParam param{tparam};
common::HistogramCuts cuts;
cuts.cut_values_.HostVector() = std::vector<float>{1.0, 2.0, 11.0, 12.0};
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 2, 4};
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
cuts.cut_ptrs_.SetDevice(0);
cuts.cut_values_.SetDevice(0);
cuts.min_vals_.SetDevice(0);
thrust::device_vector<bst_feature_t> feature_set = thrust::device_vector<bst_feature_t> feature_set =
std::vector<bst_feature_t>{0, 1}; std::vector<bst_feature_t>{0, 1};
thrust::device_vector<uint32_t> feature_segments =
std::vector<bst_row_t>{0, 2, 4};
thrust::device_vector<float> feature_values =
std::vector<float>{1.0, 2.0, 11.0, 12.0};
thrust::device_vector<float> feature_min_values =
std::vector<float>{0.0, 0.0};
// Setup gradients so that second feature gets higher gain // Setup gradients so that second feature gets higher gain
thrust::device_vector<GradientPair> feature_histogram = thrust::device_vector<GradientPair> feature_histogram =
std::vector<GradientPair>{ std::vector<GradientPair>{
@ -42,22 +44,27 @@ void TestEvaluateSingleSplit(bool is_categorical) {
FeatureType::kCategorical); FeatureType::kCategorical);
common::Span<FeatureType> d_feature_types; common::Span<FeatureType> d_feature_types;
if (is_categorical) { if (is_categorical) {
auto max_cat = *std::max_element(cuts.cut_values_.HostVector().begin(),
cuts.cut_values_.HostVector().end());
cuts.SetCategorical(true, max_cat);
d_feature_types = dh::ToSpan(feature_types); d_feature_types = dh::ToSpan(feature_types);
} }
EvaluateSplitInputs<GradientPair> input{1, EvaluateSplitInputs<GradientPair> input{1,
parent_sum, parent_sum,
param, param,
dh::ToSpan(feature_set), dh::ToSpan(feature_set),
d_feature_types, d_feature_types,
dh::ToSpan(feature_segments), cuts.cut_ptrs_.ConstDeviceSpan(),
dh::ToSpan(feature_values), cuts.cut_values_.ConstDeviceSpan(),
dh::ToSpan(feature_min_values), cuts.min_vals_.ConstDeviceSpan(),
dh::ToSpan(feature_histogram)}; dh::ToSpan(feature_histogram)};
GPUHistEvaluator<GradientPair> evaluator{ GPUHistEvaluator<GradientPair> evaluator{
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0}; tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
dh::device_vector<common::CatBitField::value_type> out_cats; evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split; DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(input, 0).split;
EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.findex, 1);
EXPECT_EQ(result.fvalue, 11.0); EXPECT_EQ(result.fvalue, 11.0);

View File

@ -137,48 +137,6 @@ TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPair>(true); TestBuildHist<GradientPair>(true);
} }
TEST(GpuHist, ApplySplit) {
RegTree tree;
GPUExpandEntry candidate;
candidate.nid = 0;
candidate.left_weight = 1.0f;
candidate.right_weight = 2.0f;
candidate.base_weight = 3.0f;
candidate.split.is_cat = true;
candidate.split.fvalue = 1.0f; // at cat 1
size_t n_rows = 10;
size_t n_cols = 10;
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
GenericParameter p;
p.InitAllowUnknown(Args{});
TrainParam tparam;
tparam.InitAllowUnknown(Args{});
BatchParam bparam;
bparam.gpu_id = 0;
bparam.max_bin = 3;
Context ctx{CreateEmptyGenericParam(0)};
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
auto impl = ellpack.Impl();
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
&ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam);
updater.ApplySplit(candidate, &tree);
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
ASSERT_EQ(updater.node_categories.size(), 1);
}
}
HistogramCutsWrapper GetHostCutMatrix () { HistogramCutsWrapper GetHostCutMatrix () {
HistogramCutsWrapper cmat; HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});