diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 872cb0cc6..c9a79dfda 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -10,7 +10,11 @@ namespace xgboost { namespace data { void EllpackPageSource::Fetch() { +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(param_.gpu_id)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(param_.gpu_id)); +#endif if (!this->ReadCache()) { if (count_ != 0 && !sync_) { // source is initialized to be the 0th page during construction, so when count_ is 0 diff --git a/src/data/ellpack_page_source.hip b/src/data/ellpack_page_source.hip index e69de29bb..fe26c1cb2 100644 --- a/src/data/ellpack_page_source.hip +++ b/src/data/ellpack_page_source.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "ellpack_page_source.cu" +#endif