[coll] Improvements and fixes for tracker and allreduce. (#9745)
- Allow the tracker to wait. - Fix allreduce type cast - Return args from the federated tracker.
This commit is contained in:
@@ -27,9 +27,15 @@ class PrintWorker : public WorkerForTest {
|
||||
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
ASSERT_FALSE(tracker.Ready());
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
auto args = tracker.WorkerArgs();
|
||||
ASSERT_TRUE(tracker.Ready());
|
||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
@@ -47,6 +53,9 @@ TEST_F(TrackerTest, Print) {
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
|
||||
36
tests/cpp/plugin/federated/test_federated_tracker.cc
Normal file
36
tests/cpp/plugin/federated/test_federated_tracker.cc
Normal file
@@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory> // for make_unique
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "federated_tracker.h"
|
||||
#include "test_worker.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(FederatedTrackerTest, Basic) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{3};
|
||||
|
||||
auto tracker = std::make_unique<FederatedTracker>(config);
|
||||
ASSERT_FALSE(tracker->Ready());
|
||||
auto fut = tracker->Run();
|
||||
auto args = tracker->WorkerArgs();
|
||||
ASSERT_TRUE(tracker->Ready());
|
||||
|
||||
ASSERT_GE(tracker->Port(), 1);
|
||||
std::string host;
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||
|
||||
rc = tracker->Shutdown();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
ASSERT_FALSE(tracker->Ready());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user