merge latest changes
This commit is contained in:
@@ -47,7 +47,7 @@ class Worker : public WorkerForTest {
|
||||
|
||||
std::size_t n = 8192; // n_bytes = 8192 * sizeof(int)
|
||||
std::vector<std::int32_t> data(comm_.World() * n, 0);
|
||||
auto s_data = common::Span{data.data(), data.size()};
|
||||
auto s_data = common::Span<std::int32_t>{data};
|
||||
auto seg = s_data.subspan(comm_.Rank() * n, n);
|
||||
std::iota(seg.begin(), seg.end(), comm_.Rank());
|
||||
|
||||
|
||||
@@ -90,10 +90,10 @@ class Worker : public NCCLWorkerForTest {
|
||||
}
|
||||
};
|
||||
|
||||
class AllgatherTestGPU : public SocketTest {};
|
||||
class MGPUAllgatherTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||
TEST_F(MGPUAllgatherTest, MGPUTestVRing) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
@@ -104,7 +104,7 @@ TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
|
||||
TEST_F(MGPUAllgatherTest, MGPUTestVBcast) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
|
||||
@@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
|
||||
void Basic() {
|
||||
{
|
||||
std::vector<double> data(13, 0.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
||||
}
|
||||
{
|
||||
std::vector<double> data(1, 1.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
||||
}
|
||||
}
|
||||
|
||||
void Acc() {
|
||||
std::vector<double> data(314, 1.5);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
|
||||
@@ -5,17 +5,15 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class AllreduceTestGPU : public SocketTest {};
|
||||
class MGPUAllreduceTest : public SocketTest {};
|
||||
|
||||
class Worker : public NCCLWorkerForTest {
|
||||
public:
|
||||
@@ -47,7 +45,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllreduceTestGPU, BitOr) {
|
||||
TEST_F(MGPUAllreduceTest, BitOr) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
@@ -57,7 +55,7 @@ TEST_F(AllreduceTestGPU, BitOr) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTestGPU, Sum) {
|
||||
TEST_F(MGPUAllreduceTest, Sum) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
|
||||
63
tests/cpp/collective/test_coll_c_api.cc
Normal file
63
tests/cpp/collective/test_coll_c_api.cc
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
#include <chrono> // for ""s
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "test_worker.h" // for SocketTest
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class TrackerAPITest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(TrackerAPITest, CAPI) {
|
||||
TrackerHandle handle;
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = String{"rabit"};
|
||||
config["n_workers"] = 2;
|
||||
config["timeout"] = 1;
|
||||
auto config_str = Json::Dump(config);
|
||||
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
rc = XGTrackerRun(handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
|
||||
std::thread bg_wait{[&] {
|
||||
Json config{Object{}};
|
||||
auto config_str = Json::Dump(config);
|
||||
auto rc = XGTrackerWait(handle, config_str.c_str());
|
||||
ASSERT_EQ(rc, 0);
|
||||
}};
|
||||
|
||||
char const* cargs;
|
||||
rc = XGTrackerWorkerArgs(handle, &cargs);
|
||||
ASSERT_EQ(rc, 0);
|
||||
auto args = Json::Load(StringView{cargs});
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
|
||||
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
|
||||
ASSERT_NE(port, 0);
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
for (std::int32_t r = 0; r < 2; ++r) {
|
||||
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
|
||||
}
|
||||
for (auto& w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
rc = XGTrackerFree(handle);
|
||||
ASSERT_EQ(rc, 0);
|
||||
|
||||
bg_wait.join();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) {
|
||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||
if (i % 2 == 0) {
|
||||
auto p_chan = worker.Comm().Chan(i + 1);
|
||||
p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
auto rc = Success() << [&] {
|
||||
return p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
} << [&] { return p_chan->Block(); };
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
} else {
|
||||
auto p_chan = worker.Comm().Chan(i - 1);
|
||||
std::int32_t r{-1};
|
||||
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
auto rc = Success() << [&] {
|
||||
return p_chan->RecvAll(
|
||||
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
} << [&] { return p_chan->Block(); };
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(r, i - 1);
|
||||
}
|
||||
|
||||
63
tests/cpp/collective/test_comm_group.cc
Normal file
63
tests/cpp/collective/test_comm_group.cc
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h> // for Json
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "../../../src/collective/comm_group.h"
|
||||
#include "../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
#include "test_worker.h" // for TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"rabit"};
|
||||
config["DMLC_TRACKER_URI"] = host;
|
||||
config["DMLC_TRACKER_PORT"] = port;
|
||||
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
||||
config["DMLC_TASK_ID"] = std::to_string(r);
|
||||
config["dmlc_retry"] = 2;
|
||||
return config;
|
||||
}
|
||||
|
||||
class CommGroupTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(CommGroupTest, Basic) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 5u);
|
||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Context ctx;
|
||||
auto config = MakeConfig(host, port, timeout, r);
|
||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||
ASSERT_TRUE(ptr->IsDistributed());
|
||||
ASSERT_EQ(ptr->World(), n_workers);
|
||||
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CPU());
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
TEST_F(CommGroupTest, BasicGPU) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
auto ctx = MakeCUDACtx(r);
|
||||
auto config = MakeConfig(host, port, timeout, r);
|
||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
#endif // for defined(XGBOOST_USE_NCCL)
|
||||
} // namespace xgboost::collective
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <bitset>
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../src/collective/comm.cuh"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/collective/nccl_device_communicator.cuh"
|
||||
#include "../helpers.h"
|
||||
@@ -16,17 +17,15 @@ namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||
auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
|
||||
auto construct = []() { NcclDeviceCommunicator comm{-1, false, DefaultNcclName()}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
|
||||
try {
|
||||
dh::safe_nccl(ncclSystemError);
|
||||
} catch (dmlc::Error const& e) {
|
||||
auto str = std::string{e.what()};
|
||||
ASSERT_TRUE(str.find("environment variables") != std::string::npos);
|
||||
}
|
||||
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
|
||||
auto rc = stub->GetNcclResult(ncclSystemError);
|
||||
auto msg = rc.Report();
|
||||
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -33,7 +33,7 @@ class WorkerForTest {
|
||||
tracker_port_{port},
|
||||
world_size_{world},
|
||||
task_id_{"t:" + std::to_string(rank)},
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
|
||||
CHECK_EQ(world_size_, comm_.World());
|
||||
}
|
||||
virtual ~WorkerForTest() = default;
|
||||
@@ -92,10 +92,12 @@ class TrackerTest : public SocketTest {
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||
std::chrono::seconds timeout{1};
|
||||
std::chrono::seconds timeout{2};
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
LOG(INFO) << "Using " << n_workers << " workers for test.";
|
||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user