Loop over copy_if (#6201)

* Loop over copy_if

* Catch OOM.

Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
Rory Mitchell
2020-10-14 10:23:16 +13:00
committed by GitHub
parent 0fc263ead5
commit 734a911a26
3 changed files with 64 additions and 18 deletions

View File

@@ -129,6 +129,12 @@ inline size_t AvailableMemory(int device_idx) {
return device_free;
}
inline int32_t CurrentDevice() {
int32_t device = 0;
safe_cuda(cudaGetDevice(&device));
return device;
}
inline size_t TotalMemory(int device_idx) {
size_t device_free = 0;
size_t device_total = 0;
@@ -384,6 +390,16 @@ template <typename T>
using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
inline void ThrowOOMError(std::string const& err, size_t bytes) {
auto device = CurrentDevice();
auto rank = rabit::GetRank();
std::stringstream ss;
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
<< "- Free memory: " << AvailableMemory(device) << "\n"
<< "- Requested memory: " << bytes << std::endl;
LOG(FATAL) << ss.str();
}
/**
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
*/
@@ -397,7 +413,13 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
using other = XGBDefaultDeviceAllocatorImpl<U>; // NOLINT
};
pointer allocate(size_t n) { // NOLINT
pointer ptr = SuperT::allocate(n);
pointer ptr;
try {
ptr = SuperT::allocate(n);
dh::safe_cuda(cudaGetLastError());
} catch (const std::exception &e) {
ThrowOOMError(e.what(), n * sizeof(T));
}
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
return ptr;
}
@@ -432,8 +454,11 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
}
pointer allocate(size_t n) { // NOLINT
T* ptr;
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
n * sizeof(T));
auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
n * sizeof(T));
if (errc != cudaSuccess) {
ThrowOOMError("Caching allocator", n * sizeof(T));
}
pointer thrust_ptr{ ptr };
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
return thrust_ptr;