Fix LTR with weighted Quantile DMatrix. (#7975)

* Fix LTR with weighted Quantile DMatrix.

* Better tests.
This commit is contained in:
Jiaming Yuan
2022-06-09 01:33:41 +08:00
committed by GitHub
parent 1a33b50a0d
commit 8f8bd8147a
6 changed files with 83 additions and 42 deletions

View File

@@ -366,6 +366,7 @@ void TestSketchFromWeights(bool with_group) {
ValidateCuts(cuts, m.get(), kBins);
if (with_group) {
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
@@ -377,6 +378,17 @@ void TestSketchFromWeights(bool with_group) {
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
}
}
if (with_group) {
auto& h_weights = info.weights_.HostVector();
h_weights.resize(kGroups);
// Generate different weight.
for (size_t i = 0; i < h_weights.size(); ++i) {
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
}
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
ValidateCuts(weighted, m.get(), kBins);
}
}
TEST(HistUtil, SketchFromWeights) {