[EM] Refactor GPU histogram builder. (#10764)
- Expose the maximum number of cached nodes to be consistent with the CPU implementation. Also easier for testing. - Extract the subtraction trick for easier testing. - Split up the `GradientQuantiser` to avoid circular dependency.
This commit is contained in:
@@ -51,7 +51,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
row_set_collection.Init();
|
||||
|
||||
HistMakerTrainParam hist_param;
|
||||
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
|
||||
hist.Reset(gmat.cut.Ptrs().back(), hist_param.MaxCachedHistNodes(ctx.Device()));
|
||||
hist.AllocateHistograms({0});
|
||||
auto const &elem = row_set_collection[0];
|
||||
common::BuildHist<false>(row_gpairs, common::Span{elem.begin(), elem.end()}, gmat, hist[0],
|
||||
@@ -120,7 +120,7 @@ TEST(HistMultiEvaluator, Evaluate) {
|
||||
linalg::Vector<GradientPairPrecise> root_sum({2}, DeviceOrd::CPU());
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto &hist = histogram[t];
|
||||
hist.Reset(n_bins * n_features, hist_param.max_cached_hist_node);
|
||||
hist.Reset(n_bins * n_features, hist_param.MaxCachedHistNodes(ctx.Device()));
|
||||
hist.AllocateHistograms({0});
|
||||
auto node_hist = hist[0];
|
||||
node_hist[0] = {-0.5, 0.5};
|
||||
@@ -237,7 +237,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
||||
entries.front().nid = 0;
|
||||
entries.front().depth = 0;
|
||||
|
||||
hist.Reset(gmat.cut.TotalBins(), hist_param.max_cached_hist_node);
|
||||
hist.Reset(gmat.cut.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
|
||||
hist.AllocateHistograms({0});
|
||||
auto node_hist = hist[0];
|
||||
|
||||
@@ -265,9 +265,10 @@ TEST(HistEvaluator, Categorical) {
|
||||
}
|
||||
|
||||
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
Context ctx;
|
||||
BoundedHistCollection hist;
|
||||
HistMakerTrainParam hist_param;
|
||||
hist.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
|
||||
hist.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
|
||||
hist.AllocateHistograms({0});
|
||||
auto node_hist = hist[0];
|
||||
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
|
||||
@@ -277,10 +278,9 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
MetaInfo info;
|
||||
info.num_col_ = 1;
|
||||
info.feature_types = {FeatureType::kCategorical};
|
||||
Context ctx;
|
||||
|
||||
auto evaluator = HistEvaluator{&ctx, ¶m_, info, sampler};
|
||||
evaluator.InitRoot(GradStats{parent_sum_});
|
||||
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
RegTree tree;
|
||||
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);
|
||||
|
||||
Reference in New Issue
Block a user