more logging

This commit is contained in:
Hendrik Groove 2024-10-20 20:59:23 +02:00
parent c964dd62b4
commit 58a27ba968

View File

@ -394,13 +394,14 @@ inline void ThrowOOMError(std::string const& err, size_t bytes) {
template <class T> template <class T>
struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> { struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
using SuperT = XGBBaseDeviceAllocator<T>; using SuperT = XGBBaseDeviceAllocator<T>;
using pointer = thrust::device_ptr<T>; // NOLINT using pointer = thrust::device_ptr<T>;
template<typename U> template<typename U>
struct rebind // NOLINT struct rebind {
{ using other = XGBDefaultDeviceAllocatorImpl<U>;
using other = XGBDefaultDeviceAllocatorImpl<U>; // NOLINT
}; };
pointer allocate(size_t n) { // NOLINT
pointer allocate(size_t n) {
pointer ptr; pointer ptr;
try { try {
ptr = SuperT::allocate(n); ptr = SuperT::allocate(n);
@ -408,17 +409,22 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
} catch (const std::exception &e) { } catch (const std::exception &e) {
ThrowOOMError(e.what(), n * sizeof(T)); ThrowOOMError(e.what(), n * sizeof(T));
} }
std::cerr << "XGBDefaultDeviceAllocatorImpl: Allocated " << n * sizeof(T)
<< " bytes at " << ptr.get() << std::endl;
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T)); GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
return ptr; return ptr;
} }
void deallocate(pointer ptr, size_t n) { // NOLINT
void deallocate(pointer ptr, size_t n) {
std::cerr << "XGBDefaultDeviceAllocatorImpl: Deallocating " << n * sizeof(T)
<< " bytes at " << ptr.get() << std::endl;
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
SuperT::deallocate(ptr, n); SuperT::deallocate(ptr, n);
} }
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBDefaultDeviceAllocatorImpl() XGBDefaultDeviceAllocatorImpl() : SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()) {}
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()) {} #endif
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
}; };
/** /**