From bdcb036592a0b012d72db0e976bad690ef53f9c2 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 07:34:19 +0100 Subject: [PATCH] add context.hip --- src/common/cuda_context.cuh | 4 ++++ src/common/device_helpers.hip.h | 2 +- src/context.cc | 8 ++++---- src/context.hip | 2 ++ 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/common/cuda_context.cuh b/src/common/cuda_context.cuh index 372b49dde..47b51c009 100644 --- a/src/common/cuda_context.cuh +++ b/src/common/cuda_context.cuh @@ -5,7 +5,11 @@ #define XGBOOST_COMMON_CUDA_CONTEXT_CUH_ #include +#if defined(XGBOOST_USE_HIP) +#include "device_helpers.hip.h" +#elif defined(XGBOOST_USE_CUDA) #include "device_helpers.cuh" +#endif namespace xgboost { struct CUDAContext { diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 975702d77..0452d6626 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -23,7 +23,7 @@ #include #include // for size_t #include -#include +#include #include #include #include diff --git a/src/context.cc b/src/context.cc index 28fda9c45..6d4eb6d8a 100644 --- a/src/context.cc +++ b/src/context.cc @@ -18,7 +18,7 @@ std::int64_t constexpr Context::kDefaultSeed; Context::Context() : cfs_cpu_count_{common::GetCfsCPUCount()} {} void Context::ConfigureGpuId(bool require_gpu) { -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) if (gpu_id == kCpuId) { // 0. User didn't specify the `gpu_id' if (require_gpu) { // 1. `tree_method' or `predictor' or both are using // GPU. @@ -47,7 +47,7 @@ void Context::ConfigureGpuId(bool require_gpu) { // Just set it to CPU, don't think about it. this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}}); (void)(require_gpu); -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_ common::SetDevice(this->gpu_id); } @@ -60,10 +60,10 @@ std::int32_t Context::Threads() const { return n_threads; } -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) CUDAContext const* Context::CUDACtx() const { common::AssertGPUSupport(); return nullptr; } -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) } // namespace xgboost diff --git a/src/context.hip b/src/context.hip index e69de29bb..487feeccb 100644 --- a/src/context.hip +++ b/src/context.hip @@ -0,0 +1,2 @@ + +#include "context.cu"