Implement weighted sketching for adapter. (#5760)
* Bounded memory tests. * Fixed memory estimation.
This commit is contained in:
@@ -50,8 +50,7 @@ TEST(HistUtil, DeviceSketch) {
|
||||
// Duplicate this function from hist_util.cu so we don't have to expose it in
|
||||
// header
|
||||
size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
@@ -59,6 +58,15 @@ size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
size_t BytesRequiredForTest(size_t num_rows, size_t num_columns, size_t num_bins,
|
||||
bool with_weights) {
|
||||
size_t bytes_num_elements = BytesPerElement(with_weights) * num_rows * num_columns;
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
// divide by 2 is because the memory quota used in sorting is reused for storing cuts.
|
||||
return bytes_num_elements / 2 + bytes_cuts;
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
@@ -71,12 +79,10 @@ TEST(HistUtil, DeviceSketchMemory) {
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements = num_rows * num_columns*sizeof(Entry);
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
|
||||
size_t bytes_constant = 1000;
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
bytes_num_elements + bytes_cuts + bytes_constant);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMemoryWeights) {
|
||||
@@ -92,12 +98,9 @@ TEST(HistUtil, DeviceSketchMemoryWeights) {
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements =
|
||||
num_rows * num_columns * (sizeof(Entry) + sizeof(float));
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
size_t((bytes_num_elements + bytes_cuts) * 1.05));
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchDeterminism) {
|
||||
@@ -192,6 +195,20 @@ TEST(HistUtil, DeviceSketchBatches) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
|
||||
num_rows = 1000;
|
||||
size_t batches = 16;
|
||||
auto x = GenerateRandom(num_rows * batches, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns);
|
||||
auto cuts_with_batches = DeviceSketch(0, dmat.get(), num_bins, num_rows);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
|
||||
auto const& cut_values_batched = cuts_with_batches.Values();
|
||||
auto const& cut_values = cuts.Values();
|
||||
CHECK_EQ(cut_values.size(), cut_values_batched.size());
|
||||
for (size_t i = 0; i < cut_values.size(); ++i) {
|
||||
ASSERT_NEAR(cut_values_batched[i], cut_values[i], 1e5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
@@ -210,6 +227,19 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows,
|
||||
DMatrix* dmat) {
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows);
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
|
||||
0, &sketch_container);
|
||||
common::DenseCuts dense_cuts(&batched_cuts);
|
||||
dense_cuts.Init(&sketch_container.sketches_, num_bins, num_rows);
|
||||
ValidateCuts(batched_cuts, dmat, num_bins);
|
||||
}
|
||||
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketch) {
|
||||
int rows = 5;
|
||||
int cols = 1;
|
||||
@@ -244,14 +274,56 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry);
|
||||
size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t);
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
size_t bytes_constant = 1000;
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchBatchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows);
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
|
||||
0, &sketch_container);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_constant = 1000;
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchBatchWeightedMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
MetaInfo info;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
h_weights.resize(num_rows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0,
|
||||
&sketch_container);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
@@ -284,6 +356,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -302,6 +375,7 @@ TEST(HistUtil, AdapterDeviceSketchBatches) {
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,6 +397,8 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
||||
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -330,7 +406,7 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto& h_weights = m->Info().weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
@@ -357,6 +433,71 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i));
|
||||
}
|
||||
ValidateCuts(weighted_cuts, m.get(), kBins);
|
||||
}
|
||||
|
||||
void TestAdapterSketchFromWeights(bool with_group) {
|
||||
size_t constexpr kRows = 300, kCols = 20, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
HostDeviceVector<float> storage;
|
||||
std::string m =
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
|
||||
&storage);
|
||||
MetaInfo info;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
|
||||
std::vector<bst_group_t> groups(kGroups);
|
||||
if (with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.weights_.SetDevice(0);
|
||||
info.num_row_ = kRows;
|
||||
info.num_col_ = kCols;
|
||||
|
||||
data::CupyAdapter adapter(m);
|
||||
auto const& batch = adapter.Value();
|
||||
SketchContainer sketch_container(kBins, kCols, kRows);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
0,
|
||||
&sketch_container);
|
||||
common::HistogramCuts cuts;
|
||||
common::DenseCuts dense_cuts(&cuts);
|
||||
dense_cuts.Init(&sketch_container.sketches_, kBins, kRows);
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().num_col_ = kCols;
|
||||
dmat->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
ValidateCuts(cuts, dmat.get(), kBins);
|
||||
|
||||
if (with_group) {
|
||||
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0);
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
|
||||
ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchFromWeights) {
|
||||
TestAdapterSketchFromWeights(false);
|
||||
TestAdapterSketchFromWeights(true);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -151,7 +151,8 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
size_t num_bins) {
|
||||
|
||||
// Check the endpoints are correct
|
||||
EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front());
|
||||
CHECK_GT(sorted_column.size(), 0);
|
||||
EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front());
|
||||
EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front());
|
||||
EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back());
|
||||
|
||||
@@ -189,6 +190,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
||||
// Collect data into columns
|
||||
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
CHECK_GT(batch.Size(), 0);
|
||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||
for (auto e : batch[i]) {
|
||||
columns[e.index].push_back(e.fvalue);
|
||||
|
||||
Reference in New Issue
Block a user