merge latest changes
This commit is contained in:
@@ -29,7 +29,7 @@ namespace {
|
||||
auto stub = fed->Handle();
|
||||
|
||||
BroadcastRequest request;
|
||||
request.set_sequence_number(*sequence_number++);
|
||||
request.set_sequence_number((*sequence_number)++);
|
||||
request.set_rank(comm.Rank());
|
||||
if (comm.Rank() != root) {
|
||||
request.set_send_buffer(nullptr, 0);
|
||||
@@ -90,9 +90,9 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (comm.Rank() == root) {
|
||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
} else {
|
||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
||||
}
|
||||
}
|
||||
|
||||
FederatedComm::FederatedComm(Json const& config) {
|
||||
FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
|
||||
Json const& config) {
|
||||
/**
|
||||
* Topology
|
||||
*/
|
||||
@@ -93,6 +94,13 @@ FederatedComm::FederatedComm(Json const& config) {
|
||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||
|
||||
/**
|
||||
* Basic config
|
||||
*/
|
||||
this->retry_ = retry;
|
||||
this->timeout_ = timeout;
|
||||
this->task_id_ = task_id;
|
||||
|
||||
/**
|
||||
* Certificates
|
||||
*/
|
||||
|
||||
@@ -11,6 +11,8 @@ namespace xgboost::collective {
|
||||
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
|
||||
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
|
||||
CHECK(impl);
|
||||
CHECK(ctx->IsCUDA());
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
}
|
||||
|
||||
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../src/collective/coll.h" // for Coll
|
||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedComm : public FederatedComm {
|
||||
@@ -16,5 +18,9 @@ class CUDAFederatedComm : public FederatedComm {
|
||||
public:
|
||||
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
|
||||
[[nodiscard]] auto Stream() const { return stream_; }
|
||||
Comm* MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const override {
|
||||
LOG(FATAL) << "[Internal Error]: Invalid request for CUDA variant.";
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -10,12 +10,12 @@
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/collective/comm.h" // for HostComm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
class FederatedComm : public HostComm {
|
||||
std::shared_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
@@ -27,6 +27,10 @@ class FederatedComm : public Comm {
|
||||
this->rank_ = that->Rank();
|
||||
this->world_ = that->World();
|
||||
|
||||
this->retry_ = that->Retry();
|
||||
this->timeout_ = that->Timeout();
|
||||
this->task_id_ = that->TaskID();
|
||||
|
||||
this->tracker_ = that->TrackerInfo();
|
||||
}
|
||||
|
||||
@@ -41,7 +45,8 @@ class FederatedComm : public Comm {
|
||||
* - federated_client_key_path
|
||||
* - federated_client_cert_path
|
||||
*/
|
||||
explicit FederatedComm(Json const& config);
|
||||
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
|
||||
Json const& config);
|
||||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank) {
|
||||
this->Init(host, port, world, rank, {}, {}, {});
|
||||
@@ -59,6 +64,6 @@ class FederatedComm : public Comm {
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user