From 134cbfddbe1777bc1e36fe5034217cb74ff3727c Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Fri, 10 Mar 2023 04:40:33 +0100 Subject: [PATCH] finish gradient_index.cu --- src/data/array_interface.cu | 10 +++++++--- src/data/gradient_index.cc | 4 ++-- src/data/gradient_index.hip | 4 ++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index 875a10606..5a72d66d7 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -23,7 +23,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { case 2: // default per-thread stream default: +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast(stream))); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipStreamSynchronize(reinterpret_cast(stream))); +#endif } } @@ -50,12 +54,12 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { return true; } return true; -#elif defined(XGBOOST_USE_HIP) - return false; -#endif } else { // other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. return false; } +#elif defined(XGBOOST_USE_HIP) + return false; +#endif } } // namespace xgboost diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 0a606ecd5..4d7dbe9b5 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -67,12 +67,12 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts & max_numeric_bins_per_feat(max_bin_per_feat), isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &, BatchParam const &) { common::AssertGPUSupport(); } -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) GHistIndexMatrix::~GHistIndexMatrix() = default; diff --git a/src/data/gradient_index.hip b/src/data/gradient_index.hip index e69de29bb..7cc0c154d 100644 --- a/src/data/gradient_index.hip +++ b/src/data/gradient_index.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "gradient_index.cu" +#endif