Update collective implementation. (#10152)

* Update collective implementation.

- Cleanup resource during `Finalize` to avoid handling threads in destructor.
- Calculate the size for allgather automatically.
- Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
Jiaming Yuan
2024-03-30 18:57:31 +08:00
committed by GitHub
parent 230010d9a0
commit 8bad677c2f
31 changed files with 233 additions and 127 deletions

View File

@@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
auto size = data.size_bytes() / comm.World();
auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size);

View File

@@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
};
}
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
@@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
@@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,

View File

@@ -1,12 +1,9 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedColl : public Coll {
@@ -20,8 +17,7 @@ class FederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
@@ -9,7 +9,6 @@
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
#include "xgboost/logging.h"
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#pragma once
@@ -11,7 +11,6 @@
#include <string> // for string
#include "../../src/collective/comm.h" // for HostComm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
@@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
@@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
/**
* @brief Get a string ID for the current process.
*/
[[nodiscard]] Result ProcessorName(std::string* out) const final {
auto rank = this->Rank();
*out = "rank:" + std::to_string(rank);
return Success();
};
};
} // namespace xgboost::collective

View File

@@ -1,22 +1,18 @@
/**
* Copyright 2022-2023, XGBoost contributors
* Copyright 2022-2024, XGBoost contributors
*/
#pragma once
#include <federated.old.grpc.pb.h>
#include <cstdint> // for int32_t
#include <future> // for future
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;

View File

@@ -17,8 +17,7 @@ namespace xgboost::collective {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;