Fix LTR with weighted Quantile DMatrix. (#7975)
* Fix LTR with weighted Quantile DMatrix. * Better tests.
This commit is contained in:
@@ -184,8 +184,6 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
info.weights_.SetDevice(device);
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
@@ -205,9 +203,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
auto d_temp_weights = dh::ToSpan(temp_weights);
|
||||
|
||||
if (is_ranking) {
|
||||
if (!weights.empty()) {
|
||||
CHECK_EQ(weights.size(), info.group_ptr_.size() - 1);
|
||||
}
|
||||
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||
auto const weight_iter = dh::MakeTransformIterator<float>(
|
||||
thrust::make_constant_iterator(0lu),
|
||||
[=]__device__(size_t idx) -> float {
|
||||
thrust::make_counting_iterator(0lu), [=] __device__(size_t idx) -> float {
|
||||
auto ridx = batch.GetElement(idx).row_idx;
|
||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
||||
return weights[group_idx];
|
||||
@@ -272,7 +274,7 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
|
||||
size_t num_cols = batch.NumCols();
|
||||
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
||||
int32_t device = sketch_container->DeviceIdx();
|
||||
bool weighted = info.weights_.Size() != 0;
|
||||
bool weighted = !info.weights_.Empty();
|
||||
|
||||
if (weighted) {
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
|
||||
Reference in New Issue
Block a user