unify cuda to hip
This commit is contained in:
parent
6df27eadc9
commit
4eb371b3f0
@ -32,6 +32,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// Wrapper around cub sort to define is_decending
|
// Wrapper around cub sort to define is_decending
|
||||||
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
|
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
|
||||||
typename EndOffsetIteratorT>
|
typename EndOffsetIteratorT>
|
||||||
@ -56,13 +57,13 @@ static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_st
|
|||||||
end_bit, false, ctx->Stream(), debug_synchronous)));
|
end_bit, false, ctx->Stream(), debug_synchronous)));
|
||||||
#elif defined(XGBOOST_USE_HIP)
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
if (IS_DESCENDING) {
|
if (IS_DESCENDING) {
|
||||||
rocprim::segmented_radix_sort_pairs_desc<KeyT, hipcub::NullType, BeginOffsetIteratorT>(d_temp_storage,
|
rocprim::segmented_radix_sort_pairs_desc<KeyT, cub::NullType, BeginOffsetIteratorT>(d_temp_storage,
|
||||||
temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items,
|
temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items,
|
||||||
num_segments, d_begin_offsets, d_end_offsets,
|
num_segments, d_begin_offsets, d_end_offsets,
|
||||||
begin_bit, end_bit, ctx->Stream(), debug_synchronous);
|
begin_bit, end_bit, ctx->Stream(), debug_synchronous);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
rocprim::segmented_radix_sort_pairs<KeyT, hipcub::NullType, BeginOffsetIteratorT>(d_temp_storage,
|
rocprim::segmented_radix_sort_pairs<KeyT, cub::NullType, BeginOffsetIteratorT>(d_temp_storage,
|
||||||
temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items,
|
temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items,
|
||||||
num_segments, d_begin_offsets, d_end_offsets,
|
num_segments, d_begin_offsets, d_end_offsets,
|
||||||
begin_bit, end_bit, ctx->Stream(), debug_synchronous);
|
begin_bit, end_bit, ctx->Stream(), debug_synchronous);
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
#define WITH_CUDA() true
|
#define WITH_CUDA() true
|
||||||
|
|
||||||
#elif defined(__HIP_PLATFORM_AMD__)
|
#elif defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#include "cuda_to_hip.h"
|
||||||
#include <thrust/system/hip/error.h>
|
#include <thrust/system/hip/error.h>
|
||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
|
|
||||||
@ -38,7 +39,7 @@
|
|||||||
#endif // defined(__CUDACC__)
|
#endif // defined(__CUDACC__)
|
||||||
|
|
||||||
namespace dh {
|
namespace dh {
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
|
||||||
/*
|
/*
|
||||||
* Error handling functions
|
* Error handling functions
|
||||||
*/
|
*/
|
||||||
@ -53,22 +54,6 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, int line
|
|||||||
}
|
}
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
|
||||||
#elif defined(__HIP_PLATFORM_AMD__)
|
|
||||||
/*
|
|
||||||
* Error handling functions
|
|
||||||
*/
|
|
||||||
#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__)
|
|
||||||
|
|
||||||
inline hipError_t ThrowOnCudaError(hipError_t code, const char *file, int line)
|
|
||||||
{
|
|
||||||
if (code != hipSuccess) {
|
|
||||||
LOG(FATAL) << thrust::system_error(code, thrust::hip_category(),
|
|
||||||
std::string{file} + ": " + // NOLINT
|
|
||||||
std::to_string(line)).what();
|
|
||||||
}
|
|
||||||
return code;
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,10 @@
|
|||||||
|
|
||||||
#define cudaSuccess hipSuccess
|
#define cudaSuccess hipSuccess
|
||||||
#define cudaError hipError_t
|
#define cudaError hipError_t
|
||||||
|
#define cudaError_t hipError_t
|
||||||
#define cudaGetLastError hipGetLastError
|
#define cudaGetLastError hipGetLastError
|
||||||
#define cudaPeekAtLastError hipPeekAtLastError
|
#define cudaPeekAtLastError hipPeekAtLastError
|
||||||
|
#define cudaErrorInvalidValue hipErrorInvalidValue
|
||||||
|
|
||||||
#define cudaStream_t hipStream_t
|
#define cudaStream_t hipStream_t
|
||||||
#define cudaStreamCreate hipStreamCreate
|
#define cudaStreamCreate hipStreamCreate
|
||||||
@ -17,7 +19,10 @@
|
|||||||
#define cudaStreamWaitEvent hipStreamWaitEvent
|
#define cudaStreamWaitEvent hipStreamWaitEvent
|
||||||
#define cudaStreamSynchronize hipStreamSynchronize
|
#define cudaStreamSynchronize hipStreamSynchronize
|
||||||
#define cudaStreamPerThread hipStreamPerThread
|
#define cudaStreamPerThread hipStreamPerThread
|
||||||
#define cudaStreamLegacy hipStreamLegacy
|
|
||||||
|
/* not compatible */
|
||||||
|
#define cudaStreamLegacy hipStreamDefault
|
||||||
|
#define hipStreamLegacy hipStreamDefault
|
||||||
|
|
||||||
#define cudaEvent_t hipEvent_t
|
#define cudaEvent_t hipEvent_t
|
||||||
#define cudaEventCreate hipEventCreate
|
#define cudaEventCreate hipEventCreate
|
||||||
@ -50,6 +55,11 @@
|
|||||||
#define cudaPointerAttributes hipPointerAttribute_t
|
#define cudaPointerAttributes hipPointerAttribute_t
|
||||||
#define cudaPointerGetAttributes hipPointerGetAttributes
|
#define cudaPointerGetAttributes hipPointerGetAttributes
|
||||||
|
|
||||||
|
/* hipMemoryTypeUnregistered not supported */
|
||||||
|
#define cudaMemoryTypeUnregistered hipMemoryTypeUnified
|
||||||
|
#define cudaMemoryTypeHost hipMemoryTypeHost
|
||||||
|
#define cudaMemoryTypeUnified hipMemoryTypeUnified
|
||||||
|
|
||||||
#define cudaMemGetInfo hipMemGetInfo
|
#define cudaMemGetInfo hipMemGetInfo
|
||||||
#define cudaFuncSetAttribute hipFuncSetAttribute
|
#define cudaFuncSetAttribute hipFuncSetAttribute
|
||||||
|
|
||||||
@ -59,10 +69,11 @@
|
|||||||
namespace thrust {
|
namespace thrust {
|
||||||
namespace hip {
|
namespace hip {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace cuda = thrust::hip;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace thrust {
|
namespace thrust {
|
||||||
namespace cuda = thrust::hip;
|
|
||||||
#define cuda_category hip_category
|
#define cuda_category hip_category
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1109,7 +1109,7 @@ inline CUDAStreamView DefaultStream() {
|
|||||||
#ifdef HIP_API_PER_THREAD_DEFAULT_STREAM
|
#ifdef HIP_API_PER_THREAD_DEFAULT_STREAM
|
||||||
return CUDAStreamView{hipStreamPerThread};
|
return CUDAStreamView{hipStreamPerThread};
|
||||||
#else
|
#else
|
||||||
return CUDAStreamView{hipStreamDefault};
|
return CUDAStreamView{hipStreamLegacy};
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||||
cudaPointerAttributes attr;
|
cudaPointerAttributes attr;
|
||||||
auto err = cudaPointerGetAttributes(&attr, ptr);
|
auto err = cudaPointerGetAttributes(&attr, ptr);
|
||||||
// reset error
|
// reset error
|
||||||
@ -64,25 +64,6 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#elif defined(XGBOOST_USE_HIP)
|
|
||||||
hipPointerAttribute_t attr;
|
|
||||||
auto err = hipPointerGetAttributes(&attr, ptr);
|
|
||||||
// reset error
|
|
||||||
CHECK_EQ(err, hipGetLastError());
|
|
||||||
if (err == hipErrorInvalidValue) {
|
|
||||||
return false;
|
|
||||||
} else if (err == hipSuccess) {
|
|
||||||
switch (attr.memoryType) {
|
|
||||||
case hipMemoryTypeUnified:
|
|
||||||
case hipMemoryTypeHost:
|
|
||||||
return false;
|
|
||||||
default:
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user