fix Pointer Attr
This commit is contained in:
@@ -59,7 +59,24 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
||||
return false;
|
||||
}
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
return false;
|
||||
hipPointerAttribute_t attr;
|
||||
auto err = hipPointerGetAttributes(&attr, ptr);
|
||||
// reset error
|
||||
CHECK_EQ(err, hipGetLastError());
|
||||
if (err == hipErrorInvalidValue) {
|
||||
return false;
|
||||
} else if (err == hipSuccess) {
|
||||
switch (attr.memoryType) {
|
||||
case hipMemoryTypeUnified:
|
||||
case hipMemoryTypeHost:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -35,7 +35,11 @@ auto SetDeviceToPtr(void const* ptr) {
|
||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||
return ptr_device;
|
||||
#elif defined(XGBOOST_USE_HIP) /* this is wrong, need to figure out */
|
||||
return 0;
|
||||
hipPointerAttribute_t attr;
|
||||
dh::safe_cuda(hipPointerGetAttributes(&attr, ptr));
|
||||
int32_t ptr_device = attr.device;
|
||||
dh::safe_cuda(hipSetDevice(ptr_device));
|
||||
return ptr_device;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user