enable federated

This commit is contained in:
Hui Liu 2023-10-31 16:31:56 -07:00
parent 123af45327
commit 129bb76941
6 changed files with 15 additions and 3 deletions

View File

@ -51,6 +51,10 @@ target_sources(
if(USE_CUDA) if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu) target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
endif() endif()
if(USE_HIP)
target_sources(objxgboost PRIVATE federated_comm.hip federated_coll.hip)
endif()
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL") target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -54,7 +54,7 @@ namespace {
} }
} // namespace } // namespace
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
Coll *FederatedColl::MakeCUDAVar() { Coll *FederatedColl::MakeCUDAVar() {
common::AssertGPUSupport(); common::AssertGPUSupport();
return nullptr; return nullptr;

View File

@ -0,0 +1,4 @@
#ifdef XGBOOST_USE_HIP
#include "federated_coll.cu"
#endif

View File

@ -120,7 +120,7 @@ FederatedComm::FederatedComm(Json const& config) {
client_cert); client_cert);
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const { Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport(); common::AssertGPUSupport();
return nullptr; return nullptr;

View File

@ -0,0 +1,4 @@
#ifdef XGBOOST_USE_HIP
#include "federated_comm.cu"
#endif

View File

@ -49,7 +49,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World()); this->Rank(), this->World());
} }
#if !defined(XGBOOST_USE_NCCL) #if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const { Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport(); common::AssertGPUSupport();
common::AssertNCCLSupport(); common::AssertNCCLSupport();