Update collective implementation. (#10152)
* Update collective implementation. - Cleanup resource during `Finalize` to avoid handling threads in destructor. - Calculate the size for allgather automatically. - Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
@@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (comm.Rank() == root) {
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
} else {
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
}
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
|
||||
using namespace federated; // NOLINT
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
auto size = data.size_bytes() / comm.World();
|
||||
|
||||
auto offset = comm.Rank() * size;
|
||||
auto segment = data.subspan(offset, size);
|
||||
|
||||
@@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
@@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
|
||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*/
|
||||
#include "../../src/collective/comm.h" // for Comm, Coll
|
||||
#include "federated_coll.h" // for FederatedColl
|
||||
@@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../src/collective/coll.h" // for Coll
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedColl : public Coll {
|
||||
@@ -20,8 +17,7 @@ class FederatedColl : public Coll {
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedComm : public FederatedComm {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for HostComm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
|
||||
std::int32_t rank) {
|
||||
this->Init(host, port, world, rank, {}, {}, {});
|
||||
}
|
||||
[[nodiscard]] Result Shutdown() final {
|
||||
this->ResetState();
|
||||
return Success();
|
||||
}
|
||||
~FederatedComm() override { stub_.reset(); }
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||
@@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
|
||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
/**
|
||||
* @brief Get a string ID for the current process.
|
||||
*/
|
||||
[[nodiscard]] Result ProcessorName(std::string* out) const final {
|
||||
auto rank = this->Rank();
|
||||
*out = "rank:" + std::to_string(rank);
|
||||
return Success();
|
||||
};
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.old.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
@@ -17,8 +17,7 @@ namespace xgboost::collective {
|
||||
namespace federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
Reference in New Issue
Block a user