More tests for column split and vertical federated learning (#8985)

Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning.

Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`.

Some refactoring of the testing fixtures.
This commit is contained in:
Rong Ou
2023-03-28 01:40:26 -07:00
committed by GitHub
parent 401ce5cf5e
commit ff26cd3212
18 changed files with 212 additions and 94 deletions

View File

@@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads);
m->Info().IsColumnSplit(), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
@@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads};
m->Info().IsColumnSplit(), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}