[coll] Improvements and fixes for tracker and allreduce. (#9745)

- Allow the tracker to wait.
- Fix allreduce type cast
- Return args from the federated tracker.
This commit is contained in:
Jiaming Yuan
2023-11-02 04:06:46 +08:00
committed by GitHub
parent 0ff8572737
commit 4da4e092b5
8 changed files with 184 additions and 57 deletions

View File

@@ -6,7 +6,6 @@
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
#include <grpcpp/server_builder.h> // for ServerBuilder
#include <chrono> // for ms
#include <cstdint> // for int32_t
#include <exception> // for exception
#include <limits> // for numeric_limits
@@ -61,7 +60,7 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
}
std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
return std::async(std::launch::async, [this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
xgboost::collective::federated::FederatedService service{
static_cast<std::int32_t>(this->n_workers_)};
@@ -98,10 +97,13 @@ std::future<Result> FederatedTracker::Run() {
try {
server_ = builder.BuildAndStart();
ready_ = true;
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});
}
ready_ = false;
return collective::Success();
});
}
@@ -109,18 +111,8 @@ std::future<Result> FederatedTracker::Run() {
FederatedTracker::~FederatedTracker() = default;
Result FederatedTracker::Shutdown() {
common::Timer timer;
timer.Start();
using namespace std::chrono_literals;
while (!server_) {
timer.Stop();
auto ela = timer.ElapsedSeconds();
if (ela > this->Timeout().count()) {
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(10ms);
}
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
try {
server_->Shutdown();
@@ -130,4 +122,17 @@ Result FederatedTracker::Shutdown() {
return Success();
}
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
std::string host;
rc = GetHostAddress(&host);
CHECK(rc.OK());
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
} // namespace xgboost::collective

View File

@@ -57,9 +57,8 @@ class FederatedTracker : public collective::Tracker {
explicit FederatedTracker(Json const& config);
~FederatedTracker() override;
std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves.
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Json WorkerArgs() const override;
[[nodiscard]] Result Shutdown();
};
} // namespace xgboost::collective