Enable compiling with system cub. (#7232)

- Tested with all CUDA 11.x.
- Workaround cub scan by using discard iterator in AUC.
- Limit the size of Argsort when compiled with CUDA cub.
This commit is contained in:
Jiaming Yuan
2021-09-17 14:28:18 +08:00
committed by GitHub
parent b18f5f61b0
commit c311a8c1d8
6 changed files with 67 additions and 26 deletions

View File

@@ -331,24 +331,25 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
// expand to tuple to include class id
auto fptp_it_in = dh::MakeTransformIterator<Triple>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
uint32_t class_id = i / n_samples;
return thrust::make_tuple(class_id, d_fptp[i].first, d_fptp[i].second);
return thrust::make_tuple(i, d_fptp[i].first, d_fptp[i].second);
});
// shrink down to pair
auto fptp_it_out = thrust::make_transform_output_iterator(
dh::tbegin(d_fptp), [=] __device__(Triple const &t) {
return thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
dh::TypedDiscard<Triple>{}, [d_fptp] __device__(Triple const &t) {
d_fptp[thrust::get<0>(t)] =
thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
return t;
});
dh::InclusiveScan(
fptp_it_in, fptp_it_out,
[=] __device__(Triple const &l, Triple const &r) {
uint32_t l_cid = thrust::get<0>(l);
uint32_t r_cid = thrust::get<0>(r);
uint32_t l_cid = thrust::get<0>(l) / n_samples;
uint32_t r_cid = thrust::get<0>(r) / n_samples;
if (l_cid != r_cid) {
return r;
}
return Triple(r_cid, // class_id
return Triple(thrust::get<0>(r),
thrust::get<1>(l) + thrust::get<1>(r), // fp
thrust::get<2>(l) + thrust::get<2>(r)); // tp
},
@@ -521,7 +522,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
dh::TemporaryArray<float> d_auc(group_ptr.size() - 1);
auto s_d_auc = dh::ToSpan(d_auc);
auto out = thrust::make_transform_output_iterator(
Discard<RankScanItem>(), [=] __device__(RankScanItem const &item) -> RankScanItem {
dh::TypedDiscard<RankScanItem>{}, [=] __device__(RankScanItem const &item) -> RankScanItem {
auto group_id = item.group_id;
assert(group_id < d_group_ptr.size());
auto data_group_begin = d_group_ptr[group_id];