fix memoryType
This commit is contained in:
parent
2d7ffbdf3d
commit
c42c7d99f1
@ -214,12 +214,12 @@ function(xgboost_link_rccl target)
|
||||
endif()
|
||||
|
||||
if(BUILD_STATIC_LIB)
|
||||
target_include_directories(${target} PUBLIC ${RCCL_INCLUDE_DIR})
|
||||
target_include_directories(${target} PUBLIC ${RCCL_INCLUDE_DIR}/rccl)
|
||||
target_compile_definitions(${target} PUBLIC ${xgboost_rccl_flags})
|
||||
target_link_directories(${target} PUBLIC ${HIP_LIB_INSTALL_DIR})
|
||||
target_link_libraries(${target} PUBLIC ${RCCL_LIBRARY})
|
||||
else()
|
||||
target_include_directories(${target} PRIVATE ${RCCL_INCLUDE_DIR})
|
||||
target_include_directories(${target} PRIVATE ${RCCL_INCLUDE_DIR}/rccl)
|
||||
target_compile_definitions(${target} PRIVATE ${xgboost_rccl_flags})
|
||||
target_link_directories(${target} PUBLIC ${HIP_LIB_INSTALL_DIR})
|
||||
if(NOT USE_DLOPEN_RCCL)
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit 6ceffde024f8752954550ebcca98caa24b5d158d
|
||||
Subproject commit 2fea6734e83cf147c1bbe580ac4713cd50abcad5
|
||||
@ -20,7 +20,6 @@ void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
||||
* case where 0 might be given should either use None, 1, or 2 instead for
|
||||
* clarity.
|
||||
*/
|
||||
/* ignored for HIP */
|
||||
#if !defined(XGBOOST_USE_HIP)
|
||||
LOG(FATAL) << "Invalid stream ID in array interface: " << stream;
|
||||
#endif
|
||||
@ -42,7 +41,7 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
cudaPointerAttributes attr;
|
||||
auto err = cudaPointerGetAttributes(&attr, ptr);
|
||||
// reset error
|
||||
@ -64,6 +63,35 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
||||
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
||||
return false;
|
||||
}
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
hipPointerAttribute_t attr;
|
||||
auto err = hipPointerGetAttributes(&attr, ptr);
|
||||
// reset error
|
||||
CHECK_EQ(err, hipGetLastError());
|
||||
if (err == hipErrorInvalidValue) {
|
||||
return false;
|
||||
} else if (err == hipSuccess) {
|
||||
#if HIP_VERSION_MAJOR < 6
|
||||
switch (attr.memoryType) {
|
||||
case hipMemoryTypeUnified:
|
||||
case hipMemoryTypeHost:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
#else
|
||||
switch (attr.type) {
|
||||
case hipMemoryTypeUnified:
|
||||
case hipMemoryTypeHost:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user