Categorical data support in CPU sketching. (#7221)

This commit is contained in:
Jiaming Yuan
2021-09-17 04:37:09 +08:00
committed by GitHub
parent 9f63d6fead
commit 31c1e13f90
7 changed files with 129 additions and 57 deletions

View File

@@ -43,12 +43,14 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
// Generate cuts for distributed environment.
auto sparsity = 0.5f;
auto rank = rabit::GetRank();
HostSketchContainer sketch_distributed(column_size, n_bins, false, OmpGetNumThreads(0));
auto m = RandomDataGenerator{rows, cols, sparsity}
.Seed(rank)
.Lower(.0f)
.Upper(1.0f)
.GenerateDMatrix();
HostSketchContainer sketch_distributed(
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false,
OmpGetNumThreads(0));
for (auto const &page : m->GetBatches<SparsePage>()) {
sketch_distributed.PushRowPage(page, m->Info());
}
@@ -59,7 +61,9 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
rabit::Finalize();
CHECK_EQ(rabit::GetWorldSize(), 1);
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
HostSketchContainer sketch_on_single_node(column_size, n_bins, false, OmpGetNumThreads(0));
HostSketchContainer sketch_on_single_node(
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false,
OmpGetNumThreads(0));
for (auto rank = 0; rank < world; ++rank) {
auto m = RandomDataGenerator{rows, cols, sparsity}
.Seed(rank)