Fall back to CUB allocator if RMM memory pool is not set up (#6150)

* Fall back to CUB allocator if RMM memory pool is not set up

* Fix build

* Prevent memory leak

* Add note about lack of memory initialisation

* Add check for other fast allocators

* Set use_cub_allocator_ to true when RMM is not enabled

* Fix clang-tidy

* Do not demangle symbol; add check to ensure Linux+Clang/GCC combo
This commit is contained in:
Philip Hyunsu Cho
2020-09-24 11:04:50 -07:00
committed by GitHub
parent 5b05f88ba9
commit 72ef553550
2 changed files with 49 additions and 33 deletions

View File

@@ -402,7 +402,7 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
return SuperT::deallocate(ptr, n);
SuperT::deallocate(ptr, n);
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBDefaultDeviceAllocatorImpl()
@@ -410,49 +410,59 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
};
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
template <typename T>
using XGBCachingDeviceAllocatorImpl = XGBDefaultDeviceAllocatorImpl<T>;
#else
/**
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs
* allocations if verbose. Does not initialise memory on construction.
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end, unless
* RMM pool allocator is enabled. Does not initialise memory on construction.
*/
template <class T>
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
using SuperT = XGBBaseDeviceAllocator<T>;
using pointer = thrust::device_ptr<T>; // NOLINT
template<typename U>
struct rebind // NOLINT
{
using other = XGBCachingDeviceAllocatorImpl<U>; // NOLINT
};
cub::CachingDeviceAllocator& GetGlobalCachingAllocator ()
{
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
static cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
T *ptr;
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
n * sizeof(T));
pointer thrust_ptr(ptr);
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
return thrust_ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
GetGlobalCachingAllocator().DeviceFree(ptr.get());
}
__host__ __device__
void construct(T *) // NOLINT
{
// no-op
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
static cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
pointer ptr;
if (use_cub_allocator_) {
T* raw_ptr{nullptr};
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void**>(&raw_ptr), n * sizeof(T));
ptr = pointer(raw_ptr);
} else {
ptr = SuperT::allocate(n);
}
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
return ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
if (use_cub_allocator_) {
GetGlobalCachingAllocator().DeviceFree(ptr.get());
} else {
SuperT::deallocate(ptr, n);
}
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBCachingDeviceAllocatorImpl()
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{nullptr}) {
std::string symbol{typeid(*SuperT::resource()).name()};
if (symbol.find("pool_memory_resource") != std::string::npos
|| symbol.find("binning_memory_resource") != std::string::npos
|| symbol.find("arena_memory_resource") != std::string::npos) {
use_cub_allocator_ = false;
}
}
};
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
private:
bool use_cub_allocator_{true};
};
} // namespace detail
// Declare xgboost allocators