Fix LTR with weighted Quantile DMatrix. (#7975)
* Fix LTR with weighted Quantile DMatrix. * Better tests.
This commit is contained in:
@@ -366,6 +366,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
ValidateCuts(cuts, m.get(), kBins);
|
||||
|
||||
if (with_group) {
|
||||
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
@@ -377,6 +378,17 @@ void TestSketchFromWeights(bool with_group) {
|
||||
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (with_group) {
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
h_weights.resize(kGroups);
|
||||
// Generate different weight.
|
||||
for (size_t i = 0; i < h_weights.size(); ++i) {
|
||||
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
|
||||
}
|
||||
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
ValidateCuts(weighted, m.get(), kBins);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, SketchFromWeights) {
|
||||
|
||||
Reference in New Issue
Block a user