Small refactor to categoricals (#7858)
This commit is contained in:
parent
14ef38b834
commit
7ef54e39ec
5
.gitignore
vendored
5
.gitignore
vendored
@ -130,4 +130,7 @@ credentials.csv
|
||||
# Visual Studio code + extensions
|
||||
.vscode
|
||||
.metals
|
||||
.bloop
|
||||
.bloop
|
||||
|
||||
# hypothesis python tests
|
||||
.hypothesis
|
||||
@ -271,12 +271,19 @@ __device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a,
|
||||
* \brief Set the bits for categorical splits based on the split threshold.
|
||||
*/
|
||||
template <typename GradientSumT>
|
||||
__device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input,
|
||||
__device__ void SetCategoricalSplit(EvaluateSplitInputs<GradientSumT> const &input,
|
||||
common::Span<bst_feature_t const> d_sorted_idx, bst_feature_t fidx,
|
||||
bool is_left, common::Span<common::CatBitField::value_type> out,
|
||||
DeviceSplitCandidate *p_out_split) {
|
||||
auto &out_split = *p_out_split;
|
||||
out_split.split_cats = common::CatBitField{out};
|
||||
|
||||
// Simple case for one hot split
|
||||
if (common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
|
||||
out_split.split_cats.Set(common::AsCat(out_split.fvalue));
|
||||
return;
|
||||
}
|
||||
|
||||
auto node_sorted_idx =
|
||||
is_left ? d_sorted_idx.subspan(0, input.feature_values.size())
|
||||
: d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size());
|
||||
@ -311,7 +318,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_splits) {
|
||||
if (!split_cats_.empty()) {
|
||||
if (need_sort_histogram_) {
|
||||
this->SortHistogram(left, right, evaluator);
|
||||
}
|
||||
|
||||
@ -352,14 +359,13 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT> const &input,
|
||||
common::Span<CatST> cats_out) {
|
||||
if (has_sort_) {
|
||||
dh::CUDAEvent event;
|
||||
event.Record(dh::DefaultStream());
|
||||
auto h_cats = this->HostCatStorage(input.nidx);
|
||||
copy_stream_.View().Wait(event);
|
||||
dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(),
|
||||
cudaMemcpyDeviceToHost, copy_stream_.View()));
|
||||
}
|
||||
if (cats_out.empty()) return;
|
||||
dh::CUDAEvent event;
|
||||
event.Record(dh::DefaultStream());
|
||||
auto h_cats = this->HostCatStorage(input.nidx);
|
||||
copy_stream_.View().Wait(event);
|
||||
dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(),
|
||||
cudaMemcpyDeviceToHost, copy_stream_.View()));
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
@ -376,17 +382,16 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate,
|
||||
auto d_sorted_idx = this->SortedIdx(left);
|
||||
auto d_entries = out_entries;
|
||||
auto cats_out = this->DeviceCatStorage(left.nidx);
|
||||
// turn candidate into entry, along with hanlding sort based split.
|
||||
// turn candidate into entry, along with handling sort based split.
|
||||
dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) {
|
||||
auto const &input = i == 0 ? left : right;
|
||||
auto &split = out_splits[i];
|
||||
auto fidx = out_splits[i].findex;
|
||||
|
||||
if (split.is_cat &&
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
|
||||
if (split.is_cat) {
|
||||
bool is_left = i == 0;
|
||||
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
|
||||
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
|
||||
SetCategoricalSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
|
||||
}
|
||||
|
||||
float base_weight =
|
||||
@ -418,9 +423,8 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
||||
auto &split = out_split[i];
|
||||
auto fidx = out_split[i].findex;
|
||||
|
||||
if (split.is_cat &&
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
|
||||
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
|
||||
if (split.is_cat) {
|
||||
SetCategoricalSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
|
||||
}
|
||||
|
||||
float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum});
|
||||
|
||||
@ -58,9 +58,12 @@ class GPUHistEvaluator {
|
||||
dh::device_vector<bst_feature_t> feature_idx_;
|
||||
// Training param used for evaluation
|
||||
TrainParam param_;
|
||||
// whether the input data requires sort based split, which is more complicated so we try
|
||||
// to avoid it if possible.
|
||||
bool has_sort_{false};
|
||||
// Do we have any categorical features that require sorting histograms?
|
||||
// use this to skip the expensive sort step
|
||||
bool need_sort_histogram_ = false;
|
||||
// Number of elements of categorical storage type
|
||||
// needed to hold categoricals for a single mode
|
||||
std::size_t node_categorical_storage_size_ = 0;
|
||||
|
||||
// Copy the categories from device to host asynchronously.
|
||||
void CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, common::Span<CatST> cats_out);
|
||||
@ -69,12 +72,17 @@ class GPUHistEvaluator {
|
||||
* \brief Get host category storage of nidx for internal calculation.
|
||||
*/
|
||||
auto HostCatStorage(bst_node_t nidx) {
|
||||
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
|
||||
|
||||
std::size_t min_size=(nidx+2)*node_categorical_storage_size_;
|
||||
if(h_split_cats_.size()<min_size){
|
||||
h_split_cats_.resize(min_size);
|
||||
}
|
||||
|
||||
if (nidx == RegTree::kRoot) {
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2);
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2);
|
||||
return cats_out;
|
||||
}
|
||||
|
||||
@ -82,12 +90,15 @@ class GPUHistEvaluator {
|
||||
* \brief Get device category storage of nidx for internal calculation.
|
||||
*/
|
||||
auto DeviceCatStorage(bst_node_t nidx) {
|
||||
auto cat_bits = split_cats_.size() / param_.MaxNodes();
|
||||
std::size_t min_size=(nidx+2)*node_categorical_storage_size_;
|
||||
if(split_cats_.size()<min_size){
|
||||
split_cats_.resize(min_size);
|
||||
}
|
||||
if (nidx == RegTree::kRoot) {
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits);
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits * 2);
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2);
|
||||
return cats_out;
|
||||
}
|
||||
|
||||
@ -123,8 +134,7 @@ class GPUHistEvaluator {
|
||||
*/
|
||||
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
|
||||
copy_stream_.View().Sync();
|
||||
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
|
||||
return cats_out;
|
||||
}
|
||||
/**
|
||||
|
||||
@ -30,46 +30,40 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
// This condition avoids sort-based split function calls if the users want
|
||||
// onehot-encoding-based splits.
|
||||
// For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x.
|
||||
has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = i - 1;
|
||||
if (common::IsCat(ft, idx)) {
|
||||
auto n_bins = ptrs[i] - ptrs[idx];
|
||||
bool use_sort = !common::UseOneHot(n_bins, to_onehot);
|
||||
return use_sort;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
need_sort_histogram_ =
|
||||
thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = i - 1;
|
||||
if (common::IsCat(ft, idx)) {
|
||||
auto n_bins = ptrs[i] - ptrs[idx];
|
||||
bool use_sort = !common::UseOneHot(n_bins, to_onehot);
|
||||
return use_sort;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (has_sort_) {
|
||||
auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1);
|
||||
CHECK_NE(bit_storage_size, 0);
|
||||
// We need to allocate for all nodes since the updater can grow the tree layer by
|
||||
// layer, all nodes in the same layer must be preserved until that layer is
|
||||
// finished. We can allocate one layer at a time, but the best case is reducing the
|
||||
// size of the bitset by about a half, at the cost of invoking CUDA malloc many more
|
||||
// times than necessary.
|
||||
split_cats_.resize(param.MaxNodes() * bit_storage_size);
|
||||
h_split_cats_.resize(split_cats_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
|
||||
node_categorical_storage_size_ =
|
||||
common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1);
|
||||
CHECK_NE(node_categorical_storage_size_, 0);
|
||||
split_cats_.resize(node_categorical_storage_size_);
|
||||
h_split_cats_.resize(node_categorical_storage_size_);
|
||||
dh::safe_cuda(
|
||||
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
|
||||
|
||||
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
|
||||
sort_input_.resize(cat_sorted_idx_.size());
|
||||
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
|
||||
sort_input_.resize(cat_sorted_idx_.size());
|
||||
|
||||
/**
|
||||
* cache feature index binary search result
|
||||
*/
|
||||
feature_idx_.resize(cat_sorted_idx_.size());
|
||||
auto d_fidxes = dh::ToSpan(feature_idx_);
|
||||
auto it = thrust::make_counting_iterator(0ul);
|
||||
auto values = cuts.cut_values_.ConstDeviceSpan();
|
||||
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
|
||||
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(),
|
||||
feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) {
|
||||
auto fidx = dh::SegmentId(ptrs, i);
|
||||
return fidx;
|
||||
});
|
||||
}
|
||||
/**
|
||||
* cache feature index binary search result
|
||||
*/
|
||||
feature_idx_.resize(cat_sorted_idx_.size());
|
||||
auto d_fidxes = dh::ToSpan(feature_idx_);
|
||||
auto it = thrust::make_counting_iterator(0ul);
|
||||
auto values = cuts.cut_values_.ConstDeviceSpan();
|
||||
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(),
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
auto fidx = dh::SegmentId(ptrs, i);
|
||||
return fidx;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -194,8 +194,6 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
// Storing split categories for last node.
|
||||
dh::caching_device_vector<uint32_t> node_categories;
|
||||
|
||||
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
|
||||
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
|
||||
@ -239,7 +237,8 @@ struct GPUHistMakerDevice {
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
||||
ctx_->gpu_id);
|
||||
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
|
||||
@ -349,14 +348,14 @@ struct GPUHistMakerDevice {
|
||||
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree* p_tree) {
|
||||
RegTree::Node split_node = (*p_tree)[nidx];
|
||||
auto split_type = p_tree->NodeSplitType(nidx);
|
||||
void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
|
||||
RegTree::Node split_node = (*p_tree)[e.nid];
|
||||
auto split_type = p_tree->NodeSplitType(e.nid);
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
auto node_cats = dh::ToSpan(node_categories);
|
||||
auto node_cats = e.split.split_cats.Bits();
|
||||
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
e.nid, split_node.LeftChild(), split_node.RightChild(),
|
||||
[=] __device__(bst_uint ridx) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
@ -569,27 +568,12 @@ struct GPUHistMakerDevice {
|
||||
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
||||
<< "Categorical feature value too large.";
|
||||
std::vector<uint32_t> split_cats;
|
||||
if (candidate.split.split_cats.Bits().empty()) {
|
||||
if (common::InvalidCat(candidate.split.fvalue)) {
|
||||
common::InvalidCategory();
|
||||
}
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0);
|
||||
common::CatBitField cats_bits(split_cats);
|
||||
cats_bits.Set(cat);
|
||||
dh::CopyToD(split_cats, &node_categories);
|
||||
} else {
|
||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||
auto max_cat = candidate.split.MaxCat();
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
|
||||
CHECK_LE(split_cats.size(), h_cats.size());
|
||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||
|
||||
node_categories.resize(candidate.split.split_cats.Bits().size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
node_categories.data().get(), candidate.split.split_cats.Data(),
|
||||
candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
|
||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||
auto max_cat = candidate.split.MaxCat();
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
|
||||
CHECK_LE(split_cats.size(), h_cats.size());
|
||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
|
||||
@ -676,7 +660,7 @@ struct GPUHistMakerDevice {
|
||||
// Update position is only run when child is valid, instead of right after apply
|
||||
// split (as in approx tree method). Hense we have the finalise position call
|
||||
// in GPU Hist.
|
||||
this->UpdatePosition(candidate.nid, p_tree);
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
|
||||
@ -24,14 +24,16 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
cuts.cut_values_.HostVector() = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
cuts.cut_ptrs_.HostVector() = std::vector<uint32_t>{0, 2, 4};
|
||||
cuts.min_vals_.HostVector() = std::vector<float>{0.0, 0.0};
|
||||
cuts.cut_ptrs_.SetDevice(0);
|
||||
cuts.cut_values_.SetDevice(0);
|
||||
cuts.min_vals_.SetDevice(0);
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 0.0};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{
|
||||
@ -42,22 +44,27 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
if (is_categorical) {
|
||||
auto max_cat = *std::max_element(cuts.cut_values_.HostVector().begin(),
|
||||
cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
}
|
||||
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
d_feature_types,
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
|
||||
GPUHistEvaluator<GradientPair> evaluator{
|
||||
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||
dh::device_vector<common::CatBitField::value_type> out_cats;
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split;
|
||||
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(input, 0).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
|
||||
@ -137,48 +137,6 @@ TEST(GpuHist, BuildHistSharedMem) {
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
RegTree tree;
|
||||
GPUExpandEntry candidate;
|
||||
candidate.nid = 0;
|
||||
candidate.left_weight = 1.0f;
|
||||
candidate.right_weight = 2.0f;
|
||||
candidate.base_weight = 3.0f;
|
||||
candidate.split.is_cat = true;
|
||||
candidate.split.fvalue = 1.0f; // at cat 1
|
||||
|
||||
size_t n_rows = 10;
|
||||
size_t n_cols = 10;
|
||||
|
||||
auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
|
||||
GenericParameter p;
|
||||
p.InitAllowUnknown(Args{});
|
||||
|
||||
TrainParam tparam;
|
||||
tparam.InitAllowUnknown(Args{});
|
||||
BatchParam bparam;
|
||||
bparam.gpu_id = 0;
|
||||
bparam.max_bin = 3;
|
||||
Context ctx{CreateEmptyGenericParam(0)};
|
||||
|
||||
for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
|
||||
auto impl = ellpack.Impl();
|
||||
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
|
||||
feature_types.SetDevice(bparam.gpu_id);
|
||||
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
|
||||
&ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam);
|
||||
updater.ApplySplit(candidate, &tree);
|
||||
|
||||
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
|
||||
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
|
||||
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
|
||||
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
|
||||
ASSERT_EQ(tree.GetSplitCategories().back(), bits);
|
||||
|
||||
ASSERT_EQ(updater.node_categories.size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCutsWrapper GetHostCutMatrix () {
|
||||
HistogramCutsWrapper cmat;
|
||||
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user