diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu index 8fbf89818..9eb54394c 100644 --- a/src/common/ranking_utils.cu +++ b/src/common/ranking_utils.cu @@ -23,6 +23,12 @@ #include "xgboost/logging.h" // for CHECK #include "xgboost/span.h" // for Span +#if defined(XGBOOST_USE_HIP) +#include + +namespace cub = hipcub; +#endif + namespace xgboost::ltr { namespace cuda_impl { void CalcQueriesDCG(Context const* ctx, linalg::VectorView d_labels, @@ -141,8 +147,13 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { auto const& h_group_ptr = info.group_ptr_; group_ptr_.Resize(h_group_ptr.size()); auto d_group_ptr = group_ptr_.DeviceSpan(); +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(), cudaMemcpyHostToDevice, cuctx->Stream())); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(), + hipMemcpyHostToDevice, cuctx->Stream())); +#endif } auto d_group_ptr = DataGroupPtr(ctx);