[EM] Return a full DMatrix instead of a Ellpack from the GPU sampler. (#10753)
This commit is contained in:
@@ -39,11 +39,11 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method,
|
||||
|
||||
if (fixed_size_sampling) {
|
||||
EXPECT_EQ(sample.sample_rows, kRows);
|
||||
EXPECT_EQ(sample.page->n_rows, kRows);
|
||||
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
|
||||
EXPECT_EQ(sample.gpair.size(), kRows);
|
||||
} else {
|
||||
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
|
||||
EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.03f);
|
||||
EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
|
||||
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
|
||||
}
|
||||
|
||||
@@ -88,25 +88,28 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
||||
|
||||
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||
auto sampled_page = sample.page;
|
||||
auto p_fmat = sample.p_fmat;
|
||||
EXPECT_EQ(sample.sample_rows, kRows);
|
||||
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
||||
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
|
||||
EXPECT_EQ(sampled_page->n_rows, kRows);
|
||||
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
|
||||
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer;
|
||||
auto h_accessor = sampled_page->GetHostAccessor(&ctx, &h_gidx_buffer);
|
||||
ASSERT_EQ(p_fmat->NumBatches(), 1);
|
||||
for (auto const& sampled_page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer;
|
||||
auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer);
|
||||
|
||||
std::size_t offset = 0;
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
auto page = batch.Impl();
|
||||
std::vector<common::CompressedByteT> h_page_gidx_buffer;
|
||||
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
|
||||
size_t num_elements = page->n_rows * page->row_stride;
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
|
||||
std::size_t offset = 0;
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
auto page = batch.Impl();
|
||||
std::vector<common::CompressedByteT> h_page_gidx_buffer;
|
||||
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
|
||||
size_t num_elements = page->n_rows * page->row_stride;
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
|
||||
}
|
||||
offset += num_elements;
|
||||
}
|
||||
offset += num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user