diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 471e09fc9..ebcea1c2f 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -17,6 +17,8 @@ #include "xgboost/learner.h" #if defined(XGBOOST_USE_NCCL) #include +#elif defined(XGBOOST_USE_RCCL) +#include #endif namespace xgboost { diff --git a/src/collective/nccl_stub.cc b/src/collective/nccl_stub.cc index 5101234a4..408432438 100644 --- a/src/collective/nccl_stub.cc +++ b/src/collective/nccl_stub.cc @@ -1,15 +1,25 @@ /** * Copyright 2023, XGBoost Contributors */ -#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0) +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include "nccl_stub.h" +#if defined(XGBOOST_USE_NCCL) #include // for CUDA_VERSION #include // for cudaPeekAtLastError #include // for dlclose, dlsym, dlopen #include #include // for cuda_category #include // for system_error +#elif defined(XGBOOST_USE_RCCL) +#include "../common/cuda_to_hip.h" +#include "../common/device_helpers.hip.h" +#include // for cudaPeekAtLastError +#include // for dlclose, dlsym, dlopen +#include +#include // for cuda_category +#include // for system_error +#endif #include // for int32_t #include // for stringstream diff --git a/src/collective/nccl_stub.h b/src/collective/nccl_stub.h index 6bf2ecae6..978f34028 100644 --- a/src/collective/nccl_stub.h +++ b/src/collective/nccl_stub.h @@ -2,10 +2,17 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0) +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) +#if defined(XGBOOST_USE_NCCL) #include #include +#elif defined(XGBOOST_USE_RCCL) +#include "../common/cuda_to_hip.h" +#include "../common/device_helpers.cuh" +#include +#include +#endif #include // for string diff --git a/src/common/algorithm.cuh b/src/common/algorithm.cuh index bce9ba5de..e1e9c8bf4 100644 --- a/src/common/algorithm.cuh +++ b/src/common/algorithm.cuh @@ -226,6 +226,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V }); } +#if defined(XGBOOST_USE_CUDA) template void ArgSort(xgboost::Context const *ctx, xgboost::common::Span keys, xgboost::common::Span sorted_idx) { @@ -295,5 +296,51 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span keys, sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice, cuctx->Stream())); } +#elif defined(XGBOOST_USE_HIP) +template +void ArgSort(xgboost::Context const *ctx, xgboost::common::Span keys, + xgboost::common::Span sorted_idx) { + std::size_t bytes = 0; + auto cuctx = ctx->CUDACtx(); + dh::Iota(sorted_idx, cuctx->Stream()); + + using KeyT = typename decltype(keys)::value_type; + using ValueT = std::remove_const_t; + + dh::TemporaryArray out(keys.size()); + dh::TemporaryArray sorted_idx_out(sorted_idx.size()); + + // track https://github.com/NVIDIA/cub/pull/340 for 64bit length support + using OffsetT = std::conditional_t; + CHECK_LE(sorted_idx.size(), std::numeric_limits::max()); + if (accending) { + void *d_temp_storage = nullptr; + + dh::safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, + bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, + sizeof(KeyT) * 8, cuctx->Stream(), false))); + + dh::TemporaryArray storage(bytes); + d_temp_storage = storage.data().get(); + dh::safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, + bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, + sizeof(KeyT) * 8, cuctx->Stream(), false))); + } else { + void *d_temp_storage = nullptr; + + dh::safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage, + bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, + sizeof(KeyT) * 8, cuctx->Stream(), false))); + dh::TemporaryArray storage(bytes); + d_temp_storage = storage.data().get(); + dh::safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage, + bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, + sizeof(KeyT) * 8, cuctx->Stream(), false))); + } + + dh::safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), + sorted_idx.size_bytes(), hipMemcpyDeviceToDevice, cuctx->Stream())); +} +#endif } // namespace xgboost::common #endif // XGBOOST_COMMON_ALGORITHM_CUH_ diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index fcfe2bdd4..79f2f3390 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -40,10 +40,6 @@ #include "xgboost/logging.h" #include "xgboost/span.h" -#ifdef XGBOOST_USE_RCCL -#include "rccl.h" -#endif // XGBOOST_USE_RCCL - #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #include "rmm/mr/device/per_device_resource.hpp" #include "rmm/mr/device/thrust_allocator_adaptor.hpp" @@ -98,30 +94,6 @@ XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT } namespace dh { -#ifdef XGBOOST_USE_RCCL -#define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__) - -inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file, int line) { - if (code != ncclSuccess) { - std::stringstream ss; - ss << "RCCL failure: " << ncclGetErrorString(code) << "."; - ss << " " << file << "(" << line << ")\n"; - if (code == ncclUnhandledCudaError) { - // nccl usually preserves the last error so we can get more details. - auto err = hipPeekAtLastError(); - ss << " CUDA error: " << thrust::system_error(err, thrust::hip_category()).what() << "\n"; - } else if (code == ncclSystemError) { - ss << " This might be caused by a network configuration issue. Please consider specifying " - "the network interface for RCCL via environment variables listed in its reference: " - "`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n"; - } - LOG(FATAL) << ss.str(); - } - - return code; -} -#endif - inline int32_t CudaGetPointerDevice(void const *ptr) { int32_t device = -1; hipPointerAttribute_t attr; @@ -298,8 +270,8 @@ inline void LaunchN(size_t n, L lambda) { } template -void Iota(Container array) { - LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; }); +void Iota(Container array, cudaStream_t stream) { + LaunchN(array.size(), stream, [=] __device__(size_t i) { array[i] = i; }); } namespace detail { @@ -465,7 +437,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator { hipcub::CachingDeviceAllocator& GetGlobalCachingAllocator() { // Configure allocator with maximum cached bin size of ~1GB and no limit on // maximum cached bytes - static hipcub::CachingDeviceAllocator *allocator = new hipcub::CachingDeviceAllocator(2, 9, 29); + thread_local std::unique_ptr allocator{ + std::make_unique(2, 9, 29)}; return *allocator; } pointer allocate(size_t n) { // NOLINT @@ -581,6 +554,16 @@ class DoubleBuffer { T *Other() { return buff.Alternate(); } }; +template +xgboost::common::Span LazyResize(xgboost::Context const *ctx, + xgboost::HostDeviceVector *buffer, std::size_t n) { + buffer->SetDevice(ctx->Device()); + if (buffer->Size() < n) { + buffer->Resize(n); + } + return buffer->DeviceSpan().subspan(0, n); +} + /** * \brief Copies device span to std::vector. * @@ -1017,49 +1000,6 @@ void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) InclusiveScan(d_in, d_out, hipcub::Sum(), num_items); } -template -void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_idx) { - size_t bytes = 0; - Iota(sorted_idx); - - using KeyT = typename decltype(keys)::value_type; - using ValueT = std::remove_const_t; - - TemporaryArray out(keys.size()); - TemporaryArray sorted_idx_out(sorted_idx.size()); - - // track https://github.com/NVIDIA/cub/pull/340 for 64bit length support - using OffsetT = std::conditional_t; - CHECK_LE(sorted_idx.size(), std::numeric_limits::max()); - if (accending) { - void *d_temp_storage = nullptr; - - safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, - bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, - sizeof(KeyT) * 8))); - - TemporaryArray storage(bytes); - d_temp_storage = storage.data().get(); - safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, - bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, - sizeof(KeyT) * 8))); - } else { - void *d_temp_storage = nullptr; - - safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage, - bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, - sizeof(KeyT) * 8))); - TemporaryArray storage(bytes); - d_temp_storage = storage.data().get(); - safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage, - bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, - sizeof(KeyT) * 8))); - } - - safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), - sorted_idx.size_bytes(), hipMemcpyDeviceToDevice)); -} - class CUDAStreamView; class CUDAEvent { @@ -1105,6 +1045,8 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT dh::safe_cuda(hipEventRecord(event_, hipStream_t{stream})); } +// Changing this has effect on prediction return, where we need to pass the pointer to +// third-party libraries like cuPy inline CUDAStreamView DefaultStream() { #ifdef HIP_API_PER_THREAD_DEFAULT_STREAM return CUDAStreamView{hipStreamPerThread}; diff --git a/src/data/array_interface.cc b/src/data/array_interface.cc index 06b9ed00c..c6d9eda74 100644 --- a/src/data/array_interface.cc +++ b/src/data/array_interface.cc @@ -6,7 +6,7 @@ #include "../common/common.h" // for AssertGPUSupport namespace xgboost { -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) void ArrayInterfaceHandler::SyncCudaStream(int64_t) { common::AssertGPUSupport(); } bool ArrayInterfaceHandler::IsCudaPtr(void const *) { return false; } #endif // !defined(XGBOOST_USE_CUDA) diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 589b91acc..37e88f838 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -9,7 +9,7 @@ #include // for int32_t #include "../common/common.h" // for Range -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) #include "../common/linalg_op.cuh" #endif #include "../common/linalg_op.h" diff --git a/tests/cpp/objective/test_multiclass_obj_gpu.hip b/tests/cpp/objective/test_multiclass_obj_gpu.hip index 6bf3f66b0..938ddd9d8 100644 --- a/tests/cpp/objective/test_multiclass_obj_gpu.hip +++ b/tests/cpp/objective/test_multiclass_obj_gpu.hip @@ -1,2 +1,2 @@ -#include "test_multiclass_obj.cc" +#include "test_multiclass_obj_gpu.cu" diff --git a/tests/cpp/objective/test_regression_obj_cpu.cc b/tests/cpp/objective/test_regression_obj_cpu.cc index 3613d0d90..afc8cbb73 100644 --- a/tests/cpp/objective/test_regression_obj_cpu.cc +++ b/tests/cpp/objective/test_regression_obj_cpu.cc @@ -193,7 +193,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"}); } -#if defined(__CUDACC__) +#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) TEST(Objective, CPU_vs_CUDA) { Context ctx = MakeCUDACtx(GPUIDX); @@ -271,7 +271,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { } // CoxRegression not implemented in GPU code, no need for testing. -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(__HIP_PLATFORM_AMD__) TEST(Objective, CoxRegressionGPair) { Context ctx = MakeCUDACtx(GPUIDX); std::vector> args; diff --git a/tests/cpp/objective/test_regression_obj_gpu.hip b/tests/cpp/objective/test_regression_obj_gpu.hip index b5a636e26..62154585e 100644 --- a/tests/cpp/objective/test_regression_obj_gpu.hip +++ b/tests/cpp/objective/test_regression_obj_gpu.hip @@ -1,2 +1,2 @@ -#include "test_regression_obj.cc" +#include "test_regression_obj_gpu.cu"