array interface

This commit is contained in:
Hendrik Groove 2024-10-20 01:46:48 +02:00
parent 206f305b65
commit 971d3ca8cd

View File

@ -68,33 +68,37 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
hipPointerAttribute_t attr; hipPointerAttribute_t attr;
std::cerr << "Calling hipPointerGetAttributes" << std::endl;
auto err = hipPointerGetAttributes(&attr, ptr); auto err = hipPointerGetAttributes(&attr, ptr);
// Reset error std::cerr << "hipPointerGetAttributes returned: " << hipGetErrorString(err) << std::endl;
hipError_t last_error = hipGetLastError();
if (last_error != hipSuccess) {
LOG(WARNING) << "HIP error after hipPointerGetAttributes: "
<< hipGetErrorString(last_error);
}
if (err == hipErrorInvalidValue) { if (err == hipErrorInvalidValue) {
std::cerr << "Invalid pointer, returning false" << std::endl;
return false; return false;
} else if (err == hipSuccess) { } else if (err == hipSuccess) {
// For ROCm 6.2.2, we use the `type` field std::cerr << "Pointer attributes obtained successfully" << std::endl;
std::cerr << "Memory type: " << attr.type << std::endl;
switch (attr.type) { switch (attr.type) {
case hipMemoryTypeUnregistered: case hipMemoryTypeUnregistered:
std::cerr << "Memory type is Unregistered, returning false" << std::endl;
return false;
case hipMemoryTypeHost: case hipMemoryTypeHost:
std::cerr << "Memory type is Host, returning false" << std::endl;
return false; return false;
case hipMemoryTypeDevice: case hipMemoryTypeDevice:
std::cerr << "Memory type is Device, returning true" << std::endl;
return true;
case hipMemoryTypeManaged: case hipMemoryTypeManaged:
std::cerr << "Memory type is Managed, returning true" << std::endl;
return true; return true;
default: default:
LOG(WARNING) << "Unknown memory type: " << attr.type; std::cerr << "Unknown memory type: " << attr.type << std::endl;
return false; return false;
} }
} else { } else {
LOG(WARNING) << "hipPointerGetAttributes failed with error: " std::cerr << "hipPointerGetAttributes failed with error: "
<< hipGetErrorString(err); << hipGetErrorString(err) << std::endl;
return false; return false;
} }
#elif defined(XGBOOST_USE_CUDA) #elif defined(XGBOOST_USE_CUDA)