finish sparse_page_dmatrix.cu

This commit is contained in:
amdsc21 2023-03-10 05:04:57 +01:00
parent 080fc35c4b
commit fa9f69dd85
2 changed files with 8 additions and 4 deletions

View File

@ -20,7 +20,7 @@ const MetaInfo &SparsePageDMatrix::Info() const { return info_; }
namespace detail {
// Use device dispatch
std::size_t NSamplesDevice(DMatrixProxy *) // NOLINT
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
; // NOLINT
#else
{
@ -29,7 +29,7 @@ std::size_t NSamplesDevice(DMatrixProxy *) // NOLINT
}
#endif
std::size_t NFeaturesDevice(DMatrixProxy *) // NOLINT
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
; // NOLINT
#else
{
@ -188,12 +188,12 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(begin_iter));
}
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam &) {
common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
}
#endif // !defined(XGBOOST_USE_CUDA)
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
} // namespace data
} // namespace xgboost

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "sparse_page_dmatrix.cu"
#endif