/** * Copyright 2023, XGBoost Contributors */ #pragma once #include // for shared_ptr, unique_ptr #include // for string #include // for move #include "coll.h" // for Comm #include "comm.h" // for Coll #include "xgboost/collective/result.h" // for Result #include "xgboost/collective/socket.h" // for GetHostName namespace xgboost::collective { /** * @brief Communicator group used for double dispatching between communicators and * collective implementations. */ class CommGroup { std::shared_ptr comm_; mutable std::shared_ptr gpu_comm_; std::shared_ptr backend_; mutable std::shared_ptr gpu_coll_; // lazy initialization CommGroup(std::shared_ptr comm, std::shared_ptr coll) : comm_{std::dynamic_pointer_cast(comm)}, backend_{std::move(coll)} { CHECK(comm_); } public: CommGroup(); [[nodiscard]] auto World() const { return comm_->World(); } [[nodiscard]] auto Rank() const { return comm_->Rank(); } [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } [[nodiscard]] static CommGroup* Create(Json config); [[nodiscard]] std::shared_ptr Backend(DeviceOrd device) const; [[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const; [[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); } [[nodiscard]] Result ProcessorName(std::string* out) const { auto rc = GetHostName(out); return rc; } }; std::unique_ptr& GlobalCommGroup(); void GlobalCommGroupInit(Json config); void GlobalCommGroupFinalize(); } // namespace xgboost::collective