add context.hip
This commit is contained in:
parent
7a3a9b682a
commit
bdcb036592
@ -5,7 +5,11 @@
|
||||
#define XGBOOST_COMMON_CUDA_CONTEXT_CUH_
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#include "device_helpers.hip.h"
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
#include "device_helpers.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
struct CUDAContext {
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
#include <chrono>
|
||||
#include <cstddef> // for size_t
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <cub/util_allocator.cuh>
|
||||
#include <hipcub/util_allocator.hpp>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
@ -18,7 +18,7 @@ std::int64_t constexpr Context::kDefaultSeed;
|
||||
Context::Context() : cfs_cpu_count_{common::GetCfsCPUCount()} {}
|
||||
|
||||
void Context::ConfigureGpuId(bool require_gpu) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
if (gpu_id == kCpuId) { // 0. User didn't specify the `gpu_id'
|
||||
if (require_gpu) { // 1. `tree_method' or `predictor' or both are using
|
||||
// GPU.
|
||||
@ -47,7 +47,7 @@ void Context::ConfigureGpuId(bool require_gpu) {
|
||||
// Just set it to CPU, don't think about it.
|
||||
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}});
|
||||
(void)(require_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_
|
||||
|
||||
common::SetDevice(this->gpu_id);
|
||||
}
|
||||
@ -60,10 +60,10 @@ std::int32_t Context::Threads() const {
|
||||
return n_threads;
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||
CUDAContext const* Context::CUDACtx() const {
|
||||
common::AssertGPUSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
} // namespace xgboost
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
|
||||
#include "context.cu"
|
||||
Loading…
x
Reference in New Issue
Block a user