/** * Copyright 2023, XGBoost Contributors */ #include #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/tracker.h" #include "../../../src/common/type.h" // for EraseType #include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { namespace { class AllreduceWorker : public WorkerForTest { public: using WorkerForTest::WorkerForTest; void Basic() { { std::vector data(13, 0.0); Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { for (std::size_t i = 0; i < rhs.size(); ++i) { rhs[i] += lhs[i]; } }); ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0); } { std::vector data(1, 1.0); Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { for (std::size_t i = 0; i < rhs.size(); ++i) { rhs[i] += lhs[i]; } }); ASSERT_EQ(data[0], static_cast(comm_.World())); } } void Acc() { std::vector data(314, 1.5); Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { for (std::size_t i = 0; i < rhs.size(); ++i) { rhs[i] += lhs[i]; } }); for (std::size_t i = 0; i < data.size(); ++i) { auto v = data[i]; ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; } } void BitOr() { std::vector data(comm_.World(), 0); data[comm_.Rank()] = ~std::uint32_t{0}; auto pcoll = std::shared_ptr{new Coll{}}; auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); ASSERT_TRUE(rc.OK()) << rc.Report(); for (auto v : data) { ASSERT_EQ(v, ~std::uint32_t{0}); } } }; class AllreduceTest : public SocketTest {}; } // namespace TEST_F(AllreduceTest, Basic) { std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { AllreduceWorker worker{host, port, timeout, n_workers, r}; worker.Basic(); }); } TEST_F(AllreduceTest, Sum) { std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { AllreduceWorker worker{host, port, timeout, n_workers, r}; worker.Acc(); }); } TEST_F(AllreduceTest, BitOr) { std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { AllreduceWorker worker{host, port, timeout, n_workers, r}; worker.BitOr(); }); } } // namespace xgboost::collective