[EM] Have one partitioner for each batch. (#10760)

- Initialize one partitioner for each batch.
- Collect partition size during initialization.
- Support base ridx in the finalization.
This commit is contained in:
Jiaming Yuan
2024-08-29 01:35:17 +08:00
committed by GitHub
parent 3043827efc
commit 4fe67f10b4
10 changed files with 211 additions and 181 deletions

View File

@@ -38,11 +38,9 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method,
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
if (fixed_size_sampling) {
EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
EXPECT_EQ(sample.gpair.size(), kRows);
} else {
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
}
@@ -89,7 +87,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
auto p_fmat = sample.p_fmat;
EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
EXPECT_EQ(sample.gpair.size(), gpair.Size());
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
EXPECT_EQ(p_fmat->Info().num_row_, kRows);