[coll] Add comm group. (#9759)
- Implement `CommGroup` for double dispatching. - Small cleanup to tracker for handling abort.
This commit is contained in:
63
tests/cpp/collective/test_comm_group.cc
Normal file
63
tests/cpp/collective/test_comm_group.cc
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h> // for Json
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "../../../src/collective/comm_group.h"
|
||||
#include "../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
#include "test_worker.h" // for TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"rabit"};
|
||||
config["DMLC_TRACKER_URI"] = host;
|
||||
config["DMLC_TRACKER_PORT"] = port;
|
||||
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
||||
config["DMLC_TASK_ID"] = std::to_string(r);
|
||||
config["dmlc_retry"] = 2;
|
||||
return config;
|
||||
}
|
||||
|
||||
class CommGroupTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(CommGroupTest, Basic) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 5u);
|
||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Context ctx;
|
||||
auto config = MakeConfig(host, port, timeout, r);
|
||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||
ASSERT_TRUE(ptr->IsDistributed());
|
||||
ASSERT_EQ(ptr->World(), n_workers);
|
||||
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CPU());
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
TEST_F(CommGroupTest, BasicGPU) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
auto ctx = MakeCUDACtx(r);
|
||||
auto config = MakeConfig(host, port, timeout, r);
|
||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
#endif // for defined(XGBOOST_USE_NCCL)
|
||||
} // namespace xgboost::collective
|
||||
@@ -95,7 +95,8 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||
std::chrono::seconds timeout{1};
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
|
||||
@@ -15,6 +15,15 @@
|
||||
namespace xgboost::linalg {
|
||||
namespace {
|
||||
DeviceOrd CPU() { return DeviceOrd::CPU(); }
|
||||
|
||||
template <typename T>
|
||||
void ConstView(linalg::VectorView<T> v1, linalg::VectorView<std::add_const_t<T>> v2) {
|
||||
// compile test for being able to pass non-const view to const view.
|
||||
auto s = v1.Slice(linalg::All());
|
||||
ASSERT_EQ(s.Size(), v1.Size());
|
||||
auto s2 = v2.Slice(linalg::All());
|
||||
ASSERT_EQ(s2.Size(), v2.Size());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
|
||||
@@ -206,6 +215,11 @@ TEST(Linalg, TensorView) {
|
||||
ASSERT_TRUE(t.FContiguous());
|
||||
ASSERT_FALSE(t.CContiguous());
|
||||
}
|
||||
{
|
||||
// const
|
||||
TensorView<double, 1> t{data, {data.size()}, CPU()};
|
||||
ConstView(t, t);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Linalg, Tensor) {
|
||||
|
||||
@@ -124,6 +124,9 @@ TEST_F(FederatedCollTestGPU, Allgather) {
|
||||
|
||||
TEST_F(FederatedCollTestGPU, AllgatherV) {
|
||||
std::int32_t n_workers = 2;
|
||||
if (common::AllVisibleGPUs() < n_workers) {
|
||||
GTEST_SKIP_("At least 2 GPUs are required for the test.");
|
||||
}
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
TestAllgatherV(comm, rank);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string> // for string
|
||||
@@ -19,12 +20,14 @@ class FederatedCommTest : public SocketTest {};
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
||||
ASSERT_THAT(construct,
|
||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid world size")));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
ASSERT_THAT(construct,
|
||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid worker rank.")));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
@@ -38,7 +41,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedComm comm(config);
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
@@ -49,7 +52,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedComm comm(config);
|
||||
FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
@@ -59,7 +62,7 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
config["federated_world_size"] = 6;
|
||||
config["federated_rank"] = 3;
|
||||
config["federated_server_address"] = String{"localhost:0"};
|
||||
FederatedComm comm{config};
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
EXPECT_EQ(comm.World(), 6);
|
||||
EXPECT_EQ(comm.Rank(), 3);
|
||||
}
|
||||
|
||||
22
tests/cpp/plugin/federated/test_federated_comm_group.cc
Normal file
22
tests/cpp/plugin/federated/test_federated_comm_group.cc
Normal file
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h> // for Json
|
||||
|
||||
#include "../../../../src/collective/comm_group.h"
|
||||
#include "../../helpers.h"
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(CommGroup, Federated) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederatedGroup(n_workers, [&](std::shared_ptr<CommGroup> comm_group, std::int32_t r) {
|
||||
Context ctx;
|
||||
ASSERT_EQ(comm_group->Rank(), r);
|
||||
auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CPU());
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
22
tests/cpp/plugin/federated/test_federated_comm_group.cu
Normal file
22
tests/cpp/plugin/federated/test_federated_comm_group.cu
Normal file
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h> // for Json
|
||||
|
||||
#include "../../../../src/collective/comm_group.h"
|
||||
#include "../../helpers.h"
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(CommGroup, FederatedGPU) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederatedGroup(n_workers, [&](std::shared_ptr<CommGroup> comm_group, std::int32_t r) {
|
||||
Context ctx = MakeCUDACtx(0);
|
||||
auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||
ASSERT_EQ(comm_group->Rank(), r);
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -5,10 +5,12 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <chrono> // for ms, seconds
|
||||
#include <memory> // for shared_ptr
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "../../../../src/collective/comm_group.h"
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
@@ -23,9 +25,8 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals;
|
||||
while (tracker.Port() == 0) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
@@ -34,7 +35,8 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
auto comm = std::make_shared<FederatedComm>(config);
|
||||
auto comm = std::make_shared<FederatedComm>(
|
||||
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
@@ -44,7 +46,43 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
auto rc = tracker.Shutdown();
|
||||
rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
FederatedTracker tracker{config};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"federated"};
|
||||
config["dmlc_task_id"] = std::to_string(i);
|
||||
config["dmlc_retry"] = 2;
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
|
||||
fn(comm_group, i);
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user