Loop over copy_if (#6201)
* Loop over copy_if * Catch OOM. Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
@@ -129,6 +129,12 @@ inline size_t AvailableMemory(int device_idx) {
|
||||
return device_free;
|
||||
}
|
||||
|
||||
inline int32_t CurrentDevice() {
|
||||
int32_t device = 0;
|
||||
safe_cuda(cudaGetDevice(&device));
|
||||
return device;
|
||||
}
|
||||
|
||||
inline size_t TotalMemory(int device_idx) {
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
@@ -384,6 +390,16 @@ template <typename T>
|
||||
using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
inline void ThrowOOMError(std::string const& err, size_t bytes) {
|
||||
auto device = CurrentDevice();
|
||||
auto rank = rabit::GetRank();
|
||||
std::stringstream ss;
|
||||
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
|
||||
<< "- Free memory: " << AvailableMemory(device) << "\n"
|
||||
<< "- Requested memory: " << bytes << std::endl;
|
||||
LOG(FATAL) << ss.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
|
||||
*/
|
||||
@@ -397,7 +413,13 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
using other = XGBDefaultDeviceAllocatorImpl<U>; // NOLINT
|
||||
};
|
||||
pointer allocate(size_t n) { // NOLINT
|
||||
pointer ptr = SuperT::allocate(n);
|
||||
pointer ptr;
|
||||
try {
|
||||
ptr = SuperT::allocate(n);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
} catch (const std::exception &e) {
|
||||
ThrowOOMError(e.what(), n * sizeof(T));
|
||||
}
|
||||
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
|
||||
return ptr;
|
||||
}
|
||||
@@ -432,8 +454,11 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
}
|
||||
pointer allocate(size_t n) { // NOLINT
|
||||
T* ptr;
|
||||
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
||||
n * sizeof(T));
|
||||
auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
||||
n * sizeof(T));
|
||||
if (errc != cudaSuccess) {
|
||||
ThrowOOMError("Caching allocator", n * sizeof(T));
|
||||
}
|
||||
pointer thrust_ptr{ ptr };
|
||||
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
||||
return thrust_ptr;
|
||||
|
||||
Reference in New Issue
Block a user