merge latest changes

This commit is contained in:
Hui Liu
2023-12-13 21:06:28 -08:00
194 changed files with 4859 additions and 2838 deletions

View File

@@ -29,7 +29,7 @@ namespace {
auto stub = fed->Handle();
BroadcastRequest request;
request.set_sequence_number(*sequence_number++);
request.set_sequence_number((*sequence_number)++);
request.set_rank(comm.Rank());
if (comm.Rank() != root) {
request.set_send_buffer(nullptr, 0);
@@ -90,9 +90,9 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
}

View File

@@ -60,7 +60,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
}
}
FederatedComm::FederatedComm(Json const& config) {
FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config) {
/**
* Topology
*/
@@ -93,6 +94,13 @@ FederatedComm::FederatedComm(Json const& config) {
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
/**
* Basic config
*/
this->retry_ = retry;
this->timeout_ = timeout;
this->task_id_ = task_id;
/**
* Certificates
*/

View File

@@ -11,6 +11,8 @@ namespace xgboost::collective {
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
CHECK(impl);
CHECK(ctx->IsCUDA());
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
}
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {

View File

@@ -5,9 +5,11 @@
#include <memory> // for shared_ptr
#include "../../src/collective/coll.h" // for Coll
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
#include "xgboost/logging.h"
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {
@@ -16,5 +18,9 @@ class CUDAFederatedComm : public FederatedComm {
public:
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
[[nodiscard]] auto Stream() const { return stream_; }
Comm* MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const override {
LOG(FATAL) << "[Internal Error]: Invalid request for CUDA variant.";
return nullptr;
}
};
} // namespace xgboost::collective

View File

@@ -10,12 +10,12 @@
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/comm.h" // for Comm
#include "../../src/collective/comm.h" // for HostComm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
class FederatedComm : public Comm {
class FederatedComm : public HostComm {
std::shared_ptr<federated::Federated::Stub> stub_;
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
@@ -27,6 +27,10 @@ class FederatedComm : public Comm {
this->rank_ = that->Rank();
this->world_ = that->World();
this->retry_ = that->Retry();
this->timeout_ = that->Timeout();
this->task_id_ = that->TaskID();
this->tracker_ = that->TrackerInfo();
}
@@ -41,7 +45,8 @@ class FederatedComm : public Comm {
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(Json const& config);
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
@@ -59,6 +64,6 @@ class FederatedComm : public Comm {
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};
} // namespace xgboost::collective