More refactoring to take advantage of collective aggregators (#9081)

This commit is contained in:
Rong Ou
2023-04-25 12:36:09 -07:00
committed by GitHub
parent 49ccae7fb9
commit a320b402a5
10 changed files with 81 additions and 81 deletions

View File

@@ -73,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
auto hess = Span<float const>{hessian};
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
column_size, false, AllThreadsForTest());
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
@@ -86,7 +86,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
}
HistogramCuts distributed_cuts;
sketch_distributed.MakeCuts(&distributed_cuts);
sketch_distributed.MakeCuts(m->Info(), &distributed_cuts);
// Generate cuts for single node environment
collective::Finalize();
@@ -94,7 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
m->Info().num_row_ = world * rows;
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
column_size, false, AllThreadsForTest());
m->Info().num_row_ = rows;
for (auto rank = 0; rank < world; ++rank) {
@@ -117,7 +117,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
}
HistogramCuts single_node_cuts;
sketch_on_single_node.MakeCuts(&single_node_cuts);
sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts);
auto const& sptrs = single_node_cuts.Ptrs();
auto const& dptrs = distributed_cuts.Ptrs();
@@ -205,7 +205,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
HistogramCuts distributed_cuts;
{
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, true, AllThreadsForTest());
column_size, false, AllThreadsForTest());
std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
@@ -219,7 +219,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
}
}
sketch_distributed.MakeCuts(&distributed_cuts);
sketch_distributed.MakeCuts(m->Info(), &distributed_cuts);
}
// Generate cuts for single node environment
@@ -228,7 +228,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
HistogramCuts single_node_cuts;
{
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
column_size, false, AllThreadsForTest());
std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
@@ -242,7 +242,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
}
}
sketch_on_single_node.MakeCuts(&single_node_cuts);
sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts);
}
auto const& sptrs = single_node_cuts.Ptrs();