xgboost/plugin/federated/federated_comm.h
Jiaming Yuan 0715ab3c10
Use dlopen to load NCCL. (#9796)
This PR adds optional support for loading nccl with `dlopen` as an alternative of compile time linking. This is to address the size bloat issue with the PyPI binary release.
- Add CMake option to load `nccl` at runtime.
- Add an NCCL stub.

After this, `nccl` will be fetched from PyPI when using pip to install XGBoost, either by a user or by `pyproject.toml`. Others who want to link the nccl at compile time can continue to do so without any change.

At the moment, this is Linux only since we only support MNMG on Linux.
2023-11-22 19:27:31 +08:00

70 lines
2.3 KiB
C++

/**
* Copyright 2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/comm.h" // for HostComm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
class FederatedComm : public HostComm {
std::shared_ptr<federated::Federated::Stub> stub_;
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);
protected:
explicit FederatedComm(std::shared_ptr<FederatedComm const> that) : stub_{that->stub_} {
this->rank_ = that->Rank();
this->world_ = that->World();
this->retry_ = that->Retry();
this->timeout_ = that->Timeout();
this->task_id_ = that->TaskID();
this->tracker_ = that->TrackerInfo();
}
public:
/**
* @param config
*
* - federated_server_address: Tracker address
* - federated_world_size: The number of workers
* - federated_rank: Rank of federated worker
* - federated_server_cert_path
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
return nullptr;
}
[[nodiscard]] Result LogTracker(std::string msg) const override {
LOG(CONSOLE) << msg;
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};
} // namespace xgboost::collective