merge latest changes

This commit is contained in:
Hui Liu
2023-12-13 21:06:28 -08:00
194 changed files with 4859 additions and 2838 deletions

View File

@@ -47,7 +47,7 @@ class Worker : public WorkerForTest {
std::size_t n = 8192; // n_bytes = 8192 * sizeof(int)
std::vector<std::int32_t> data(comm_.World() * n, 0);
auto s_data = common::Span{data.data(), data.size()};
auto s_data = common::Span<std::int32_t>{data};
auto seg = s_data.subspan(comm_.Rank() * n, n);
std::iota(seg.begin(), seg.end(), comm_.Rank());

View File

@@ -90,10 +90,10 @@ class Worker : public NCCLWorkerForTest {
}
};
class AllgatherTestGPU : public SocketTest {};
class MGPUAllgatherTest : public SocketTest {};
} // namespace
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
TEST_F(MGPUAllgatherTest, MGPUTestVRing) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
@@ -104,7 +104,7 @@ TEST_F(AllgatherTestGPU, MGPUTestVRing) {
});
}
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
TEST_F(MGPUAllgatherTest, MGPUTestVBcast) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {

View File

@@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
void Basic() {
{
std::vector<double> data(13, 0.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
}
{
std::vector<double> data(1, 1.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
}
}
void Acc() {
std::vector<double> data(314, 1.5);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;

View File

@@ -5,17 +5,15 @@
#include <gtest/gtest.h>
#include <thrust/host_vector.h> // for host_vector
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/common/common.h"
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
#include "../../../src/common/type.h" // for EraseType
#include "../helpers.h" // for MakeCUDACtx
#include "test_worker.cuh" // for NCCLWorkerForTest
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
namespace {
class AllreduceTestGPU : public SocketTest {};
class MGPUAllreduceTest : public SocketTest {};
class Worker : public NCCLWorkerForTest {
public:
@@ -47,7 +45,7 @@ class Worker : public NCCLWorkerForTest {
};
} // namespace
TEST_F(AllreduceTestGPU, BitOr) {
TEST_F(MGPUAllreduceTest, BitOr) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
@@ -57,7 +55,7 @@ TEST_F(AllreduceTestGPU, BitOr) {
});
}
TEST_F(AllreduceTestGPU, Sum) {
TEST_F(MGPUAllreduceTest, Sum) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {

View File

@@ -0,0 +1,63 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <chrono> // for ""s
#include <thread> // for thread
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for SocketTest
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace {
class TrackerAPITest : public SocketTest {};
} // namespace
TEST_F(TrackerAPITest, CAPI) {
TrackerHandle handle;
Json config{Object{}};
config["dmlc_communicator"] = String{"rabit"};
config["n_workers"] = 2;
config["timeout"] = 1;
auto config_str = Json::Dump(config);
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
ASSERT_EQ(rc, 0);
rc = XGTrackerRun(handle);
ASSERT_EQ(rc, 0);
std::thread bg_wait{[&] {
Json config{Object{}};
auto config_str = Json::Dump(config);
auto rc = XGTrackerWait(handle, config_str.c_str());
ASSERT_EQ(rc, 0);
}};
char const* cargs;
rc = XGTrackerWorkerArgs(handle, &cargs);
ASSERT_EQ(rc, 0);
auto args = Json::Load(StringView{cargs});
std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
ASSERT_NE(port, 0);
std::vector<std::thread> workers;
using namespace std::chrono_literals; // NOLINT
for (std::int32_t r = 0; r < 2; ++r) {
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
}
for (auto& w : workers) {
w.join();
}
rc = XGTrackerFree(handle);
ASSERT_EQ(rc, 0);
bg_wait.join();
}
} // namespace xgboost::collective

View File

@@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) {
WorkerForTest worker{host, port, timeout, n_workers, i};
if (i % 2 == 0) {
auto p_chan = worker.Comm().Chan(i + 1);
p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
auto rc = Success() << [&] {
return p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
} else {
auto p_chan = worker.Comm().Chan(i - 1);
std::int32_t r{-1};
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
auto rc = Success() << [&] {
return p_chan->RecvAll(
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(r, i - 1);
}

View File

@@ -0,0 +1,63 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/json.h> // for Json
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <string> // for string
#include <thread> // for thread
#include "../../../src/collective/comm.h"
#include "../../../src/collective/comm_group.h"
#include "../../../src/common/common.h" // for AllVisibleGPUs
#include "../helpers.h" // for MakeCUDACtx
#include "test_worker.h" // for TestDistributed
namespace xgboost::collective {
namespace {
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
Json config{Object{}};
config["dmlc_communicator"] = std::string{"rabit"};
config["DMLC_TRACKER_URI"] = host;
config["DMLC_TRACKER_PORT"] = port;
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
config["DMLC_TASK_ID"] = std::to_string(r);
config["dmlc_retry"] = 2;
return config;
}
class CommGroupTest : public SocketTest {};
} // namespace
TEST_F(CommGroupTest, Basic) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 5u);
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Context ctx;
auto config = MakeConfig(host, port, timeout, r);
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
ASSERT_TRUE(ptr->IsDistributed());
ASSERT_EQ(ptr->World(), n_workers);
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CPU());
ASSERT_EQ(comm.TaskID(), std::to_string(r));
ASSERT_EQ(comm.Retry(), 2);
});
}
#if defined(XGBOOST_USE_NCCL)
TEST_F(CommGroupTest, BasicGPU) {
std::int32_t n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
auto ctx = MakeCUDACtx(r);
auto config = MakeConfig(host, port, timeout, r);
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
ASSERT_EQ(comm.TaskID(), std::to_string(r));
ASSERT_EQ(comm.Retry(), 2);
});
}
#endif // for defined(XGBOOST_USE_NCCL)
} // namespace xgboost::collective

View File

@@ -8,6 +8,7 @@
#include <bitset>
#include <string> // for string
#include "../../../src/collective/comm.cuh"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/nccl_device_communicator.cuh"
#include "../helpers.h"
@@ -16,17 +17,15 @@ namespace xgboost {
namespace collective {
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
auto construct = []() { NcclDeviceCommunicator comm{-1, false, DefaultNcclName()}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
try {
dh::safe_nccl(ncclSystemError);
} catch (dmlc::Error const& e) {
auto str = std::string{e.what()};
ASSERT_TRUE(str.find("environment variables") != std::string::npos);
}
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
auto rc = stub->GetNcclResult(ncclSystemError);
auto msg = rc.Report();
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
}
namespace {

View File

@@ -33,7 +33,7 @@ class WorkerForTest {
tracker_port_{port},
world_size_{world},
task_id_{"t:" + std::to_string(rank)},
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
CHECK_EQ(world_size_, comm_.World());
}
virtual ~WorkerForTest() = default;
@@ -92,10 +92,12 @@ class TrackerTest : public SocketTest {
template <typename WorkerFn>
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
std::chrono::seconds timeout{1};
std::chrono::seconds timeout{2};
std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK()) << rc.Report();
LOG(INFO) << "Using " << n_workers << " workers for test.";
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
auto fut = tracker.Run();