Fix LTR with weighted Quantile DMatrix. (#7975)
* Fix LTR with weighted Quantile DMatrix. * Better tests.
This commit is contained in:
@@ -697,6 +697,29 @@ class WXQuantileSketch :
|
||||
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
std::vector<float> const &group_weights = info.weights_.HostVector();
|
||||
if (group_weights.empty()) {
|
||||
return group_weights;
|
||||
}
|
||||
|
||||
size_t n_samples = info.num_row_;
|
||||
auto const &group_ptr = info.group_ptr_;
|
||||
std::vector<float> results(n_samples);
|
||||
CHECK_GE(group_ptr.size(), 2);
|
||||
CHECK_EQ(group_ptr.back(), n_samples);
|
||||
size_t cur_group = 0;
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
results[i] = group_weights[cur_group];
|
||||
if (i == group_ptr[cur_group + 1]) {
|
||||
cur_group++;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
class HistogramCuts;
|
||||
|
||||
/*!
|
||||
|
||||
Reference in New Issue
Block a user