From 500428cc0f37180bed615f35a3fce0ad1b3c7cd9 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Thu, 9 Mar 2023 22:31:11 +0100 Subject: [PATCH] finish row_partitioner.cu --- src/tree/gpu_hist/row_partitioner.cu | 21 +++++++++++++++++++++ src/tree/gpu_hist/row_partitioner.hip | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 015d817f3..137999acc 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -7,7 +7,12 @@ #include +#if defined(XGBOOST_USE_CUDA) #include "../../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../../common/device_helpers.hip.h" +#endif + #include "row_partitioner.cuh" namespace xgboost { @@ -15,15 +20,31 @@ namespace tree { RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) : device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) { + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(device_idx_)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_idx_)); +#endif + ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaStreamCreate(&stream_)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipStreamCreate(&stream_)); +#endif } RowPartitioner::~RowPartitioner() { +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(device_idx_)); dh::safe_cuda(cudaStreamDestroy(stream_)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_idx_)); + dh::safe_cuda(hipStreamDestroy(stream_)); +#endif } common::Span RowPartitioner::GetRows(bst_node_t nidx) { diff --git a/src/tree/gpu_hist/row_partitioner.hip b/src/tree/gpu_hist/row_partitioner.hip index e69de29bb..ac03ac0d7 100644 --- a/src/tree/gpu_hist/row_partitioner.hip +++ b/src/tree/gpu_hist/row_partitioner.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "row_partitioner.cu" +#endif