fix Pointer Attr

This commit is contained in:
amdsc21
2023-03-10 19:06:02 +01:00
parent 9f072b50ba
commit 5e8b1842b9
3 changed files with 27 additions and 6 deletions

View File

@@ -59,7 +59,24 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
return false;
}
#elif defined(XGBOOST_USE_HIP)
return false;
hipPointerAttribute_t attr;
auto err = hipPointerGetAttributes(&attr, ptr);
// reset error
CHECK_EQ(err, hipGetLastError());
if (err == hipErrorInvalidValue) {
return false;
} else if (err == hipSuccess) {
switch (attr.memoryType) {
case hipMemoryTypeUnified:
case hipMemoryTypeHost:
return false;
default:
return true;
}
return true;
} else {
return false;
}
#endif
}
} // namespace xgboost

View File

@@ -35,7 +35,11 @@ auto SetDeviceToPtr(void const* ptr) {
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
#elif defined(XGBOOST_USE_HIP) /* this is wrong, need to figure out */
return 0;
hipPointerAttribute_t attr;
dh::safe_cuda(hipPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(hipSetDevice(ptr_device));
return ptr_device;
#endif
}