Loop over thrust::reduce. (#6229)
* Check input chunk size of dqdm. * Add doc for current limitation.
This commit is contained in:
@@ -1132,4 +1132,21 @@ size_t SegmentedUnique(Inputs &&...inputs) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
|
||||
}
|
||||
|
||||
template <typename Policy, typename InputIt, typename Init, typename Func>
|
||||
auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) {
|
||||
size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2;
|
||||
size_t size = std::distance(first, second);
|
||||
using Ty = std::remove_cv_t<Init>;
|
||||
Ty aggregate = init;
|
||||
for (size_t offset = 0; offset < size; offset += kLimit) {
|
||||
auto begin_it = first + offset;
|
||||
auto end_it = first + std::min(offset + kLimit, size);
|
||||
size_t batch_size = std::distance(begin_it, end_it);
|
||||
CHECK_LE(batch_size, size);
|
||||
auto ret = thrust::reduce(policy, begin_it, end_it, init, reduce_op);
|
||||
aggregate = reduce_op(aggregate, ret);
|
||||
}
|
||||
return aggregate;
|
||||
}
|
||||
} // namespace dh
|
||||
|
||||
Reference in New Issue
Block a user