Fix LTR with weighted Quantile DMatrix. (#7975)
* Fix LTR with weighted Quantile DMatrix. * Better tests.
This commit is contained in:
@@ -98,7 +98,11 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx,
|
||||
int num_bins) {
|
||||
std::map<int, int> bin_weights;
|
||||
for (auto i = 0ull; i < sorted_column.size(); i++) {
|
||||
bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i];
|
||||
auto bin_idx = cuts.SearchBin(sorted_column[i], column_idx);
|
||||
if (bin_weights.find(bin_idx) == bin_weights.cend()) {
|
||||
bin_weights[bin_idx] = 0;
|
||||
}
|
||||
bin_weights.at(bin_idx) += sorted_weights[i];
|
||||
}
|
||||
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||
auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0);
|
||||
@@ -176,8 +180,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
}
|
||||
}
|
||||
|
||||
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
||||
int num_bins) {
|
||||
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) {
|
||||
// Collect data into columns
|
||||
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
@@ -189,17 +192,22 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// construct weights.
|
||||
std::vector<float> w = dmat->Info().group_ptr_.empty() ? dmat->Info().weights_.HostVector()
|
||||
: detail::UnrollGroupWeights(dmat->Info());
|
||||
|
||||
// Sort
|
||||
for (auto i = 0ull; i < columns.size(); i++) {
|
||||
auto& col = columns.at(i);
|
||||
const auto& w = dmat->Info().weights_.HostVector();
|
||||
std::vector<size_t > index(col.size());
|
||||
std::vector<size_t> index(col.size());
|
||||
std::iota(index.begin(), index.end(), 0);
|
||||
std::sort(index.begin(), index.end(),
|
||||
[=](size_t a, size_t b) { return col[a] < col[b]; });
|
||||
std::sort(index.begin(), index.end(), [=](size_t a, size_t b) { return col[a] < col[b]; });
|
||||
|
||||
std::vector<float> sorted_column(col.size());
|
||||
std::vector<float> sorted_weights(col.size(), 1.0);
|
||||
const auto& w = dmat->Info().weights_.HostVector();
|
||||
|
||||
for (auto j = 0ull; j < col.size(); j++) {
|
||||
sorted_column[j] = col[index[j]];
|
||||
if (w.size() == col.size()) {
|
||||
|
||||
Reference in New Issue
Block a user