Device dmatrix (#5420)

This commit is contained in:
Rory Mitchell
2020-03-28 14:42:21 +13:00
committed by GitHub
parent 780de49ddb
commit 13b10a6370
24 changed files with 915 additions and 310 deletions

View File

@@ -284,5 +284,28 @@ TEST(hist_util, AdapterDeviceSketchBatches) {
ValidateCuts(cuts, dmat.get(), num_bins);
}
}
// Check sketching from adapter or DMatrix results in the same answer
// Consistency here is useful for testing and user experience
TEST(hist_util, SketchingEquivalent) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) {
auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins);
auto x_device = thrust::device_vector<float>(x);
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
auto adapter_cuts = AdapterDeviceSketch(
&adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
}
}
}
} // namespace common
} // namespace xgboost