diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index c4d5ea378..4b9734c4e 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -51,6 +51,10 @@ target_sources( if(USE_CUDA) target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu) endif() +if(USE_HIP) + target_sources(objxgboost PRIVATE federated_comm.hip federated_coll.hip) +endif() + target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL") target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc index 7c25eeba5..0982166a4 100644 --- a/plugin/federated/federated_coll.cc +++ b/plugin/federated/federated_coll.cc @@ -54,7 +54,7 @@ namespace { } } // namespace -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) Coll *FederatedColl::MakeCUDAVar() { common::AssertGPUSupport(); return nullptr; diff --git a/plugin/federated/federated_coll.hip b/plugin/federated/federated_coll.hip new file mode 100644 index 000000000..e7065297c --- /dev/null +++ b/plugin/federated/federated_coll.hip @@ -0,0 +1,4 @@ + +#ifdef XGBOOST_USE_HIP +#include "federated_coll.cu" +#endif diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc index 8a649340f..581b63b7c 100644 --- a/plugin/federated/federated_comm.cc +++ b/plugin/federated/federated_comm.cc @@ -120,7 +120,7 @@ FederatedComm::FederatedComm(Json const& config) { client_cert); } -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP) Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr) const { common::AssertGPUSupport(); return nullptr; diff --git a/plugin/federated/federated_comm.hip b/plugin/federated/federated_comm.hip new file mode 100644 index 000000000..5da36ffff --- /dev/null +++ b/plugin/federated/federated_comm.hip @@ -0,0 +1,4 @@ + +#ifdef XGBOOST_USE_HIP +#include "federated_comm.cu" +#endif diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 241dca2ce..1af15805b 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -49,7 +49,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st this->Rank(), this->World()); } -#if !defined(XGBOOST_USE_NCCL) +#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL) Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr) const { common::AssertGPUSupport(); common::AssertNCCLSupport();