Extract device algorithms. (#8789)

This commit is contained in:
Jiaming Yuan 2023-02-13 20:53:53 +08:00 committed by GitHub
parent 457f704e3d
commit 31d3ec07af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 361 additions and 218 deletions

View File

@ -1,27 +1,190 @@
/*! /**
* Copyright 2022 by XGBoost Contributors * Copyright 2022-2023 by XGBoost Contributors
*/ */
#pragma once #ifndef XGBOOST_COMMON_ALGORITHM_CUH_
#define XGBOOST_COMMON_ALGORITHM_CUH_
#include <thrust/binary_search.h> // thrust::upper_bound #include <thrust/copy.h> // copy
#include <thrust/execution_policy.h> // thrust::seq #include <thrust/sort.h> // stable_sort_by_key
#include <thrust/tuple.h> // tuple,get
#include "xgboost/base.h" #include <cstddef> // size_t
#include "xgboost/span.h" #include <cstdint> // int32_t
#include <cub/cub.cuh> // DispatchSegmentedRadixSort,NullType,DoubleBuffer
#include <iterator> // distance
#include <limits> // numeric_limits
#include <type_traits> // conditional_t,remove_const_t
#include "common.h" // safe_cuda
#include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/context.h" // Context
#include "xgboost/logging.h" // CHECK
#include "xgboost/span.h" // Span,byte
namespace xgboost { namespace xgboost {
namespace common { namespace common {
namespace cuda { namespace detail {
template <typename It> // Wrapper around cub sort to define is_decending
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) { template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - 1 - first; typename EndOffsetIteratorT>
return segment_id; static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_storage,
std::size_t &temp_storage_bytes, // NOLINT
const KeyT *d_keys_in, KeyT *d_keys_out, int num_items,
int num_segments, BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets, int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8,
bool debug_synchronous = false) {
using OffsetT = int;
// Null value type
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
cub::DoubleBuffer<cub::NullType> d_values;
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
IS_DESCENDING, KeyT, cub::NullType, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, ctx->Stream(), debug_synchronous)));
} }
template <typename T> // Wrapper around cub sort for easier `descending` sort.
size_t XGBOOST_DEVICE SegmentId(Span<T> segments_ptr, size_t idx) { template <bool descending, typename KeyT, typename ValueT, typename BeginOffsetIteratorT,
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx); typename EndOffsetIteratorT>
void DeviceSegmentedRadixSortPair(void *d_temp_storage,
std::size_t &temp_storage_bytes, // NOLINT
const KeyT *d_keys_in, KeyT *d_keys_out,
const ValueT *d_values_in, ValueT *d_values_out,
std::size_t num_items, std::size_t num_segments,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets, dh::CUDAStreamView stream,
int begin_bit = 0, int end_bit = sizeof(KeyT) * 8) {
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in), d_values_out);
// In old version of cub, num_items in dispatch is also int32_t, no way to change.
using OffsetT = std::conditional_t<dh::BuildWithCUDACub() && dh::HasThrustMinorVer<13>(),
std::size_t, std::int32_t>;
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, stream)));
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, stream, false)));
#else
dh::safe_cuda(
(cub::DispatchSegmentedRadixSort<descending, KeyT, ValueT, BeginOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes,
d_keys, d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, stream, false)));
#endif
}
} // namespace detail
template <typename U, typename V>
void SegmentedSequence(Context const *ctx, Span<U> d_offset_ptr, Span<V> out_sequence) {
dh::LaunchN(out_sequence.size(), ctx->CUDACtx()->Stream(),
[out_sequence, d_offset_ptr] __device__(std::size_t idx) {
auto group = dh::SegmentId(d_offset_ptr, idx);
out_sequence[idx] = idx - d_offset_ptr[group];
});
}
template <bool descending, typename U, typename V>
inline void SegmentedSortKeys(Context const *ctx, Span<V const> group_ptr,
Span<U> out_sorted_values) {
CHECK_GE(group_ptr.size(), 1ul);
std::size_t n_groups = group_ptr.size() - 1;
std::size_t bytes = 0;
auto const *cuctx = ctx->CUDACtx();
CHECK(cuctx);
detail::DeviceSegmentedRadixSortKeys<descending>(
cuctx, nullptr, bytes, out_sorted_values.data(), out_sorted_values.data(),
out_sorted_values.size(), n_groups, group_ptr.data(), group_ptr.data() + 1);
dh::TemporaryArray<byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortKeys<descending>(
cuctx, temp_storage.data().get(), bytes, out_sorted_values.data(), out_sorted_values.data(),
out_sorted_values.size(), n_groups, group_ptr.data(), group_ptr.data() + 1);
}
/**
* \brief Create sorted index for data with multiple segments.
*
* \tparam accending sorted in non-decreasing order.
* \tparam per_seg_index Index starts from 0 for each segment if true, otherwise the
* the index span the whole data.
*/
template <bool accending, bool per_seg_index, typename U, typename V, typename IdxT>
void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
Span<IdxT> sorted_idx) {
CHECK_GE(group_ptr.size(), 1ul);
std::size_t n_groups = group_ptr.size() - 1;
std::size_t bytes = 0;
if (per_seg_index) {
SegmentedSequence(ctx, group_ptr, sorted_idx);
} else {
dh::Iota(sorted_idx);
}
dh::TemporaryArray<std::remove_const_t<U>> values_out(values.size());
dh::TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());
detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
dh::TemporaryArray<byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
}
/**
* \brief Different from the radix-sort-based argsort, this one can handle cases where
* segment doesn't start from 0, but as a result it uses comparison sort.
*/
template <typename SegIt, typename ValIt>
void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, ValIt val_begin,
ValIt val_end, dh::device_vector<std::size_t> *p_sorted_idx) {
using Tup = thrust::tuple<std::int32_t, float>;
auto &sorted_idx = *p_sorted_idx;
std::size_t n = std::distance(val_begin, val_end);
sorted_idx.resize(n);
dh::Iota(dh::ToSpan(sorted_idx));
dh::device_vector<Tup> keys(sorted_idx.size());
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> Tup {
std::int32_t seg_idx;
if (i < *seg_begin) {
seg_idx = -1;
} else {
seg_idx = dh::SegmentId(seg_begin, seg_end, i);
}
auto residue = val_begin[i];
return thrust::make_tuple(seg_idx, residue);
});
thrust::copy(ctx->CUDACtx()->CTP(), key_it, key_it + keys.size(), keys.begin());
thrust::stable_sort_by_key(ctx->CUDACtx()->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
if (thrust::get<0>(l) != thrust::get<0>(r)) {
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
}
return thrust::get<1>(l) < thrust::get<1>(r); // residue
});
} }
} // namespace cuda
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_ALGORITHM_CUH_

View File

@ -1,43 +1,39 @@
/*! /**
* Copyright 2017-2022 XGBoost contributors * Copyright 2017-2023 XGBoost contributors
*/ */
#pragma once #pragma once
#include <thrust/binary_search.h> // thrust::upper_bound
#include <thrust/device_malloc_allocator.h>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/device_malloc_allocator.h> #include <thrust/execution_policy.h> // thrust::seq
#include <thrust/gather.h> // gather
#include <thrust/iterator/discard_iterator.h> #include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h> #include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
#include <thrust/logical.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <thrust/execution_policy.h>
#include <thrust/transform_scan.h> #include <thrust/transform_scan.h>
#include <thrust/logical.h>
#include <thrust/gather.h>
#include <thrust/unique.h> #include <thrust/unique.h>
#include <thrust/binary_search.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector>
#include <tuple> #include <tuple>
#include <vector>
#include "xgboost/logging.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/span.h"
#include "xgboost/global_config.h"
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "common.h" #include "common.h"
#include "algorithm.cuh" #include "xgboost/global_config.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/logging.h"
#include "xgboost/span.h"
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
@ -1015,7 +1011,16 @@ XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIt
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func); return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
} }
using xgboost::common::cuda::SegmentId; // import it for compatibility template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - 1 - first;
return segment_id;
}
template <typename T>
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
namespace detail { namespace detail {
template <typename Key, typename KeyOutIt> template <typename Key, typename KeyOutIt>
@ -1288,114 +1293,6 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
} }
namespace detail {
// Wrapper around cub sort for easier `descending` sort.
template <bool descending, typename KeyT, typename ValueT,
typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
void DeviceSegmentedRadixSortPair(
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out,
size_t num_items, size_t num_segments, BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets, int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8) {
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in),
d_values_out);
// In old version of cub, num_items in dispatch is also int32_t, no way to change.
using OffsetT =
std::conditional_t<BuildWithCUDACub() && HasThrustMinorVer<13>(), size_t,
int32_t>;
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr)));
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr, false)));
#else
safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr, false)));
#endif
}
} // namespace detail
template <bool accending, typename U, typename V, typename IdxT>
void SegmentedArgSort(xgboost::common::Span<U> values,
xgboost::common::Span<V> group_ptr,
xgboost::common::Span<IdxT> sorted_idx) {
CHECK_GE(group_ptr.size(), 1ul);
size_t n_groups = group_ptr.size() - 1;
size_t bytes = 0;
Iota(sorted_idx);
TemporaryArray<std::remove_const_t<U>> values_out(values.size());
TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());
detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1);
TemporaryArray<xgboost::common::byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(),
n_groups, group_ptr.data(), group_ptr.data() + 1);
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
}
/**
* \brief Different from the above one, this one can handle cases where segment doesn't
* start from 0, but as a result it uses comparison sort.
*/
template <typename SegIt, typename ValIt>
void SegmentedArgSort(SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end,
dh::device_vector<size_t> *p_sorted_idx) {
using Tup = thrust::tuple<int32_t, float>;
auto &sorted_idx = *p_sorted_idx;
size_t n = std::distance(val_begin, val_end);
sorted_idx.resize(n);
dh::Iota(dh::ToSpan(sorted_idx));
dh::device_vector<Tup> keys(sorted_idx.size());
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) -> Tup {
int32_t leaf_idx;
if (i < *seg_begin) {
leaf_idx = -1;
} else {
leaf_idx = dh::SegmentId(seg_begin, seg_end, i);
}
auto residue = val_begin[i];
return thrust::make_tuple(leaf_idx, residue);
});
dh::XGBCachingDeviceAllocator<char> caching;
thrust::copy(thrust::cuda::par(caching), key_it, key_it + keys.size(), keys.begin());
dh::XGBDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc), keys.begin(), keys.end(), sorted_idx.begin(),
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
if (thrust::get<0>(l) != thrust::get<0>(r)) {
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
}
return thrust::get<1>(l) < thrust::get<1>(r); // residue
});
}
class CUDAStreamView; class CUDAStreamView;
class CUDAEvent { class CUDAEvent {

View File

@ -17,6 +17,7 @@
#include <limits> // std::numeric_limits #include <limits> // std::numeric_limits
#include <type_traits> // std::is_floating_point,std::iterator_traits #include <type_traits> // std::is_floating_point,std::iterator_traits
#include "algorithm.cuh" // SegmentedArgMergeSort
#include "cuda_context.cuh" // CUDAContext #include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
@ -150,7 +151,7 @@ void SegmentedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_begin, Se
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) { ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
dh::device_vector<std::size_t> sorted_idx; dh::device_vector<std::size_t> sorted_idx;
using Tup = thrust::tuple<std::size_t, float>; using Tup = thrust::tuple<std::size_t, float>;
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx); common::SegmentedArgMergeSort(ctx, seg_begin, seg_end, val_begin, val_end, &sorted_idx);
auto n_segments = std::distance(seg_begin, seg_end) - 1; auto n_segments = std::distance(seg_begin, seg_end) - 1;
if (n_segments <= 0) { if (n_segments <= 0) {
return; return;
@ -203,7 +204,7 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
HostDeviceVector<float>* quantiles) { HostDeviceVector<float>* quantiles) {
auto cuctx = ctx->CUDACtx(); auto cuctx = ctx->CUDACtx();
dh::device_vector<std::size_t> sorted_idx; dh::device_vector<std::size_t> sorted_idx;
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx); common::SegmentedArgMergeSort(ctx, seg_beg, seg_end, val_begin, val_end, &sorted_idx);
auto d_sorted_idx = dh::ToSpan(sorted_idx); auto d_sorted_idx = dh::ToSpan(sorted_idx);
std::size_t n_weights = std::distance(w_begin, w_end); std::size_t n_weights = std::distance(w_begin, w_end);
dh::device_vector<float> weights_cdf(n_weights); dh::device_vector<float> weights_cdf(n_weights);

View File

@ -86,7 +86,7 @@ class IterativeDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix."; LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
return nullptr; return nullptr;
} }
DMatrix *SliceCol(int num_slices, int slice_id) override { DMatrix *SliceCol(int, int) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix."; LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
return nullptr; return nullptr;
} }

View File

@ -87,7 +87,7 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix."; LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
return nullptr; return nullptr;
} }
DMatrix* SliceCol(int num_slices, int slice_id) override { DMatrix* SliceCol(int, int) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix."; LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
return nullptr; return nullptr;
} }

View File

@ -107,7 +107,7 @@ class SparsePageDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr; return nullptr;
} }
DMatrix *SliceCol(int num_slices, int slice_id) override { DMatrix *SliceCol(int, int) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory."; LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
return nullptr; return nullptr;
} }

View File

@ -345,8 +345,8 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
std::tie(auc, valid_groups) = std::tie(auc, valid_groups) =
RankingAUC<true>(predts.ConstHostVector(), info, n_threads); RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
} else { } else {
std::tie(auc, valid_groups) = GPURankingAUC( std::tie(auc, valid_groups) =
predts.ConstDeviceSpan(), info, ctx_->gpu_id, &this->d_cache_); GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
} }
return std::make_pair(auc, valid_groups); return std::make_pair(auc, valid_groups);
} }
@ -360,8 +360,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads, auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
BinaryROCAUC); BinaryROCAUC);
} else { } else {
auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, ctx_->gpu_id, auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes);
&this->d_cache_, n_classes);
} }
return auc; return auc;
} }
@ -398,14 +397,15 @@ std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const>, Me
return {}; return {};
} }
double GPUMultiClassROCAUC(common::Span<float const>, MetaInfo const &, std::int32_t, double GPUMultiClassROCAUC(Context const *, common::Span<float const>, MetaInfo const &,
std::shared_ptr<DeviceAUCCache> *, std::size_t) { std::shared_ptr<DeviceAUCCache> *, std::size_t) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return 0.0; return 0.0;
} }
std::pair<double, std::uint32_t> GPURankingAUC(common::Span<float const>, MetaInfo const &, std::pair<double, std::uint32_t> GPURankingAUC(Context const *, common::Span<float const>,
std::int32_t, std::shared_ptr<DeviceAUCCache> *) { MetaInfo const &,
std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }
@ -437,8 +437,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads, return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
BinaryPRAUC); BinaryPRAUC);
} else { } else {
return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, ctx_->gpu_id, return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes);
&d_cache_, n_classes);
} }
} }
@ -455,8 +454,8 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
std::tie(auc, valid_groups) = std::tie(auc, valid_groups) =
RankingAUC<false>(predts.ConstHostVector(), info, n_threads); RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
} else { } else {
std::tie(auc, valid_groups) = GPURankingPRAUC( std::tie(auc, valid_groups) =
predts.ConstDeviceSpan(), info, ctx_->gpu_id, &d_cache_); GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_);
} }
return std::make_pair(auc, valid_groups); return std::make_pair(auc, valid_groups);
} }
@ -476,14 +475,15 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const>, Met
return {}; return {};
} }
double GPUMultiClassPRAUC(common::Span<float const>, MetaInfo const &, std::int32_t, double GPUMultiClassPRAUC(Context const *, common::Span<float const>, MetaInfo const &,
std::shared_ptr<DeviceAUCCache> *, std::size_t) { std::shared_ptr<DeviceAUCCache> *, std::size_t) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }
std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const>, MetaInfo const &, std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *, common::Span<float const>,
std::int32_t, std::shared_ptr<DeviceAUCCache> *) { MetaInfo const &,
std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }

View File

@ -12,6 +12,7 @@
#include <utility> #include <utility>
#include "../collective/device_communicator.cuh" #include "../collective/device_communicator.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights #include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads #include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
#include "auc.h" #include "auc.h"
@ -20,6 +21,9 @@
namespace xgboost { namespace xgboost {
namespace metric { namespace metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(auc_gpu);
namespace { namespace {
// Pair of FP/TP // Pair of FP/TP
using Pair = thrust::pair<double, double>; using Pair = thrust::pair<double, double>;
@ -436,7 +440,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
return ScaleClasses(d_results, local_area, tp, auc, n_classes); return ScaleClasses(d_results, local_area, tp, auc, n_classes);
} }
void MultiClassSortedIdx(common::Span<float const> predts, void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
common::Span<uint32_t> d_class_ptr, common::Span<uint32_t> d_class_ptr,
std::shared_ptr<DeviceAUCCache> cache) { std::shared_ptr<DeviceAUCCache> cache) {
size_t n_classes = d_class_ptr.size() - 1; size_t n_classes = d_class_ptr.size() - 1;
@ -449,11 +453,11 @@ void MultiClassSortedIdx(common::Span<float const> predts,
dh::LaunchN(n_classes + 1, dh::LaunchN(n_classes + 1,
[=] XGBOOST_DEVICE(size_t i) { d_class_ptr[i] = i * n_samples; }); [=] XGBOOST_DEVICE(size_t i) { d_class_ptr[i] = i * n_samples; });
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx); common::SegmentedArgSort<false, false>(ctx, d_predts_t, d_class_ptr, d_sorted_idx);
} }
double GPUMultiClassROCAUC(common::Span<float const> predts, MetaInfo const &info, double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
std::int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache, MetaInfo const &info, std::shared_ptr<DeviceAUCCache> *p_cache,
std::size_t n_classes) { std::size_t n_classes) {
auto& cache = *p_cache; auto& cache = *p_cache;
InitCacheOnce<true>(predts, p_cache); InitCacheOnce<true>(predts, p_cache);
@ -462,13 +466,13 @@ double GPUMultiClassROCAUC(common::Span<float const> predts, MetaInfo const &inf
* Create sorted index for each class * Create sorted index for each class
*/ */
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0); dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache); MultiClassSortedIdx(ctx, predts, dh::ToSpan(class_ptr), cache);
auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
double tp, size_t /*class_id*/) { double tp, size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp); return TrapezoidArea(fp_prev, fp, tp_prev, tp);
}; };
return GPUMultiClassAUCOVR<true>(info, device, dh::ToSpan(class_ptr), n_classes, cache, fn); return GPUMultiClassAUCOVR<true>(info, ctx->gpu_id, dh::ToSpan(class_ptr), n_classes, cache, fn);
} }
namespace { namespace {
@ -480,8 +484,8 @@ struct RankScanItem {
}; };
} // anonymous namespace } // anonymous namespace
std::pair<double, std::uint32_t> GPURankingAUC(common::Span<float const> predts, std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<float const> predts,
MetaInfo const &info, std::int32_t device, MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) { std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache; auto& cache = *p_cache;
InitCacheOnce<false>(predts, p_cache); InitCacheOnce<false>(predts, p_cache);
@ -509,10 +513,10 @@ std::pair<double, std::uint32_t> GPURankingAUC(common::Span<float const> predts,
/** /**
* Sort the labels * Sort the labels
*/ */
auto d_labels = info.labels.View(device); auto d_labels = info.labels.View(ctx->gpu_id);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_labels.Values(), d_group_ptr, d_sorted_idx); common::SegmentedArgSort<false, false>(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();
@ -640,8 +644,8 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
return std::make_tuple(1.0, 1.0, auc); return std::make_tuple(1.0, 1.0, auc);
} }
double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info, double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
std::int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache, MetaInfo const &info, std::shared_ptr<DeviceAUCCache> *p_cache,
std::size_t n_classes) { std::size_t n_classes) {
auto& cache = *p_cache; auto& cache = *p_cache;
InitCacheOnce<true>(predts, p_cache); InitCacheOnce<true>(predts, p_cache);
@ -651,7 +655,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info
*/ */
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0); dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr); auto d_class_ptr = dh::ToSpan(class_ptr);
MultiClassSortedIdx(predts, d_class_ptr, cache); MultiClassSortedIdx(ctx, predts, d_class_ptr, cache);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();
@ -659,7 +663,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info
/** /**
* Get total positive/negative * Get total positive/negative
*/ */
auto labels = info.labels.View(device); auto labels = info.labels.View(ctx->gpu_id);
auto n_samples = info.num_row_; auto n_samples = info.num_row_;
dh::caching_device_vector<Pair> totals(n_classes); dh::caching_device_vector<Pair> totals(n_classes);
auto key_it = auto key_it =
@ -692,7 +696,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first); d_totals[class_id].first);
}; };
return GPUMultiClassAUCOVR<false>(info, device, d_class_ptr, n_classes, cache, fn); return GPUMultiClassAUCOVR<false>(info, ctx->gpu_id, d_class_ptr, n_classes, cache, fn);
} }
template <typename Fn> template <typename Fn>
@ -815,10 +819,11 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
return std::make_pair(auc, n_groups - invalid_groups); return std::make_pair(auc, n_groups - invalid_groups);
} }
std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predts, std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
MetaInfo const &info, std::int32_t device, common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) { std::shared_ptr<DeviceAUCCache> *p_cache) {
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
if (predts.empty()) { if (predts.empty()) {
return std::make_pair(0.0, static_cast<uint32_t>(0)); return std::make_pair(0.0, static_cast<uint32_t>(0));
} }
@ -836,10 +841,10 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predt
* Create sorted index for each group * Create sorted index for each group
*/ */
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(predts, d_group_ptr, d_sorted_idx); common::SegmentedArgSort<false, false>(ctx, predts, d_group_ptr, d_sorted_idx);
dh::XGBDeviceAllocator<char> alloc; dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels.View(device); auto labels = info.labels.View(ctx->gpu_id);
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()), if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) { dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels(); InvalidLabels();
@ -878,7 +883,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predt
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first); d_totals[group_id].first);
}; };
return GPURankingPRAUCImpl(predts, info, d_group_ptr, device, cache, fn); return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->gpu_id, cache, fn);
} }
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost

View File

@ -33,12 +33,12 @@ std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> pre
MetaInfo const &info, std::int32_t device, MetaInfo const &info, std::int32_t device,
std::shared_ptr<DeviceAUCCache> *p_cache); std::shared_ptr<DeviceAUCCache> *p_cache);
double GPUMultiClassROCAUC(common::Span<float const> predts, MetaInfo const &info, double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
std::int32_t device, std::shared_ptr<DeviceAUCCache> *cache, MetaInfo const &info, std::shared_ptr<DeviceAUCCache> *p_cache,
std::size_t n_classes); std::size_t n_classes);
std::pair<double, std::uint32_t> GPURankingAUC(common::Span<float const> predts, std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<float const> predts,
MetaInfo const &info, std::int32_t device, MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *cache); std::shared_ptr<DeviceAUCCache> *cache);
/********** /**********
@ -48,12 +48,13 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
MetaInfo const &info, std::int32_t device, MetaInfo const &info, std::int32_t device,
std::shared_ptr<DeviceAUCCache> *p_cache); std::shared_ptr<DeviceAUCCache> *p_cache);
double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info, double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
std::int32_t device, std::shared_ptr<DeviceAUCCache> *cache, MetaInfo const &info, std::shared_ptr<DeviceAUCCache> *p_cache,
std::size_t n_classes); std::size_t n_classes);
std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predts, std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
MetaInfo const &info, std::int32_t device, common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *cache); std::shared_ptr<DeviceAUCCache> *cache);
namespace detail { namespace detail {

View File

@ -84,6 +84,7 @@ DMLC_REGISTRY_LINK_TAG(multiclass_metric);
DMLC_REGISTRY_LINK_TAG(survival_metric); DMLC_REGISTRY_LINK_TAG(survival_metric);
DMLC_REGISTRY_LINK_TAG(rank_metric); DMLC_REGISTRY_LINK_TAG(rank_metric);
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(auc_gpu);
DMLC_REGISTRY_LINK_TAG(rank_metric_gpu); DMLC_REGISTRY_LINK_TAG(rank_metric_gpu);
#endif #endif
} // namespace metric } // namespace metric

View File

@ -0,0 +1,97 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <thrust/copy.h> // copy
#include <thrust/sequence.h> // sequence
#include <thrust/sort.h> // is_sorted
#include <algorithm> // is_sorted
#include <cstddef> // size_t
#include "../../../src/common/algorithm.cuh"
#include "../../../src/common/device_helpers.cuh"
#include "../helpers.h" // CreateEmptyGenericParam
namespace xgboost {
namespace common {
void TestSegmentedArgSort() {
Context ctx;
ctx.gpu_id = 0;
size_t constexpr kElements = 100, kGroups = 3;
dh::device_vector<size_t> sorted_idx(kElements, 0);
dh::device_vector<size_t> offset_ptr(kGroups + 1, 0);
offset_ptr[0] = 0;
offset_ptr[1] = 2;
offset_ptr[2] = 78;
offset_ptr[kGroups] = kElements;
auto d_offset_ptr = dh::ToSpan(offset_ptr);
auto d_sorted_idx = dh::ToSpan(sorted_idx);
dh::LaunchN(sorted_idx.size(), [=] XGBOOST_DEVICE(size_t idx) {
auto group = dh::SegmentId(d_offset_ptr, idx);
d_sorted_idx[idx] = idx - d_offset_ptr[group];
});
dh::device_vector<float> values(kElements, 0.0f);
thrust::sequence(values.begin(), values.end(), 0.0f);
SegmentedArgSort<false, true>(&ctx, dh::ToSpan(values), d_offset_ptr, d_sorted_idx);
std::vector<size_t> h_sorted_index(sorted_idx.size());
thrust::copy(sorted_idx.begin(), sorted_idx.end(), h_sorted_index.begin());
for (size_t i = 1; i < kGroups + 1; ++i) {
auto group_sorted_idx = common::Span<size_t>(h_sorted_index)
.subspan(offset_ptr[i - 1], offset_ptr[i] - offset_ptr[i - 1]);
ASSERT_TRUE(std::is_sorted(group_sorted_idx.begin(), group_sorted_idx.end(), std::greater<>{}));
ASSERT_EQ(group_sorted_idx.back(), 0);
for (auto j : group_sorted_idx) {
ASSERT_LT(j, group_sorted_idx.size());
}
}
}
TEST(Algorithms, SegmentedArgSort) { TestSegmentedArgSort(); }
TEST(Algorithms, ArgSort) {
Context ctx;
ctx.gpu_id = 0;
dh::device_vector<float> values(20);
dh::Iota(dh::ToSpan(values)); // accending
dh::device_vector<size_t> sorted_idx(20);
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(), sorted_idx.end(),
thrust::greater<size_t>{}));
dh::Iota(dh::ToSpan(values));
dh::device_vector<size_t> groups(3);
groups[0] = 0;
groups[1] = 10;
groups[2] = 20;
SegmentedArgSort<false, false>(&ctx, dh::ToSpan(values), dh::ToSpan(groups),
dh::ToSpan(sorted_idx));
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(), sorted_idx.end(),
thrust::greater<size_t>{}));
ASSERT_TRUE(
thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10, thrust::greater<size_t>{}));
ASSERT_TRUE(
thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(), thrust::greater<size_t>{}));
}
TEST(Algorithms, SegmentedSequence) {
dh::device_vector<std::size_t> idx(16);
dh::device_vector<std::size_t> ptr(3);
Context ctx = CreateEmptyGenericParam(0);
ptr[0] = 0;
ptr[1] = 4;
ptr[2] = idx.size();
SegmentedSequence(&ctx, dh::ToSpan(ptr), dh::ToSpan(idx));
ASSERT_EQ(idx[0], 0);
ASSERT_EQ(idx[4], 0);
ASSERT_EQ(idx[3], 3);
ASSERT_EQ(idx[15], 11);
}
} // namespace common
} // namespace xgboost

View File

@ -172,28 +172,4 @@ TEST(Allocator, OOM) {
// Clear last error so we don't fail subsequent tests // Clear last error so we don't fail subsequent tests
cudaGetLastError(); cudaGetLastError();
} }
TEST(DeviceHelpers, ArgSort) {
dh::device_vector<float> values(20);
dh::Iota(dh::ToSpan(values)); // accending
dh::device_vector<size_t> sorted_idx(20);
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
dh::Iota(dh::ToSpan(values));
dh::device_vector<size_t> groups(3);
groups[0] = 0;
groups[1] = 10;
groups[2] = 20;
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
dh::ToSpan(sorted_idx));
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
thrust::greater<size_t>{}));
}
} // namespace xgboost } // namespace xgboost

View File

@ -341,6 +341,7 @@ TEST(GPUQuantile, MultiMerge) {
namespace { namespace {
void TestAllReduceBasic(int32_t n_gpus) { void TestAllReduceBasic(int32_t n_gpus) {
auto const world = collective::GetWorldSize(); auto const world = collective::GetWorldSize();
CHECK_EQ(world, n_gpus);
constexpr size_t kRows = 1000, kCols = 100; constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
auto const device = collective::GetRank(); auto const device = collective::GetRank();
@ -425,8 +426,9 @@ TEST(GPUQuantile, MGPUAllReduceBasic) {
} }
namespace { namespace {
void TestSameOnAllWorkers(int32_t n_gpus) { void TestSameOnAllWorkers(std::int32_t n_gpus) {
auto world = collective::GetWorldSize(); auto world = collective::GetWorldSize();
CHECK_EQ(world, n_gpus);
constexpr size_t kRows = 1000, kCols = 100; constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) { MetaInfo const &info) {