[EM] Refactor GPU histogram builder. (#10764)

- Expose the maximum number of cached nodes to be consistent with the CPU implementation. Also easier for testing.
- Extract the subtraction trick for easier testing.
- Split up the `GradientQuantiser` to avoid circular dependency.
This commit is contained in:
Jiaming Yuan
2024-08-30 02:39:14 +08:00
committed by GitHub
parent 34937fea41
commit 61dd854a52
17 changed files with 394 additions and 187 deletions

View File

@@ -51,7 +51,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
row_set_collection.Init();
HistMakerTrainParam hist_param;
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
hist.Reset(gmat.cut.Ptrs().back(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto const &elem = row_set_collection[0];
common::BuildHist<false>(row_gpairs, common::Span{elem.begin(), elem.end()}, gmat, hist[0],
@@ -120,7 +120,7 @@ TEST(HistMultiEvaluator, Evaluate) {
linalg::Vector<GradientPairPrecise> root_sum({2}, DeviceOrd::CPU());
for (bst_target_t t{0}; t < n_targets; ++t) {
auto &hist = histogram[t];
hist.Reset(n_bins * n_features, hist_param.max_cached_hist_node);
hist.Reset(n_bins * n_features, hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto node_hist = hist[0];
node_hist[0] = {-0.5, 0.5};
@@ -237,7 +237,7 @@ auto CompareOneHotAndPartition(bool onehot) {
entries.front().nid = 0;
entries.front().depth = 0;
hist.Reset(gmat.cut.TotalBins(), hist_param.max_cached_hist_node);
hist.Reset(gmat.cut.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto node_hist = hist[0];
@@ -265,9 +265,10 @@ TEST(HistEvaluator, Categorical) {
}
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
Context ctx;
BoundedHistCollection hist;
HistMakerTrainParam hist_param;
hist.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
hist.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto node_hist = hist[0];
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
@@ -277,10 +278,9 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
MetaInfo info;
info.num_col_ = 1;
info.feature_types = {FeatureType::kCategorical};
Context ctx;
auto evaluator = HistEvaluator{&ctx, &param_, info, sampler};
evaluator.InitRoot(GradStats{parent_sum_});
std::vector<CPUExpandEntry> entries(1);
RegTree tree;
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);