Fix inclusive scan for large sizes (#6234)
This commit is contained in:
parent
7756192906
commit
29745c6df2
@ -144,10 +144,8 @@ function(xgboost_set_cuda_flags target)
|
||||
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")
|
||||
|
||||
if (USE_DEVICE_DEBUG)
|
||||
if (CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||
target_compile_options(${target} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-G;-src-in-ptx>)
|
||||
endif(CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||
target_compile_options(${target} PRIVATE
|
||||
$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-G;-src-in-ptx>)
|
||||
else (USE_DEVICE_DEBUG)
|
||||
target_compile_options(${target} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-lineinfo>)
|
||||
@ -157,10 +155,8 @@ function(xgboost_set_cuda_flags target)
|
||||
enable_nvtx(${target})
|
||||
endif (USE_NVTX)
|
||||
|
||||
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
|
||||
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
|
||||
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/)
|
||||
endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
|
||||
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1 -DTHRUST_IGNORE_CUB_VERSION_CHECK=1)
|
||||
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/)
|
||||
|
||||
if (MSVC)
|
||||
target_compile_options(${target} PRIVATE
|
||||
|
||||
2
cub
2
cub
@ -1 +1 @@
|
||||
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
|
||||
Subproject commit af39ee264f4627608072bf54730bf3a862e56875
|
||||
@ -870,9 +870,6 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
Known limitation:
|
||||
The data size (rows * cols) can not exceed 2 ** 31 - 1000
|
||||
"""
|
||||
|
||||
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
|
||||
|
||||
@ -500,10 +500,6 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
max_bin: Number of bins for histogram construction.
|
||||
|
||||
|
||||
Know issue:
|
||||
The size of each chunk (rows * cols for a single dask chunk/partition) can
|
||||
not exceed 2 ** 31 - 1000
|
||||
|
||||
'''
|
||||
def __init__(self, client,
|
||||
data,
|
||||
|
||||
@ -307,7 +307,7 @@ class MemoryLogger {
|
||||
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
|
||||
auto itr = device_allocations.find(ptr);
|
||||
if (itr == device_allocations.end()) {
|
||||
LOG(FATAL) << "Attempting to deallocate " << n << " bytes on device "
|
||||
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
|
||||
<< current_device << " that was never allocated ";
|
||||
}
|
||||
num_deallocations++;
|
||||
|
||||
@ -161,6 +161,26 @@ struct WriteCompressedEllpackFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
struct TupleScanOp {
|
||||
__device__ Tuple operator()(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.template get<0>() == b.template get<0>()) {
|
||||
b.template get<1>() += a.template get<1>();
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
return b;
|
||||
}
|
||||
};
|
||||
|
||||
// Change the value type of thrust discard iterator so we can use it with cub
|
||||
template <typename T>
|
||||
class TypedDiscard : public thrust::discard_iterator<T> {
|
||||
public:
|
||||
using value_type = T; // NOLINT
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
@ -201,30 +221,23 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
||||
thrust::discard_iterator<size_t> discard;
|
||||
TypedDiscard<Tuple> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
out(discard, functor);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
// 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and
|
||||
// lead to oom error.
|
||||
// or:
|
||||
// after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument
|
||||
// https://github.com/NVIDIA/thrust/issues/1299
|
||||
CHECK_LE(batch.Size(), std::numeric_limits<int32_t>::max() - 1000)
|
||||
<< "Known limitation, size (rows * cols) of quantile based DMatrix "
|
||||
"cannot exceed the limit of 32-bit integer.";
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
|
||||
key_value_index_iter + batch.Size(), out,
|
||||
[=] __device__(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.get<0>() == b.get<0>()) {
|
||||
b.get<1>() += a.get<1>();
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
return b;
|
||||
});
|
||||
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
|
||||
// So we don't crash on n > 2^31
|
||||
size_t temp_storage_bytes = 0;
|
||||
using DispatchScan =
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr, false);
|
||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr, false);
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
|
||||
21
tests/python-gpu/test_large_input.py
Normal file
21
tests/python-gpu/test_large_input.py
Normal file
@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import cupy as cp
|
||||
import time
|
||||
import pytest
|
||||
|
||||
|
||||
# Test for integer overflow or out of memory exceptions
|
||||
def test_large_input():
|
||||
available_bytes, _ = cp.cuda.runtime.memGetInfo()
|
||||
# 15 GB
|
||||
required_bytes = 1.5e+10
|
||||
if available_bytes < required_bytes:
|
||||
pytest.skip("Not enough memory on this device")
|
||||
n = 1000
|
||||
m = ((1 << 31) + n - 1) // n
|
||||
assert (np.log2(m * n) > 31)
|
||||
X = cp.ones((m, n), dtype=np.float32)
|
||||
y = cp.ones(m)
|
||||
dmat = xgb.DeviceQuantileDMatrix(X, y)
|
||||
xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
|
||||
Loading…
x
Reference in New Issue
Block a user