Extract Sketch Entry from hist maker. (#7503)
* Extract Sketch Entry from hist maker. * Add a new sketch container for sorted inputs. * Optimize bin search.
This commit is contained in:
@@ -19,7 +19,22 @@ TEST(Quantile, LoadBalance) {
|
||||
}
|
||||
CHECK_EQ(n_cols, kCols);
|
||||
}
|
||||
namespace {
|
||||
template <bool use_column>
|
||||
using ContainerType = std::conditional_t<use_column, SortedSketchContainer, HostSketchContainer>;
|
||||
|
||||
// Dispatch for push page.
|
||||
void PushPage(SortedSketchContainer* container, SparsePage const& page, MetaInfo const& info,
|
||||
Span<float const> hessian) {
|
||||
container->PushColPage(page, info, hessian);
|
||||
}
|
||||
void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo const& info,
|
||||
Span<float const> hessian) {
|
||||
container->PushRowPage(page, info, hessian);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
template <bool use_column>
|
||||
void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
@@ -48,12 +63,23 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
.Lower(.0f)
|
||||
.Upper(1.0f)
|
||||
.GenerateDMatrix();
|
||||
HostSketchContainer sketch_distributed(
|
||||
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false,
|
||||
OmpGetNumThreads(0));
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
sketch_distributed.PushRowPage(page, m->Info());
|
||||
|
||||
std::vector<float> hessian(rows, 1.0);
|
||||
auto hess = Span<float const>{hessian};
|
||||
|
||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info(), column_size, false, hess,
|
||||
OmpGetNumThreads(0));
|
||||
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
PushPage(&sketch_distributed, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
|
||||
HistogramCuts distributed_cuts;
|
||||
sketch_distributed.MakeCuts(&distributed_cuts);
|
||||
|
||||
@@ -61,17 +87,25 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
rabit::Finalize();
|
||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
HostSketchContainer sketch_on_single_node(
|
||||
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false,
|
||||
OmpGetNumThreads(0));
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info(), column_size, false, hess,
|
||||
OmpGetNumThreads(0));
|
||||
m->Info().num_row_ = rows;
|
||||
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
auto m = RandomDataGenerator{rows, cols, sparsity}
|
||||
.Seed(rank)
|
||||
.Lower(.0f)
|
||||
.Upper(1.0f)
|
||||
.GenerateDMatrix();
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
sketch_on_single_node.PushRowPage(page, m->Info());
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
} else {
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
PushPage(&sketch_on_single_node, page, m->Info(), hess);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +121,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
|
||||
ASSERT_EQ(sptrs.size(), dptrs.size());
|
||||
for (size_t i = 0; i < sptrs.size(); ++i) {
|
||||
ASSERT_EQ(sptrs[i], dptrs[i]);
|
||||
ASSERT_EQ(sptrs[i], dptrs[i]) << i;
|
||||
}
|
||||
|
||||
ASSERT_EQ(svals.size(), dvals.size());
|
||||
@@ -104,14 +138,28 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
TEST(Quantile, DistributedBasic) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 10, kCols = 10;
|
||||
TestDistributedQuantile(kRows, kCols);
|
||||
TestDistributedQuantile<false>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, Distributed) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 1000, kCols = 200;
|
||||
TestDistributedQuantile(kRows, kCols);
|
||||
constexpr size_t kRows = 4000, kCols = 200;
|
||||
TestDistributedQuantile<false>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, SortedDistributedBasic) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 10, kCols = 10;
|
||||
TestDistributedQuantile<true>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, SortedDistributed) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 4000, kCols = 200;
|
||||
TestDistributedQuantile<true>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user