fix macro
This commit is contained in:
parent
80961039d7
commit
22525c002a
@ -23,6 +23,12 @@
|
|||||||
#include "xgboost/logging.h" // for CHECK
|
#include "xgboost/logging.h" // for CHECK
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
|
||||||
|
namespace cub = hipcub;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xgboost::ltr {
|
namespace xgboost::ltr {
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
void CalcQueriesDCG(Context const* ctx, linalg::VectorView<float const> d_labels,
|
void CalcQueriesDCG(Context const* ctx, linalg::VectorView<float const> d_labels,
|
||||||
@ -141,8 +147,13 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
|||||||
auto const& h_group_ptr = info.group_ptr_;
|
auto const& h_group_ptr = info.group_ptr_;
|
||||||
group_ptr_.Resize(h_group_ptr.size());
|
group_ptr_.Resize(h_group_ptr.size());
|
||||||
auto d_group_ptr = group_ptr_.DeviceSpan();
|
auto d_group_ptr = group_ptr_.DeviceSpan();
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
|
dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
|
||||||
cudaMemcpyHostToDevice, cuctx->Stream()));
|
cudaMemcpyHostToDevice, cuctx->Stream()));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
|
||||||
|
hipMemcpyHostToDevice, cuctx->Stream()));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_group_ptr = DataGroupPtr(ctx);
|
auto d_group_ptr = DataGroupPtr(ctx);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user