[EM] Return a full DMatrix instead of a Ellpack from the GPU sampler. (#10753)

This commit is contained in:
Jiaming Yuan
2024-08-28 01:05:11 +08:00
committed by GitHub
parent d6ebcfb032
commit bde1265caf
20 changed files with 525 additions and 214 deletions

View File

@@ -39,11 +39,11 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method,
if (fixed_size_sampling) {
EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.page->n_rows, kRows);
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
EXPECT_EQ(sample.gpair.size(), kRows);
} else {
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.03f);
EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
}
@@ -88,25 +88,28 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
auto sampled_page = sample.page;
auto p_fmat = sample.p_fmat;
EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.gpair.size(), gpair.Size());
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
EXPECT_EQ(sampled_page->n_rows, kRows);
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
std::vector<common::CompressedByteT> h_gidx_buffer;
auto h_accessor = sampled_page->GetHostAccessor(&ctx, &h_gidx_buffer);
ASSERT_EQ(p_fmat->NumBatches(), 1);
for (auto const& sampled_page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
std::vector<common::CompressedByteT> h_gidx_buffer;
auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer);
std::size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
auto page = batch.Impl();
std::vector<common::CompressedByteT> h_page_gidx_buffer;
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
size_t num_elements = page->n_rows * page->row_stride;
for (size_t i = 0; i < num_elements; i++) {
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
std::size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
auto page = batch.Impl();
std::vector<common::CompressedByteT> h_page_gidx_buffer;
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
size_t num_elements = page->n_rows * page->row_stride;
for (size_t i = 0; i < num_elements; i++) {
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
}
offset += num_elements;
}
offset += num_elements;
}
}