diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc index 8a649340f..ec1287413 100644 --- a/plugin/federated/federated_comm.cc +++ b/plugin/federated/federated_comm.cc @@ -60,7 +60,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_ } } -FederatedComm::FederatedComm(Json const& config) { +FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id, + Json const& config) { /** * Topology */ @@ -93,6 +94,13 @@ FederatedComm::FederatedComm(Json const& config) { CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required."; CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required."; + /** + * Basic config + */ + this->retry_ = retry; + this->timeout_ = timeout; + this->task_id_ = task_id; + /** * Certificates */ diff --git a/plugin/federated/federated_comm.cu b/plugin/federated/federated_comm.cu index b05d38b1b..3eb8eb4f7 100644 --- a/plugin/federated/federated_comm.cu +++ b/plugin/federated/federated_comm.cu @@ -11,6 +11,8 @@ namespace xgboost::collective { CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr impl) : FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} { CHECK(impl); + CHECK(ctx->IsCUDA()); + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); } Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr) const { diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index fb97a78b0..a24798626 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -27,6 +27,10 @@ class FederatedComm : public Comm { this->rank_ = that->Rank(); this->world_ = that->World(); + this->retry_ = that->Retry(); + this->timeout_ = that->Timeout(); + this->task_id_ = that->TaskID(); + this->tracker_ = that->TrackerInfo(); } @@ -41,7 +45,8 @@ class FederatedComm : public Comm { * - federated_client_key_path * - federated_client_cert_path */ - explicit FederatedComm(Json const& config); + explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id, + Json const& config); explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank) { this->Init(host, port, world, rank, {}, {}, {}); diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 241dca2ce..964137ff1 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -5,13 +5,17 @@ #include // for copy #include // for seconds +#include // for exit #include // for shared_ptr +#include // for unique_lock #include // for string #include // for move, forward #include "../common/common.h" // for AssertGPUSupport +#include "../common/json_utils.h" // for OptionalArg #include "allgather.h" // for RingAllgather #include "protocol.h" // for kMagic +#include "tracker.h" // for GetHostAddress #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE #include "xgboost/collective/socket.h" // for TCPSocket #include "xgboost/json.h" // for Json, Object @@ -209,24 +213,18 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se std::shared_ptr error_sock{TCPSocket::CreatePtr(domain)}; auto eport = error_sock->BindHost(); error_sock->Listen(); - error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] { + error_worker_ = std::thread{[error_sock = std::move(error_sock)] { auto conn = error_sock->Accept(); - // On Windows accept returns an invalid socket after network is shutdown. + // On Windows, accept returns a closed socket after finalize. if (conn.IsClosed()) { return; } LOG(WARNING) << "Another worker is running into error."; - std::string scmd; - conn.Recv(&scmd); - auto jcmd = Json::Load(scmd); - auto rc = this->Shutdown(); - if (!rc.OK()) { - LOG(WARNING) << "Fail to shutdown worker:" << rc.Report(); - } #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 - exit(-1); + // exit is nicer than abort as the former performs cleanups. + std::exit(-1); #else - LOG(FATAL) << rc.Report(); + LOG(FATAL) << "abort"; #endif }}; error_worker_.detach(); diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc new file mode 100644 index 000000000..570500843 --- /dev/null +++ b/src/collective/comm_group.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "comm_group.h" + +#include // for transform +#include // for seconds +#include // for int32_t +#include // for shared_ptr, unique_ptr +#include // for string +#include // for vector + +#include "../common/json_utils.h" // for OptionalArg +#include "coll.h" // for Coll +#include "comm.h" // for Comm +#include "tracker.h" // for GetHostAddress +#include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for DeviceOrd +#include "xgboost/json.h" // for Json + +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_coll.h" +#include "../../plugin/federated/federated_comm.h" +#endif + +namespace xgboost::collective { +[[nodiscard]] std::shared_ptr CommGroup::Backend(DeviceOrd device) const { + if (device.IsCUDA()) { + if (!gpu_coll_) { + gpu_coll_.reset(backend_->MakeCUDAVar()); + } + return gpu_coll_; + } + return backend_; +} + +[[nodiscard]] Comm const& CommGroup::Ctx(Context const* ctx, DeviceOrd device) const { + if (device.IsCUDA()) { + CHECK(ctx->IsCUDA()); + if (!gpu_comm_) { + gpu_comm_.reset(comm_->MakeCUDAVar(ctx, backend_)); + } + return *gpu_comm_; + } + return *comm_; +} + +CommGroup::CommGroup() + : comm_{std::shared_ptr(new RabitComm{})}, // NOLINT + backend_{std::shared_ptr(new Coll{})} {} // NOLINT + +[[nodiscard]] CommGroup* CommGroup::Create(Json config) { + if (IsA(config)) { + return new CommGroup; + } + + std::string type = OptionalArg(config, "dmlc_communicator", std::string{"rabit"}); + std::vector keys; + // Try both lower and upper case for compatibility + auto get_param = [&](std::string name, auto dft, auto t) { + std::string upper; + std::transform(name.cbegin(), name.cend(), std::back_inserter(upper), + [](char c) { return std::toupper(c); }); + std::transform(name.cbegin(), name.cend(), name.begin(), + [](char c) { return std::tolower(c); }); + keys.push_back(upper); + keys.push_back(name); + + auto const& obj = get(config); + auto it = obj.find(upper); + if (it != obj.cend()) { + return OptionalArg(config, upper, dft); + } else { + return OptionalArg(config, name, dft); + } + }; + // Common args + auto retry = + OptionalArg(config, "dmlc_retry", static_cast(DefaultRetry())); + auto timeout = OptionalArg(config, "dmlc_timeout_sec", + static_cast(DefaultTimeoutSec())); + auto task_id = get_param("dmlc_task_id", std::string{}, String{}); + + if (type == "rabit") { + auto host = get_param("dmlc_tracker_uri", std::string{}, String{}); + auto port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); + auto ptr = + new CommGroup{std::shared_ptr{new RabitComm{ // NOLINT + host, static_cast(port), std::chrono::seconds{timeout}, + static_cast(retry), task_id}}, + std::shared_ptr(new Coll{})}; // NOLINT + return ptr; + } else if (type == "federated") { +#if defined(XGBOOST_USE_FEDERATED) + auto ptr = new CommGroup{ + std::make_shared(retry, std::chrono::seconds{timeout}, task_id, config), + std::make_shared()}; + return ptr; +#endif // defined(XGBOOST_USE_FEDERATED) + } else { + LOG(FATAL) << "Invalid communicator type"; + } + + return nullptr; +} + +std::unique_ptr& GlobalCommGroup() { + static std::unique_ptr sptr; + if (!sptr) { + Json config{Null{}}; + sptr.reset(CommGroup::Create(config)); + } + return sptr; +} + +void GlobalCommGroupInit(Json config) { + auto& sptr = GlobalCommGroup(); + sptr.reset(CommGroup::Create(std::move(config))); +} + +void GlobalCommGroupFinalize() { + auto& sptr = GlobalCommGroup(); + sptr.reset(); +} +} // namespace xgboost::collective diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h new file mode 100644 index 000000000..62f3e565f --- /dev/null +++ b/src/collective/comm_group.h @@ -0,0 +1,53 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for shared_ptr, unique_ptr +#include // for string +#include // for move + +#include "coll.h" // for Comm +#include "comm.h" // for Coll +#include "xgboost/collective/result.h" // for Result +#include "xgboost/collective/socket.h" // for GetHostName + +namespace xgboost::collective { +/** + * @brief Communicator group used for double dispatching between communicators and + * collective implementations. + */ +class CommGroup { + std::shared_ptr comm_; + mutable std::shared_ptr gpu_comm_; + + std::shared_ptr backend_; + mutable std::shared_ptr gpu_coll_; // lazy initialization + + CommGroup(std::shared_ptr comm, std::shared_ptr coll) + : comm_{std::move(comm)}, backend_{std::move(coll)} {} + + public: + CommGroup(); + + [[nodiscard]] auto World() const { return comm_->World(); } + [[nodiscard]] auto Rank() const { return comm_->Rank(); } + [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + + [[nodiscard]] static CommGroup* Create(Json config); + + [[nodiscard]] std::shared_ptr Backend(DeviceOrd device) const; + [[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const; + [[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); } + + [[nodiscard]] Result ProcessorName(std::string* out) const { + auto rc = GetHostName(out); + return rc; + } +}; + +std::unique_ptr& GlobalCommGroup(); + +void GlobalCommGroupInit(Json config); + +void GlobalCommGroupFinalize(); +} // namespace xgboost::collective diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 4837e2ace..88c51d8a9 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -58,36 +58,35 @@ Result Tracker::WaitUntilReady() const { RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr) : sock_{std::move(sock)} { - auto host = addr.Addr(); - std::int32_t rank{0}; - rc_ = Success() - << [&] { return proto::Magic{}.Verify(&sock_); } - << [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); }; - if (!rc_.OK()) { - return; - } - - std::string cmd; - sock_.Recv(&cmd); - auto jcmd = Json::Load(StringView{cmd}); - cmd_ = static_cast(get(jcmd["cmd"])); + Json jcmd; std::int32_t port{0}; - if (cmd_ == proto::CMD::kStart) { - proto::Start start; - rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_); - } else if (cmd_ == proto::CMD::kPrint) { - proto::Print print; - rc_ = print.TrackerHandle(jcmd, &msg_); - } else if (cmd_ == proto::CMD::kError) { - proto::ErrorCMD error; - rc_ = error.TrackerHandle(jcmd, &msg_, &code_); - } - if (!rc_.OK()) { - return; - } - info_ = proto::PeerInfo{host, port, rank}; + rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] { + return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); + } << [&] { + std::string cmd; + sock_.Recv(&cmd); + jcmd = Json::Load(StringView{cmd}); + cmd_ = static_cast(get(jcmd["cmd"])); + return Success(); + } << [&] { + if (cmd_ == proto::CMD::kStart) { + proto::Start start; + return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_); + } else if (cmd_ == proto::CMD::kPrint) { + proto::Print print; + return print.TrackerHandle(jcmd, &msg_); + } else if (cmd_ == proto::CMD::kError) { + proto::ErrorCMD error; + return error.TrackerHandle(jcmd, &msg_, &code_); + } + return Success(); + } << [&] { + auto host = addr.Addr(); + info_ = proto::PeerInfo{host, port, rank}; + return Success(); + }; } RabitTracker::RabitTracker(Json const& config) : Tracker{config} { @@ -137,15 +136,18 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { std::int32_t n_shutdown{0}; bool during_restart{false}; + bool running{false}; std::vector pending; explicit State(std::int32_t world) : n_workers{world} {} State(State const& that) = delete; State& operator=(State&& that) = delete; + // modifiers void Start(WorkerProxy&& worker) { CHECK_LT(pending.size(), n_workers); CHECK_LE(n_shutdown, n_workers); + CHECK(!running); pending.emplace_back(std::forward(worker)); @@ -155,6 +157,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { CHECK_GE(n_shutdown, 0); CHECK_LT(n_shutdown, n_workers); + running = false; ++n_shutdown; CHECK_LE(n_shutdown, n_workers); @@ -163,21 +166,26 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { CHECK_LE(pending.size(), n_workers); CHECK_LE(n_shutdown, n_workers); + running = false; during_restart = true; } - [[nodiscard]] bool Ready() const { - CHECK_LE(pending.size(), n_workers); - return static_cast(pending.size()) == n_workers; - } void Bootstrap() { CHECK_EQ(pending.size(), n_workers); CHECK_LE(n_shutdown, n_workers); + running = true; + // A reset. n_shutdown = 0; during_restart = false; pending.clear(); } + + // observers + [[nodiscard]] bool Ready() const { + CHECK_LE(pending.size(), n_workers); + return static_cast(pending.size()) == n_workers; + } [[nodiscard]] bool ShouldContinue() const { CHECK_LE(pending.size(), n_workers); CHECK_LE(n_shutdown, n_workers); @@ -187,7 +195,31 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { } }; - return std::async(std::launch::async, [this] { + auto handle_error = [&](WorkerProxy const& worker) { + auto msg = worker.Msg(); + auto code = worker.Code(); + LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() << "]: " << msg + << " code:" << code; + auto host = worker.Host(); + // We signal all workers for the error, if they haven't aborted already. + for (auto& w : worker_error_handles_) { + if (w.first == host) { + continue; + } + TCPSocket out; + // Connecting to the error port as a signal for exit. + // + // retry is set to 1, just let the worker timeout or error. Otherwise the + // tracker and the worker might be waiting for each other. + auto rc = Connect(w.first, w.second, 1, timeout_, &out); + if (!rc.OK()) { + return Fail("Failed to inform workers to stop."); + } + } + return Success(); + }; + + return std::async(std::launch::async, [this, handle_error] { State state{this->n_workers_}; while (state.ShouldContinue()) { @@ -205,6 +237,16 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { } switch (worker.Command()) { case proto::CMD::kStart: { + if (state.running) { + // Something went wrong with one of the workers. It got disconnected without + // notice. + state.Error(); + rc = handle_error(worker); + if (!rc.OK()) { + return Fail("Failed to handle abort.", std::move(rc)); + } + } + state.Start(std::move(worker)); if (state.Ready()) { rc = this->Bootstrap(&state.pending); @@ -216,36 +258,20 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { continue; } case proto::CMD::kShutdown: { + if (state.during_restart) { + // The worker can still send shutdown after call to `std::exit`. + continue; + } state.Shutdown(); continue; } case proto::CMD::kError: { if (state.during_restart) { + // Ignore further errors. continue; } state.Error(); - auto msg = worker.Msg(); - auto code = worker.Code(); - LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() - << "]: " << msg << " code:" << code; - auto host = worker.Host(); - // We signal all workers for the error, if they haven't aborted already. - for (auto& w : worker_error_handles_) { - if (w.first == host) { - continue; - } - TCPSocket out; - // retry is set to 1, just let the worker timeout or error. Otherwise the - // tracker and the worker might be waiting for each other. - auto rc = Connect(w.first, w.second, 1, timeout_, &out); - // send signal to stop the worker. - proto::ShutdownCMD shutdown; - rc = shutdown.Send(&out); - if (!rc.OK()) { - return Fail("Failed to inform workers to stop."); - } - } - + rc = handle_error(worker); continue; } case proto::CMD::kPrint: { diff --git a/tests/cpp/collective/test_comm_group.cc b/tests/cpp/collective/test_comm_group.cc new file mode 100644 index 000000000..0f6bc23a2 --- /dev/null +++ b/tests/cpp/collective/test_comm_group.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include +#include // for Json + +#include // for seconds +#include // for int32_t +#include // for string +#include // for thread + +#include "../../../src/collective/comm.h" +#include "../../../src/collective/comm_group.h" +#include "../../../src/common/common.h" // for AllVisibleGPUs +#include "../helpers.h" // for MakeCUDACtx +#include "test_worker.h" // for TestDistributed + +namespace xgboost::collective { +namespace { +auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"rabit"}; + config["DMLC_TRACKER_URI"] = host; + config["DMLC_TRACKER_PORT"] = port; + config["dmlc_timeout_sec"] = static_cast(timeout.count()); + config["DMLC_TASK_ID"] = std::to_string(r); + config["dmlc_retry"] = 2; + return config; +} + +class CommGroupTest : public SocketTest {}; +} // namespace + +TEST_F(CommGroupTest, Basic) { + std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 5u); + TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Context ctx; + auto config = MakeConfig(host, port, timeout, r); + std::unique_ptr ptr{CommGroup::Create(config)}; + ASSERT_TRUE(ptr->IsDistributed()); + ASSERT_EQ(ptr->World(), n_workers); + auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CPU()); + ASSERT_EQ(comm.TaskID(), std::to_string(r)); + ASSERT_EQ(comm.Retry(), 2); + }); +} + +#if defined(XGBOOST_USE_NCCL) +TEST_F(CommGroupTest, BasicGPU) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + auto ctx = MakeCUDACtx(r); + auto config = MakeConfig(host, port, timeout, r); + std::unique_ptr ptr{CommGroup::Create(config)}; + auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0)); + ASSERT_EQ(comm.TaskID(), std::to_string(r)); + ASSERT_EQ(comm.Retry(), 2); + }); +} +#endif // for defined(XGBOOST_USE_NCCL) +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 6578ff142..ad3213e81 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -95,7 +95,8 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { std::chrono::seconds timeout{1}; std::string host; - ASSERT_TRUE(GetHostAddress(&host).OK()); + auto rc = GetHostAddress(&host); + ASSERT_TRUE(rc.OK()) << rc.Report(); RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; auto fut = tracker.Run(); diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index f345b3a78..21c5ad30d 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -15,6 +15,15 @@ namespace xgboost::linalg { namespace { DeviceOrd CPU() { return DeviceOrd::CPU(); } + +template +void ConstView(linalg::VectorView v1, linalg::VectorView> v2) { + // compile test for being able to pass non-const view to const view. + auto s = v1.Slice(linalg::All()); + ASSERT_EQ(s.Size(), v1.Size()); + auto s2 = v2.Slice(linalg::All()); + ASSERT_EQ(s2.Size(), v2.Size()); +} } // namespace auto MakeMatrixFromTest(HostDeviceVector *storage, std::size_t n_rows, std::size_t n_cols) { @@ -206,6 +215,11 @@ TEST(Linalg, TensorView) { ASSERT_TRUE(t.FContiguous()); ASSERT_FALSE(t.CContiguous()); } + { + // const + TensorView t{data, {data.size()}, CPU()}; + ConstView(t, t); + } } TEST(Linalg, Tensor) { diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index 44211f8d7..a6ec7e352 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -124,6 +124,9 @@ TEST_F(FederatedCollTestGPU, Allgather) { TEST_F(FederatedCollTestGPU, AllgatherV) { std::int32_t n_workers = 2; + if (common::AllVisibleGPUs() < n_workers) { + GTEST_SKIP_("At least 2 GPUs are required for the test."); + } TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { TestAllgatherV(comm, rank); }); diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc index b45b00910..0d0692b5f 100644 --- a/tests/cpp/plugin/federated/test_federated_comm.cc +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -1,6 +1,7 @@ /** * Copyright 2022-2023, XGBoost contributors */ +#include #include #include // for string @@ -19,12 +20,14 @@ class FederatedCommTest : public SocketTest {}; TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; - ExpectThrow("Invalid world size.", construct); + ASSERT_THAT(construct, + ::testing::ThrowsMessage(::testing::HasSubstr("Invalid world size"))); } TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; - ExpectThrow("Invalid worker rank.", construct); + ASSERT_THAT(construct, + ::testing::ThrowsMessage(::testing::HasSubstr("Invalid worker rank."))); } TEST_F(FederatedCommTest, ThrowOnRankTooBig) { @@ -38,7 +41,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { config["federated_server_address"] = std::string("localhost:0"); config["federated_world_size"] = std::string("1"); config["federated_rank"] = Integer(0); - FederatedComm comm(config); + FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; }; ExpectThrow("got: `String`", construct); } @@ -49,7 +52,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { config["federated_server_address"] = std::string("localhost:0"); config["federated_world_size"] = 1; config["federated_rank"] = std::string("0"); - FederatedComm comm(config); + FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config); }; ExpectThrow("got: `String`", construct); } @@ -59,7 +62,7 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) { config["federated_world_size"] = 6; config["federated_rank"] = 3; config["federated_server_address"] = String{"localhost:0"}; - FederatedComm comm{config}; + FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; EXPECT_EQ(comm.World(), 6); EXPECT_EQ(comm.Rank(), 3); } diff --git a/tests/cpp/plugin/federated/test_federated_comm_group.cc b/tests/cpp/plugin/federated/test_federated_comm_group.cc new file mode 100644 index 000000000..9bfbdd3ae --- /dev/null +++ b/tests/cpp/plugin/federated/test_federated_comm_group.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include +#include // for Json + +#include "../../../../src/collective/comm_group.h" +#include "../../helpers.h" +#include "test_worker.h" + +namespace xgboost::collective { +TEST(CommGroup, Federated) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederatedGroup(n_workers, [&](std::shared_ptr comm_group, std::int32_t r) { + Context ctx; + ASSERT_EQ(comm_group->Rank(), r); + auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CPU()); + ASSERT_EQ(comm.TaskID(), std::to_string(r)); + ASSERT_EQ(comm.Retry(), 2); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_federated_comm_group.cu b/tests/cpp/plugin/federated/test_federated_comm_group.cu new file mode 100644 index 000000000..747adb6fd --- /dev/null +++ b/tests/cpp/plugin/federated/test_federated_comm_group.cu @@ -0,0 +1,22 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include +#include // for Json + +#include "../../../../src/collective/comm_group.h" +#include "../../helpers.h" +#include "test_worker.h" + +namespace xgboost::collective { +TEST(CommGroup, FederatedGPU) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederatedGroup(n_workers, [&](std::shared_ptr comm_group, std::int32_t r) { + Context ctx = MakeCUDACtx(0); + auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CUDA(0)); + ASSERT_EQ(comm_group->Rank(), r); + ASSERT_EQ(comm.TaskID(), std::to_string(r)); + ASSERT_EQ(comm.Retry(), 2); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h index 38bc32c60..d0edecc15 100644 --- a/tests/cpp/plugin/federated/test_worker.h +++ b/tests/cpp/plugin/federated/test_worker.h @@ -5,10 +5,12 @@ #include -#include // for ms +#include // for ms, seconds +#include // for shared_ptr #include // for thread #include "../../../../plugin/federated/federated_tracker.h" +#include "../../../../src/collective/comm_group.h" #include "federated_comm.h" // for FederatedComm #include "xgboost/json.h" // for Json @@ -23,9 +25,8 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { std::vector workers; using namespace std::chrono_literals; - while (tracker.Port() == 0) { - std::this_thread::sleep_for(100ms); - } + auto rc = tracker.WaitUntilReady(); + ASSERT_TRUE(rc.OK()) << rc.Report(); std::int32_t port = tracker.Port(); for (std::int32_t i = 0; i < n_workers; ++i) { @@ -34,7 +35,8 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { config["federated_world_size"] = n_workers; config["federated_rank"] = i; config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); - auto comm = std::make_shared(config); + auto comm = std::make_shared( + DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config); fn(comm, i); }); @@ -44,7 +46,43 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { t.join(); } - auto rc = tracker.Shutdown(); + rc = tracker.Shutdown(); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_TRUE(fut.get().OK()); +} + +template +void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) { + Json config{Object()}; + config["federated_secure"] = Boolean{false}; + config["n_workers"] = Integer{n_workers}; + FederatedTracker tracker{config}; + auto fut = tracker.Run(); + + std::vector workers; + auto rc = tracker.WaitUntilReady(); + ASSERT_TRUE(rc.OK()) << rc.Report(); + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"federated"}; + config["dmlc_task_id"] = std::to_string(i); + config["dmlc_retry"] = 2; + config["federated_world_size"] = n_workers; + config["federated_rank"] = i; + config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); + std::shared_ptr comm_group{CommGroup::Create(config)}; + fn(comm_group, i); + }); + } + + for (auto& t : workers) { + t.join(); + } + + rc = tracker.Shutdown(); ASSERT_TRUE(rc.OK()) << rc.Report(); ASSERT_TRUE(fut.get().OK()); }