enable federated

This commit is contained in:
Hui Liu 2023-10-31 16:31:56 -07:00
parent 123af45327
commit 129bb76941
6 changed files with 15 additions and 3 deletions

View File

@ -51,6 +51,10 @@ target_sources(
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
endif()
if(USE_HIP)
target_sources(objxgboost PRIVATE federated_comm.hip federated_coll.hip)
endif()
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -54,7 +54,7 @@ namespace {
}
} // namespace
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
Coll *FederatedColl::MakeCUDAVar() {
common::AssertGPUSupport();
return nullptr;

View File

@ -0,0 +1,4 @@
#ifdef XGBOOST_USE_HIP
#include "federated_coll.cu"
#endif

View File

@ -120,7 +120,7 @@ FederatedComm::FederatedComm(Json const& config) {
client_cert);
}
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
return nullptr;

View File

@ -0,0 +1,4 @@
#ifdef XGBOOST_USE_HIP
#include "federated_comm.cu"
#endif

View File

@ -49,7 +49,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World());
}
#if !defined(XGBOOST_USE_NCCL)
#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
common::AssertNCCLSupport();