/** * Copyright 2022-2023, XGBoost contributors */ #include #include // for Span #include // for array #include "../../../../src/common/type.h" // for EraseType #include "../../collective/test_worker.h" // for SocketTest #include "federated_coll.h" // for FederatedColl #include "federated_comm.h" // for FederatedComm #include "test_worker.h" // for TestFederated namespace xgboost::collective { namespace { class FederatedCollTest : public SocketTest {}; } // namespace TEST_F(FederatedCollTest, Allreduce) { std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { std::array buffer = {1, 2, 3, 4, 5}; std::array expected; std::transform(buffer.cbegin(), buffer.cend(), expected.begin(), [=](auto i) { return i * n_workers; }); auto coll = std::make_shared(); auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), ArrayInterfaceHandler::kI4, Op::kSum); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < 5; i++) { ASSERT_EQ(buffer[i], expected[i]); } }); } TEST_F(FederatedCollTest, Broadcast) { std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { FederatedColl coll{}; auto rc = Success(); if (comm->Rank() == 0) { std::string buffer{"hello"}; rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); ASSERT_EQ(buffer, "hello"); } else { std::string buffer{" "}; rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); ASSERT_EQ(buffer, "hello"); } ASSERT_TRUE(rc.OK()); }); } TEST_F(FederatedCollTest, Allgather) { std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { FederatedColl coll{}; std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); } }); } TEST_F(FederatedCollTest, AllgatherV) { std::int32_t n_workers = 2; TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { FederatedColl coll{}; std::vector inputs{"Federated", " Learning!!!"}; std::vector recv_segments(inputs.size() + 1, 0); std::string r; std::vector sizes{static_cast(inputs[0].size()), static_cast(inputs[1].size())}; r.resize(sizes[0] + sizes[1]); auto rc = coll.AllgatherV( *comm, common::EraseType(common::Span{inputs[comm->Rank()].data(), inputs[comm->Rank()].size()}), common::Span{sizes.data(), sizes.size()}, recv_segments, common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing); EXPECT_EQ(r, "Federated Learning!!!"); ASSERT_TRUE(rc.OK()); }); } } // namespace xgboost::collective