Fix inclusive scan for large sizes (#6234)

This commit is contained in:
Rory Mitchell
2020-11-03 17:01:43 +13:00
committed by GitHub
parent 7756192906
commit 29745c6df2
7 changed files with 61 additions and 38 deletions

View File

@@ -161,6 +161,26 @@ struct WriteCompressedEllpackFunctor {
}
};
template <typename Tuple>
struct TupleScanOp {
__device__ Tuple operator()(Tuple a, Tuple b) {
// Key equal
if (a.template get<0>() == b.template get<0>()) {
b.template get<1>() += a.template get<1>();
return b;
}
// Not equal
return b;
}
};
// Change the value type of thrust discard iterator so we can use it with cub
template <typename T>
class TypedDiscard : public thrust::discard_iterator<T> {
public:
using value_type = T; // NOLINT
};
// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
@@ -201,30 +221,23 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
// We redirect the scan output into this functor to do the actual writing
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
d_compressed_buffer, writer, batch, device_accessor, is_valid);
thrust::discard_iterator<size_t> discard;
TypedDiscard<Tuple> discard;
thrust::transform_output_iterator<
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
out(discard, functor);
dh::XGBCachingDeviceAllocator<char> alloc;
// 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and
// lead to oom error.
// or:
// after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument
// https://github.com/NVIDIA/thrust/issues/1299
CHECK_LE(batch.Size(), std::numeric_limits<int32_t>::max() - 1000)
<< "Known limitation, size (rows * cols) of quantile based DMatrix "
"cannot exceed the limit of 32-bit integer.";
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
key_value_index_iter + batch.Size(), out,
[=] __device__(Tuple a, Tuple b) {
// Key equal
if (a.get<0>() == b.get<0>()) {
b.get<1>() += a.get<1>();
return b;
}
// Not equal
return b;
});
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
// So we don't crash on n > 2^31
size_t temp_storage_bytes = 0;
using DispatchScan =
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>;
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false);
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false);
}
void WriteNullValues(EllpackPageImpl* dst, int device_idx,