[EM] Merge GPU partitioning with histogram building. (#10766)

- Stop concatenating pages if there's no subsampling.
- Use a single iteration for histogram build and partitioning.
This commit is contained in:
Jiaming Yuan
2024-08-31 03:25:37 +08:00
committed by GitHub
parent 98ac153265
commit e1a2c1bbb3
7 changed files with 118 additions and 159 deletions

View File

@@ -67,7 +67,6 @@ TEST(GradientBasedSampler, NoSampling) {
VerifySampling(kPageSize, kSubsample, kSamplingMethod);
}
// In external mode, when not sampling, we concatenate the pages together.
TEST(GradientBasedSampler, NoSamplingExternalMemory) {
constexpr size_t kRows = 2048;
constexpr size_t kCols = 1;
@@ -81,34 +80,11 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
gpair.SetDevice(ctx.Device());
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
EXPECT_NE(page->n_rows, kRows);
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
auto p_fmat = sample.p_fmat;
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
EXPECT_EQ(sample.gpair.size(), gpair.Size());
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
ASSERT_EQ(p_fmat->NumBatches(), 1);
for (auto const& sampled_page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
std::vector<common::CompressedByteT> h_gidx_buffer;
auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer);
std::size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
auto page = batch.Impl();
std::vector<common::CompressedByteT> h_page_gidx_buffer;
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
size_t num_elements = page->n_rows * page->row_stride;
for (size_t i = 0; i < num_elements; i++) {
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
}
offset += num_elements;
}
}
ASSERT_EQ(p_fmat, dmat.get());
}
TEST(GradientBasedSampler, UniformSampling) {