[coll] Improvements and fixes for tracker and allreduce. (#9745)

- Allow the tracker to wait.
- Fix allreduce type cast
- Return args from the federated tracker.
This commit is contained in:
Jiaming Yuan
2023-11-02 04:06:46 +08:00
committed by GitHub
parent 0ff8572737
commit 4da4e092b5
8 changed files with 184 additions and 57 deletions

View File

@@ -27,9 +27,15 @@ class PrintWorker : public WorkerForTest {
TEST_F(TrackerTest, Bootstrap) {
RabitTracker tracker{host, n_workers, 0, timeout};
ASSERT_FALSE(tracker.Ready());
auto fut = tracker.Run();
std::vector<std::thread> workers;
auto args = tracker.WorkerArgs();
ASSERT_TRUE(tracker.Ready());
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
@@ -47,6 +53,9 @@ TEST_F(TrackerTest, Print) {
auto fut = tracker.Run();
std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK());
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {