Fix inclusive scan for large sizes (#6234)

This commit is contained in:
Rory Mitchell 2020-11-03 17:01:43 +13:00 committed by GitHub
parent 7756192906
commit 29745c6df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 38 deletions

View File

@ -144,10 +144,8 @@ function(xgboost_set_cuda_flags target)
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")
if (USE_DEVICE_DEBUG)
if (CMAKE_BUILD_TYPE MATCHES "Debug")
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-G;-src-in-ptx>)
endif(CMAKE_BUILD_TYPE MATCHES "Debug")
target_compile_options(${target} PRIVATE
$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-G;-src-in-ptx>)
else (USE_DEVICE_DEBUG)
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-lineinfo>)
@ -157,10 +155,8 @@ function(xgboost_set_cuda_flags target)
enable_nvtx(${target})
endif (USE_NVTX)
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/)
endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1 -DTHRUST_IGNORE_CUB_VERSION_CHECK=1)
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/)
if (MSVC)
target_compile_options(${target} PRIVATE

2
cub

@ -1 +1 @@
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
Subproject commit af39ee264f4627608072bf54730bf3a862e56875

View File

@ -870,9 +870,6 @@ class DeviceQuantileDMatrix(DMatrix):
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
.. versionadded:: 1.1.0
Known limitation:
The data size (rows * cols) can not exceed 2 ** 31 - 1000
"""
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231

View File

@ -500,10 +500,6 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
max_bin: Number of bins for histogram construction.
Know issue:
The size of each chunk (rows * cols for a single dask chunk/partition) can
not exceed 2 ** 31 - 1000
'''
def __init__(self, client,
data,

View File

@ -307,7 +307,7 @@ class MemoryLogger {
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) {
LOG(FATAL) << "Attempting to deallocate " << n << " bytes on device "
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
<< current_device << " that was never allocated ";
}
num_deallocations++;

View File

@ -161,6 +161,26 @@ struct WriteCompressedEllpackFunctor {
}
};
template <typename Tuple>
struct TupleScanOp {
__device__ Tuple operator()(Tuple a, Tuple b) {
// Key equal
if (a.template get<0>() == b.template get<0>()) {
b.template get<1>() += a.template get<1>();
return b;
}
// Not equal
return b;
}
};
// Change the value type of thrust discard iterator so we can use it with cub
template <typename T>
class TypedDiscard : public thrust::discard_iterator<T> {
public:
using value_type = T; // NOLINT
};
// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
@ -201,30 +221,23 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
// We redirect the scan output into this functor to do the actual writing
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
d_compressed_buffer, writer, batch, device_accessor, is_valid);
thrust::discard_iterator<size_t> discard;
TypedDiscard<Tuple> discard;
thrust::transform_output_iterator<
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
out(discard, functor);
dh::XGBCachingDeviceAllocator<char> alloc;
// 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and
// lead to oom error.
// or:
// after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument
// https://github.com/NVIDIA/thrust/issues/1299
CHECK_LE(batch.Size(), std::numeric_limits<int32_t>::max() - 1000)
<< "Known limitation, size (rows * cols) of quantile based DMatrix "
"cannot exceed the limit of 32-bit integer.";
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
key_value_index_iter + batch.Size(), out,
[=] __device__(Tuple a, Tuple b) {
// Key equal
if (a.get<0>() == b.get<0>()) {
b.get<1>() += a.get<1>();
return b;
}
// Not equal
return b;
});
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
// So we don't crash on n > 2^31
size_t temp_storage_bytes = 0;
using DispatchScan =
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>;
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false);
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false);
}
void WriteNullValues(EllpackPageImpl* dst, int device_idx,

View File

@ -0,0 +1,21 @@
import numpy as np
import xgboost as xgb
import cupy as cp
import time
import pytest
# Test for integer overflow or out of memory exceptions
def test_large_input():
available_bytes, _ = cp.cuda.runtime.memGetInfo()
# 15 GB
required_bytes = 1.5e+10
if available_bytes < required_bytes:
pytest.skip("Not enough memory on this device")
n = 1000
m = ((1 << 31) + n - 1) // n
assert (np.log2(m * n) > 31)
X = cp.ones((m, n), dtype=np.float32)
y = cp.ones(m)
dmat = xgb.DeviceQuantileDMatrix(X, y)
xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)