fix macro

This commit is contained in:
amdsc21 2023-03-25 05:08:30 +01:00
parent 80961039d7
commit 22525c002a

View File

@ -23,6 +23,12 @@
#include "xgboost/logging.h" // for CHECK #include "xgboost/logging.h" // for CHECK
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
#if defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
namespace xgboost::ltr { namespace xgboost::ltr {
namespace cuda_impl { namespace cuda_impl {
void CalcQueriesDCG(Context const* ctx, linalg::VectorView<float const> d_labels, void CalcQueriesDCG(Context const* ctx, linalg::VectorView<float const> d_labels,
@ -141,8 +147,13 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const& h_group_ptr = info.group_ptr_; auto const& h_group_ptr = info.group_ptr_;
group_ptr_.Resize(h_group_ptr.size()); group_ptr_.Resize(h_group_ptr.size());
auto d_group_ptr = group_ptr_.DeviceSpan(); auto d_group_ptr = group_ptr_.DeviceSpan();
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
cudaMemcpyHostToDevice, cuctx->Stream())); cudaMemcpyHostToDevice, cuctx->Stream()));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
hipMemcpyHostToDevice, cuctx->Stream()));
#endif
} }
auto d_group_ptr = DataGroupPtr(ctx); auto d_group_ptr = DataGroupPtr(ctx);