diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index 22c85f3ad..1b629f6f6 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -50,7 +50,7 @@ void Communicator::Init(Json const& config) { } } -#ifndef XGBOOST_USE_CUDA +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) void Communicator::Finalize() { communicator_->Shutdown(); communicator_.reset(new NoOpCommunicator()); diff --git a/src/collective/communicator.h b/src/collective/communicator.h index de8a0e7d7..2c19f9576 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -228,7 +228,7 @@ class Communicator { static thread_local std::unique_ptr communicator_; static thread_local CommunicatorType type_; -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) static thread_local int device_ordinal_; static thread_local std::unique_ptr device_communicator_; #endif diff --git a/src/collective/communicator.hip b/src/collective/communicator.hip index e69de29bb..5a438771c 100644 --- a/src/collective/communicator.hip +++ b/src/collective/communicator.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "communicator.cu" +#endif