Fix LTR with weighted Quantile DMatrix. (#7975)

* Fix LTR with weighted Quantile DMatrix.

* Better tests.
This commit is contained in:
Jiaming Yuan
2022-06-09 01:33:41 +08:00
committed by GitHub
parent 1a33b50a0d
commit 8f8bd8147a
6 changed files with 83 additions and 42 deletions

View File

@@ -366,6 +366,7 @@ void TestSketchFromWeights(bool with_group) {
ValidateCuts(cuts, m.get(), kBins);
if (with_group) {
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
@@ -377,6 +378,17 @@ void TestSketchFromWeights(bool with_group) {
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
}
}
if (with_group) {
auto& h_weights = info.weights_.HostVector();
h_weights.resize(kGroups);
// Generate different weight.
for (size_t i = 0; i < h_weights.size(); ++i) {
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
}
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
ValidateCuts(weighted, m.get(), kBins);
}
}
TEST(HistUtil, SketchFromWeights) {

View File

@@ -593,9 +593,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
ValidateCuts(cuts, dmat.get(), kBins);
if (with_group) {
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0);
for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
}
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]);
@@ -604,6 +605,24 @@ void TestAdapterSketchFromWeights(bool with_group) {
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
}
}
if (with_group) {
common::HistogramCuts weighted;
auto& h_weights = info.weights_.HostVector();
h_weights.resize(kGroups);
// Generate different weight.
for (size_t i = 0; i < h_weights.size(); ++i) {
// FIXME(jiamingy): Some entries generated GPU test cannot pass the validate cuts if
// we use more diverse weights, partially caused by
// https://github.com/dmlc/xgboost/issues/7946
h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast<float>(kGroups);
}
SketchContainer sketch_container(ft, kBins, kCols, kRows, 0);
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
sketch_container.MakeCuts(&weighted);
ValidateCuts(weighted, dmat.get(), kBins);
}
}
TEST(HistUtil, AdapterSketchFromWeights) {

View File

@@ -98,7 +98,11 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx,
int num_bins) {
std::map<int, int> bin_weights;
for (auto i = 0ull; i < sorted_column.size(); i++) {
bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i];
auto bin_idx = cuts.SearchBin(sorted_column[i], column_idx);
if (bin_weights.find(bin_idx) == bin_weights.cend()) {
bin_weights[bin_idx] = 0;
}
bin_weights.at(bin_idx) += sorted_weights[i];
}
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0);
@@ -176,8 +180,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
}
}
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
int num_bins) {
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) {
// Collect data into columns
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
for (auto& batch : dmat->GetBatches<SparsePage>()) {
@@ -189,17 +192,22 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
}
}
}
// construct weights.
std::vector<float> w = dmat->Info().group_ptr_.empty() ? dmat->Info().weights_.HostVector()
: detail::UnrollGroupWeights(dmat->Info());
// Sort
for (auto i = 0ull; i < columns.size(); i++) {
auto& col = columns.at(i);
const auto& w = dmat->Info().weights_.HostVector();
std::vector<size_t > index(col.size());
std::vector<size_t> index(col.size());
std::iota(index.begin(), index.end(), 0);
std::sort(index.begin(), index.end(),
[=](size_t a, size_t b) { return col[a] < col[b]; });
std::sort(index.begin(), index.end(), [=](size_t a, size_t b) { return col[a] < col[b]; });
std::vector<float> sorted_column(col.size());
std::vector<float> sorted_weights(col.size(), 1.0);
const auto& w = dmat->Info().weights_.HostVector();
for (auto j = 0ull; j < col.size(); j++) {
sorted_column[j] = col[index[j]];
if (w.size() == col.size()) {