Add max_cat_threshold to GPU and handle missing cat values. (#8212)

This commit is contained in:
Jiaming Yuan
2022-09-07 00:57:51 +08:00
committed by GitHub
parent 441ffc017a
commit b5eb36f1af
10 changed files with 546 additions and 122 deletions

View File

@@ -185,5 +185,33 @@ TEST(HistEvaluator, Categorical) {
ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg);
}
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
common::HistCollection hist;
hist.Init(cuts_.TotalBins());
hist.AddHistRow(0);
hist.AllocateAllData();
auto node_hist = hist[0];
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin());
auto sampler = std::make_shared<common::ColumnSampler>();
MetaInfo info;
info.num_col_ = 1;
info.feature_types = {FeatureType::kCategorical};
auto evaluator =
HistEvaluator<CPUExpandEntry>{param_, info, common::OmpGetNumThreads(0), sampler};
evaluator.InitRoot(GradStats{parent_sum_});
std::vector<CPUExpandEntry> entries(1);
RegTree tree;
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
auto const& split = entries.front().split;
this->CheckResult(split.loss_chg, split.SplitIndex(), split.split_value, split.is_cat,
split.DefaultLeft(),
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
}
} // namespace tree
} // namespace xgboost