fix macro
This commit is contained in:
parent
80961039d7
commit
22525c002a
@ -23,6 +23,12 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#include <hipcub/hipcub.hpp>
|
||||
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
namespace xgboost::ltr {
|
||||
namespace cuda_impl {
|
||||
void CalcQueriesDCG(Context const* ctx, linalg::VectorView<float const> d_labels,
|
||||
@ -141,8 +147,13 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||
auto const& h_group_ptr = info.group_ptr_;
|
||||
group_ptr_.Resize(h_group_ptr.size());
|
||||
auto d_group_ptr = group_ptr_.DeviceSpan();
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
|
||||
cudaMemcpyHostToDevice, cuctx->Stream()));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
|
||||
hipMemcpyHostToDevice, cuctx->Stream()));
|
||||
#endif
|
||||
}
|
||||
|
||||
auto d_group_ptr = DataGroupPtr(ctx);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user