merge latest change from upstream
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h> // for ASSERT_EQ
|
||||
#include <xgboost/span.h> // for Span, oper...
|
||||
@@ -34,8 +34,8 @@ class Worker : public WorkerForTest {
|
||||
std::vector<std::int32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = comm_.Rank();
|
||||
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()});
|
||||
SafeColl(rc);
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
ASSERT_EQ(data[r], r);
|
||||
@@ -51,8 +51,8 @@ class Worker : public WorkerForTest {
|
||||
auto seg = s_data.subspan(comm_.Rank() * n, n);
|
||||
std::iota(seg.begin(), seg.end(), comm_.Rank());
|
||||
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()});
|
||||
SafeColl(rc);
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
auto seg = s_data.subspan(r * n, n);
|
||||
@@ -81,7 +81,7 @@ class Worker : public WorkerForTest {
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||
CheckV(result);
|
||||
}
|
||||
@@ -91,7 +91,7 @@ class Worker : public WorkerForTest {
|
||||
std::int32_t n{comm_.Rank()};
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
@@ -104,8 +104,8 @@ class Worker : public WorkerForTest {
|
||||
|
||||
std::vector<std::int64_t> sizes(comm_.World(), 0);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
SafeColl(rc);
|
||||
std::shared_ptr<Coll> pcoll{new Coll{}};
|
||||
|
||||
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
#include <gtest/gtest.h>
|
||||
@@ -33,8 +33,8 @@ class Worker : public NCCLWorkerForTest {
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(comm_.World(), -1);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
SafeColl(rc);
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(comm_.World(), -1);
|
||||
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||
@@ -42,7 +42,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
@@ -57,8 +57,8 @@ class Worker : public NCCLWorkerForTest {
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
|
||||
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
SafeColl(rc);
|
||||
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
|
||||
@@ -67,7 +67,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
// check segment size
|
||||
if (algo != AllgatherVAlgo::kBcast) {
|
||||
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <numeric> // for iota
|
||||
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
@@ -58,7 +59,7 @@ class AllreduceWorker : public WorkerForTest {
|
||||
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
|
||||
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
for (auto v : data) {
|
||||
ASSERT_EQ(v, ~std::uint32_t{0});
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
@@ -24,7 +24,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
thrust::host_vector<std::uint32_t> h_data(data.size());
|
||||
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
|
||||
for (auto v : h_data) {
|
||||
@@ -36,7 +36,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
dh::device_vector<double> data(314, 1.5);
|
||||
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||
ArrayInterfaceHandler::kF8, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
@@ -10,7 +10,6 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/broadcast.h" // for Broadcast
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -24,14 +23,14 @@ class Worker : public WorkerForTest {
|
||||
// basic test
|
||||
std::vector<std::int32_t> data(1, comm_.Rank());
|
||||
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
ASSERT_EQ(data[0], r);
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
std::vector<std::int32_t> data(1 << 16, comm_.Rank());
|
||||
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
ASSERT_EQ(data[0], r);
|
||||
}
|
||||
}
|
||||
@@ -41,11 +40,11 @@ class BroadcastTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(BroadcastTest, Basic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.Run();
|
||||
});
|
||||
} // namespace
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -25,13 +25,13 @@ TEST_F(TrackerAPITest, CAPI) {
|
||||
auto config_str = Json::Dump(config);
|
||||
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
rc = XGTrackerRun(handle);
|
||||
rc = XGTrackerRun(handle, nullptr);
|
||||
ASSERT_EQ(rc, 0);
|
||||
|
||||
std::thread bg_wait{[&] {
|
||||
Json config{Object{}};
|
||||
auto config_str = Json::Dump(config);
|
||||
auto rc = XGTrackerWait(handle, config_str.c_str());
|
||||
auto rc = XGTrackerWaitFor(handle, config_str.c_str());
|
||||
ASSERT_EQ(rc, 0);
|
||||
}};
|
||||
|
||||
@@ -42,8 +42,8 @@ TEST_F(TrackerAPITest, CAPI) {
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
|
||||
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
|
||||
ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"]));
|
||||
auto port = get<Integer const>(args["dmlc_tracker_port"]);
|
||||
ASSERT_NE(port, 0);
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@@ -14,7 +14,7 @@ class CommTest : public TrackerTest {};
|
||||
|
||||
TEST_F(CommTest, Channel) {
|
||||
auto n_workers = 4;
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
@@ -29,7 +29,7 @@ TEST_F(CommTest, Channel) {
|
||||
return p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
} << [&] { return p_chan->Block(); };
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
auto p_chan = worker.Comm().Chan(i - 1);
|
||||
std::int32_t r{-1};
|
||||
@@ -37,7 +37,7 @@ TEST_F(CommTest, Channel) {
|
||||
return p_chan->RecvAll(
|
||||
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
} << [&] { return p_chan->Block(); };
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
ASSERT_EQ(r, i - 1);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -17,17 +17,6 @@
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"rabit"};
|
||||
config["DMLC_TRACKER_URI"] = host;
|
||||
config["DMLC_TRACKER_PORT"] = port;
|
||||
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
||||
config["DMLC_TASK_ID"] = std::to_string(r);
|
||||
config["dmlc_retry"] = 2;
|
||||
return config;
|
||||
}
|
||||
|
||||
class CommGroupTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
@@ -36,7 +25,7 @@ TEST_F(CommGroupTest, Basic) {
|
||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Context ctx;
|
||||
auto config = MakeConfig(host, port, timeout, r);
|
||||
auto config = MakeDistributedTestConfig(host, port, timeout, r);
|
||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||
ASSERT_TRUE(ptr->IsDistributed());
|
||||
ASSERT_EQ(ptr->World(), n_workers);
|
||||
@@ -52,7 +41,7 @@ TEST_F(CommGroupTest, BasicGPU) {
|
||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
auto ctx = MakeCUDACtx(r);
|
||||
auto config = MakeConfig(host, port, timeout, r);
|
||||
auto config = MakeDistributedTestConfig(host, port, timeout, r);
|
||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ
|
||||
#include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup
|
||||
@@ -28,18 +28,23 @@ class LoopTest : public ::testing::Test {
|
||||
|
||||
auto domain = SockDomain::kV4;
|
||||
pair_.first = TCPSocket::Create(domain);
|
||||
auto port = pair_.first.BindHost();
|
||||
pair_.first.Listen();
|
||||
std::int32_t port{0};
|
||||
auto rc = Success() << [&] {
|
||||
return pair_.first.BindHost(&port);
|
||||
} << [&] {
|
||||
return pair_.first.Listen();
|
||||
};
|
||||
SafeColl(rc);
|
||||
|
||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||
auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
|
||||
SafeColl(rc);
|
||||
rc = pair_.second.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
|
||||
pair_.first = pair_.first.Accept();
|
||||
rc = pair_.first.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
|
||||
loop_ = std::shared_ptr<Loop>{new Loop{timeout}};
|
||||
}
|
||||
@@ -74,8 +79,26 @@ TEST_F(LoopTest, Op) {
|
||||
loop_->Submit(rop);
|
||||
|
||||
auto rc = loop_->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
ASSERT_EQ(rbuf[0], wbuf[0]);
|
||||
}
|
||||
|
||||
TEST_F(LoopTest, Block) {
|
||||
// We need to ensure that a blocking call doesn't go unanswered.
|
||||
auto op = Loop::Op::Sleep(2);
|
||||
|
||||
common::Timer t;
|
||||
t.Start();
|
||||
loop_->Submit(op);
|
||||
t.Stop();
|
||||
// submit is non-blocking
|
||||
ASSERT_LT(t.ElapsedSeconds(), 1);
|
||||
|
||||
t.Start();
|
||||
auto rc = loop_->Block();
|
||||
t.Stop();
|
||||
SafeColl(rc);
|
||||
ASSERT_GE(t.ElapsedSeconds(), 1);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
31
tests/cpp/collective/test_result.cc
Normal file
31
tests/cpp/collective/test_result.cc
Normal file
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/result.h>
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(Result, Concat) {
|
||||
auto rc0 = Fail("foo");
|
||||
auto rc1 = Fail("bar");
|
||||
auto rc = std::move(rc0) + std::move(rc1);
|
||||
ASSERT_NE(rc.Report().find("foo"), std::string::npos);
|
||||
ASSERT_NE(rc.Report().find("bar"), std::string::npos);
|
||||
|
||||
auto rc2 = Fail("Another", std::move(rc));
|
||||
auto assert_that = [](Result const& rc) {
|
||||
ASSERT_NE(rc.Report().find("Another"), std::string::npos);
|
||||
ASSERT_NE(rc.Report().find("foo"), std::string::npos);
|
||||
ASSERT_NE(rc.Report().find("bar"), std::string::npos);
|
||||
};
|
||||
assert_that(rc2);
|
||||
|
||||
auto empty = Success();
|
||||
auto rc3 = std::move(empty) + std::move(rc2);
|
||||
assert_that(rc3);
|
||||
|
||||
empty = Success();
|
||||
auto rc4 = std::move(rc3) + std::move(empty);
|
||||
assert_that(rc4);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost Contributors
|
||||
* Copyright 2022-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
@@ -21,14 +21,19 @@ TEST_F(SocketTest, Basic) {
|
||||
auto run_test = [msg](SockDomain domain) {
|
||||
auto server = TCPSocket::Create(domain);
|
||||
ASSERT_EQ(server.Domain(), domain);
|
||||
auto port = server.BindHost();
|
||||
server.Listen();
|
||||
std::int32_t port{0};
|
||||
auto rc = Success() << [&] {
|
||||
return server.BindHost(&port);
|
||||
} << [&] {
|
||||
return server.Listen();
|
||||
};
|
||||
SafeColl(rc);
|
||||
|
||||
TCPSocket client;
|
||||
if (domain == SockDomain::kV4) {
|
||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
auto const& addr = SockAddrV6::Loopback().Addr();
|
||||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||
@@ -45,7 +50,8 @@ TEST_F(SocketTest, Basic) {
|
||||
accepted.Send(msg);
|
||||
|
||||
std::string str;
|
||||
client.Recv(&str);
|
||||
rc = client.Recv(&str);
|
||||
SafeColl(rc);
|
||||
ASSERT_EQ(StringView{str}, msg);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
@@ -10,6 +11,7 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "../helpers.h" // for GMockThrow
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -20,13 +22,13 @@ class PrintWorker : public WorkerForTest {
|
||||
|
||||
void Print() {
|
||||
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||
ASSERT_FALSE(tracker.Ready());
|
||||
auto fut = tracker.Run();
|
||||
|
||||
@@ -34,7 +36,7 @@ TEST_F(TrackerTest, Bootstrap) {
|
||||
|
||||
auto args = tracker.WorkerArgs();
|
||||
ASSERT_TRUE(tracker.Ready());
|
||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
@@ -44,12 +46,11 @@ TEST_F(TrackerTest, Bootstrap) {
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
SafeColl(fut.get());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, Print) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
@@ -73,4 +74,47 @@ TEST_F(TrackerTest, Print) {
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
|
||||
|
||||
/**
|
||||
* Test connecting the tracker after it has finished. This should not hang the workers.
|
||||
*/
|
||||
TEST_F(TrackerTest, AfterShutdown) {
|
||||
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
// Launch no-op workers to cause the tracker to shutdown.
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
|
||||
// Launch workers again, they should fail.
|
||||
workers.clear();
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
auto assert_that = [=] {
|
||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||
};
|
||||
// On a Linux platform, the connection will be refused, on Apple platform, this gets
|
||||
// an operation now in progress poll failure, on Windows, it's a timeout error.
|
||||
#if defined(__linux__)
|
||||
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); });
|
||||
#else
|
||||
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); });
|
||||
#endif
|
||||
}
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <fstream> // for ifstream
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move
|
||||
@@ -36,7 +37,7 @@ class WorkerForTest {
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
|
||||
CHECK_EQ(world_size_, comm_.World());
|
||||
}
|
||||
virtual ~WorkerForTest() = default;
|
||||
virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); }
|
||||
auto& Comm() { return comm_; }
|
||||
|
||||
void LimitSockBuf(std::int32_t n_bytes) {
|
||||
@@ -86,19 +87,30 @@ class TrackerTest : public SocketTest {
|
||||
void SetUp() override {
|
||||
SocketTest::SetUp();
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
}
|
||||
};
|
||||
|
||||
inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers,
|
||||
std::chrono::seconds timeout) {
|
||||
Json config{Object{}};
|
||||
config["host"] = host;
|
||||
config["port"] = Integer{0};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
config["sortby"] = Integer{static_cast<std::int32_t>(Tracker::SortBy::kHost)};
|
||||
config["timeout"] = timeout.count();
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||
std::chrono::seconds timeout{2};
|
||||
|
||||
std::string host;
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
LOG(INFO) << "Using " << n_workers << " workers for test.";
|
||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
@@ -114,4 +126,15 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
|
||||
std::chrono::seconds timeout, std::int32_t r) {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"rabit"};
|
||||
config["dmlc_tracker_uri"] = host;
|
||||
config["dmlc_tracker_port"] = port;
|
||||
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
||||
config["dmlc_task_id"] = std::to_string(r);
|
||||
config["dmlc_retry"] = 2;
|
||||
return config;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user