log reduce function

This commit is contained in:
Hendrik Groove 2024-10-20 23:26:21 +02:00
parent 58a27ba968
commit bf2ef6c586

View File

@ -965,18 +965,32 @@ size_t SegmentedUniqueByKey(
template <typename Policy, typename InputIt, typename Init, typename Func> template <typename Policy, typename InputIt, typename Init, typename Func>
auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) { auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) {
std::cerr << "Entering Reduce function" << std::endl;
size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2; size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2;
size_t size = std::distance(first, second); size_t size = std::distance(first, second);
std::cerr << "Total size for reduction: " << size << std::endl;
using Ty = std::remove_cv_t<Init>; using Ty = std::remove_cv_t<Init>;
Ty aggregate = init; Ty aggregate = init;
for (size_t offset = 0; offset < size; offset += kLimit) { for (size_t offset = 0; offset < size; offset += kLimit) {
auto begin_it = first + offset; auto begin_it = first + offset;
auto end_it = first + std::min(offset + kLimit, size); auto end_it = first + std::min(offset + kLimit, size);
size_t batch_size = std::distance(begin_it, end_it); size_t batch_size = std::distance(begin_it, end_it);
CHECK_LE(batch_size, size); CHECK_LE(batch_size, size);
auto ret = thrust::reduce(policy, begin_it, end_it, init, reduce_op);
aggregate = reduce_op(aggregate, ret); std::cerr << "Processing batch: offset=" << offset << ", batch_size=" << batch_size << std::endl;
try {
auto ret = thrust::reduce(policy, begin_it, end_it, init, reduce_op);
aggregate = reduce_op(aggregate, ret);
} catch (const std::exception& e) {
std::cerr << "Exception in thrust::reduce: " << e.what() << std::endl;
throw;
}
} }
std::cerr << "Exiting Reduce function" << std::endl;
return aggregate; return aggregate;
} }