[coll] Add comm group. (#9759)

- Implement `CommGroup` for double dispatching.
- Small cleanup to tracker for handling abort.
This commit is contained in:
Jiaming Yuan
2023-11-07 11:12:31 +08:00
committed by GitHub
parent c3a0622b49
commit 6c0a190f6d
15 changed files with 462 additions and 79 deletions

View File

@@ -60,7 +60,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
}
}
FederatedComm::FederatedComm(Json const& config) {
FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config) {
/**
* Topology
*/
@@ -93,6 +94,13 @@ FederatedComm::FederatedComm(Json const& config) {
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
/**
* Basic config
*/
this->retry_ = retry;
this->timeout_ = timeout;
this->task_id_ = task_id;
/**
* Certificates
*/

View File

@@ -11,6 +11,8 @@ namespace xgboost::collective {
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
CHECK(impl);
CHECK(ctx->IsCUDA());
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
}
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {

View File

@@ -27,6 +27,10 @@ class FederatedComm : public Comm {
this->rank_ = that->Rank();
this->world_ = that->World();
this->retry_ = that->Retry();
this->timeout_ = that->Timeout();
this->task_id_ = that->TaskID();
this->tracker_ = that->TrackerInfo();
}
@@ -41,7 +45,8 @@ class FederatedComm : public Comm {
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(Json const& config);
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});