Use default allocator in sketching. (#6182)

This commit is contained in:
Jiaming Yuan
2020-09-30 14:55:59 +08:00
committed by GitHub
parent 444131a2e6
commit f0c63902ff
5 changed files with 33 additions and 53 deletions

View File

@@ -430,38 +430,18 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
pointer ptr;
if (use_cub_allocator_) {
T* raw_ptr{nullptr};
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void**>(&raw_ptr), n * sizeof(T));
ptr = pointer(raw_ptr);
} else {
ptr = SuperT::allocate(n);
}
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
return ptr;
T* ptr;
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
n * sizeof(T));
pointer thrust_ptr{ ptr };
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
return thrust_ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
if (use_cub_allocator_) {
GetGlobalCachingAllocator().DeviceFree(ptr.get());
} else {
SuperT::deallocate(ptr, n);
}
GetGlobalCachingAllocator().DeviceFree(ptr.get());
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBCachingDeviceAllocatorImpl()
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{nullptr}) {
std::string symbol{typeid(*SuperT::resource()).name()};
if (symbol.find("pool_memory_resource") != std::string::npos
|| symbol.find("binning_memory_resource") != std::string::npos
|| symbol.find("arena_memory_resource") != std::string::npos) {
use_cub_allocator_ = false;
}
}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
private:
bool use_cub_allocator_{true};
XGBOOST_DEVICE void construct(T *) {} // NOLINT
};
} // namespace detail