Remove unnecessary calls to iota. (#6797)

This commit is contained in:
Jiaming Yuan 2021-03-31 15:27:23 +08:00 committed by GitHub
parent 79b8b560d2
commit 138fe8516a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -219,8 +219,6 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
/** /**
* Create sorted index for each class * Create sorted index for each class
*/ */
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::Iota(d_sorted_idx, device);
auto d_predts_t = dh::ToSpan(cache->predts_t); auto d_predts_t = dh::ToSpan(cache->predts_t);
Transpose(predts, d_predts_t, n_samples, n_classes, device); Transpose(predts, d_predts_t, n_samples, n_classes, device);
@ -231,6 +229,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
}); });
// no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't // no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't
// use transform iterator in sorting. // use transform iterator in sorting.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx); dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
/** /**
@ -447,10 +446,9 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
/** /**
* Sort the labels * Sort the labels
*/ */
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto d_labels = info.labels_.ConstDeviceSpan(); auto d_labels = info.labels_.ConstDeviceSpan();
dh::Iota(d_sorted_idx, device); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_labels, d_group_ptr, d_sorted_idx); dh::SegmentedArgSort<false>(d_labels, d_group_ptr, d_sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();