Fix LTR with weighted Quantile DMatrix. (#7975)

* Fix LTR with weighted Quantile DMatrix.

* Better tests.
This commit is contained in:
Jiaming Yuan
2022-06-09 01:33:41 +08:00
committed by GitHub
parent 1a33b50a0d
commit 8f8bd8147a
6 changed files with 83 additions and 42 deletions

View File

@@ -184,8 +184,6 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
dh::safe_cuda(cudaSetDevice(device));
info.weights_.SetDevice(device);
auto weights = info.weights_.ConstDeviceSpan();
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
auto d_group_ptr = dh::ToSpan(group_ptr);
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
@@ -205,9 +203,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
auto d_temp_weights = dh::ToSpan(temp_weights);
if (is_ranking) {
if (!weights.empty()) {
CHECK_EQ(weights.size(), info.group_ptr_.size() - 1);
}
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
auto d_group_ptr = dh::ToSpan(group_ptr);
auto const weight_iter = dh::MakeTransformIterator<float>(
thrust::make_constant_iterator(0lu),
[=]__device__(size_t idx) -> float {
thrust::make_counting_iterator(0lu), [=] __device__(size_t idx) -> float {
auto ridx = batch.GetElement(idx).row_idx;
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
return weights[group_idx];
@@ -272,7 +274,7 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
size_t num_cols = batch.NumCols();
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
bool weighted = info.weights_.Size() != 0;
bool weighted = !info.weights_.Empty();
if (weighted) {
sketch_batch_num_elements = detail::SketchBatchNumElements(

View File

@@ -122,27 +122,6 @@ std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian,
}
return results;
}
std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
std::vector<float> const &group_weights = info.weights_.HostVector();
if (group_weights.empty()) {
return group_weights;
}
size_t n_samples = info.num_row_;
auto const &group_ptr = info.group_ptr_;
std::vector<float> results(n_samples);
CHECK_GE(group_ptr.size(), 2);
CHECK_EQ(group_ptr.back(), n_samples);
size_t cur_group = 0;
for (size_t i = 0; i < n_samples; ++i) {
results[i] = group_weights[cur_group];
if (i == group_ptr[cur_group + 1]) {
cur_group++;
}
}
return results;
}
} // anonymous namespace
template <typename WQSketch>
@@ -156,12 +135,10 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
// glue these conditions using ternary operator to avoid making data copies.
auto const &weights =
hessian.empty()
? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
: info.weights_.HostVector()) // use sample weight
: MergeWeights(
info, hessian, use_group_ind_,
n_threads_); // use hessian merged with group/sample weights
hessian.empty() ? (use_group_ind_ ? detail::UnrollGroupWeights(info) // use group weight
: info.weights_.HostVector()) // use sample weight
: MergeWeights(info, hessian, use_group_ind_,
n_threads_); // use hessian merged with group/sample weights
if (!weights.empty()) {
CHECK_EQ(weights.size(), info.num_row_);
}
@@ -563,8 +540,8 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
monitor_.Start(__func__);
// glue these conditions using ternary operator to avoid making data copies.
auto const &weights =
hessian.empty() ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
: info.weights_.HostVector()) // use sample weight
hessian.empty() ? (use_group_ind_ ? detail::UnrollGroupWeights(info) // use group weight
: info.weights_.HostVector()) // use sample weight
: MergeWeights(info, hessian, use_group_ind_,
n_threads_); // use hessian merged with group/sample weights
CHECK_EQ(weights.size(), info.num_row_);

View File

@@ -697,6 +697,29 @@ class WXQuantileSketch :
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
};
namespace detail {
inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
std::vector<float> const &group_weights = info.weights_.HostVector();
if (group_weights.empty()) {
return group_weights;
}
size_t n_samples = info.num_row_;
auto const &group_ptr = info.group_ptr_;
std::vector<float> results(n_samples);
CHECK_GE(group_ptr.size(), 2);
CHECK_EQ(group_ptr.back(), n_samples);
size_t cur_group = 0;
for (size_t i = 0; i < n_samples; ++i) {
results[i] = group_weights[cur_group];
if (i == group_ptr[cur_group + 1]) {
cur_group++;
}
}
return results;
}
} // namespace detail
class HistogramCuts;
/*!