[coll] Define interface for bridging. (#9695)

* Define the basic interface that will shared by nccl, federated and native.
This commit is contained in:
Jiaming Yuan
2023-10-20 16:20:48 +08:00
committed by GitHub
parent 6fbe6248f4
commit b771f58453
7 changed files with 174 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
#include <gtest/gtest.h>
#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for WorkerForTest, TestDistributed
@@ -47,6 +48,19 @@ class AllreduceWorker : public WorkerForTest {
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}
void BitOr() {
Context ctx;
std::vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto pcoll = std::make_shared<Coll>();
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (auto v : data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}
}
};
class AllreduceTest : public SocketTest {};
@@ -69,4 +83,13 @@ TEST_F(AllreduceTest, Sum) {
worker.Acc();
});
}
TEST_F(AllreduceTest, BitOr) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.BitOr();
});
}
} // namespace xgboost::collective