fix lambdarank_obj.cc, support HIP

This commit is contained in:
amdsc21 2023-05-02 19:03:18 +02:00
parent e4538cb13c
commit 83e6fceb5c
2 changed files with 9 additions and 9 deletions

View File

@ -414,7 +414,7 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
}; };
namespace cuda_impl { namespace cuda_impl {
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector<float> const&, void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector<float> const&,
const MetaInfo&, std::shared_ptr<ltr::NDCGCache>, const MetaInfo&, std::shared_ptr<ltr::NDCGCache>,
linalg::VectorView<double const>, // input bias ratio linalg::VectorView<double const>, // input bias ratio
@ -430,7 +430,7 @@ void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView<double cons
linalg::Vector<double>*, std::shared_ptr<ltr::RankingCache>) { linalg::Vector<double>*, std::shared_ptr<ltr::RankingCache>) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
} // namespace cuda_impl } // namespace cuda_impl
namespace cpu_impl { namespace cpu_impl {
@ -533,7 +533,7 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
} }
}; };
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
namespace cuda_impl { namespace cuda_impl {
void MAPStat(Context const*, MetaInfo const&, common::Span<std::size_t const>, void MAPStat(Context const*, MetaInfo const&, common::Span<std::size_t const>,
std::shared_ptr<ltr::MAPCache>) { std::shared_ptr<ltr::MAPCache>) {
@ -549,7 +549,7 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector<flo
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
} // namespace cuda_impl } // namespace cuda_impl
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
/** /**
* \brief The RankNet loss. * \brief The RankNet loss.
@ -604,7 +604,7 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
} }
}; };
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
namespace cuda_impl { namespace cuda_impl {
void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVector<float> const&, void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVector<float> const&,
const MetaInfo&, std::shared_ptr<ltr::RankingCache>, const MetaInfo&, std::shared_ptr<ltr::RankingCache>,
@ -615,7 +615,7 @@ void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVecto
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
} // namespace cuda_impl } // namespace cuda_impl
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name()) XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name())
.describe("LambdaRank with NDCG loss as objective") .describe("LambdaRank with NDCG loss as objective")

View File

@ -518,9 +518,9 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
Launch(ctx, iter, predt, info, p_cache, delta, ti_plus, tj_minus, li, lj, out_gpair); Launch(ctx, iter, predt, info, p_cache, delta, ti_plus, tj_minus, li, lj, out_gpair);
} }
struct ReduceOp : thrust::binary_function<thrust::tuple<double, double> const&, thrust::tuple<double, double> struct ReduceOp {
const&, thrust::tuple<double, double>> { template <typename Tup>
thrust::tuple<double, double> __host__ XGBOOST_DEVICE operator()(thrust::tuple<double, double> const& l, thrust::tuple<double, double> const& r) { Tup XGBOOST_DEVICE operator()(Tup const& l, Tup const& r) const {
return thrust::make_tuple(thrust::get<0>(l) + thrust::get<0>(r), return thrust::make_tuple(thrust::get<0>(l) + thrust::get<0>(r),
thrust::get<1>(l) + thrust::get<1>(r)); thrust::get<1>(l) + thrust::get<1>(r));
} }