[coll] Improvements and fixes for tracker and allreduce. (#9745)
- Allow the tracker to wait. - Fix allreduce type cast - Return args from the federated tracker.
This commit is contained in:
@@ -27,9 +27,15 @@ class PrintWorker : public WorkerForTest {
|
||||
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
ASSERT_FALSE(tracker.Ready());
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
auto args = tracker.WorkerArgs();
|
||||
ASSERT_TRUE(tracker.Ready());
|
||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
@@ -47,6 +53,9 @@ TEST_F(TrackerTest, Print) {
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
|
||||
Reference in New Issue
Block a user