Merge branch 'master'
This commit is contained in:
@@ -6,7 +6,6 @@
|
||||
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
||||
#include <grpcpp/server_builder.h> // for ServerBuilder
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception
|
||||
#include <limits> // for numeric_limits
|
||||
@@ -61,7 +60,7 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
}
|
||||
|
||||
std::future<Result> FederatedTracker::Run() {
|
||||
return std::async([this]() {
|
||||
return std::async(std::launch::async, [this]() {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||
xgboost::collective::federated::FederatedService service{
|
||||
static_cast<std::int32_t>(this->n_workers_)};
|
||||
@@ -98,10 +97,13 @@ std::future<Result> FederatedTracker::Run() {
|
||||
|
||||
try {
|
||||
server_ = builder.BuildAndStart();
|
||||
ready_ = true;
|
||||
server_->Wait();
|
||||
} catch (std::exception const& e) {
|
||||
return collective::Fail(std::string{e.what()});
|
||||
}
|
||||
|
||||
ready_ = false;
|
||||
return collective::Success();
|
||||
});
|
||||
}
|
||||
@@ -109,18 +111,8 @@ std::future<Result> FederatedTracker::Run() {
|
||||
FederatedTracker::~FederatedTracker() = default;
|
||||
|
||||
Result FederatedTracker::Shutdown() {
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
using namespace std::chrono_literals;
|
||||
while (!server_) {
|
||||
timer.Stop();
|
||||
auto ela = timer.ElapsedSeconds();
|
||||
if (ela > this->Timeout().count()) {
|
||||
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
auto rc = this->WaitUntilReady();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
try {
|
||||
server_->Shutdown();
|
||||
@@ -130,4 +122,17 @@ Result FederatedTracker::Shutdown() {
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
|
||||
auto rc = this->WaitUntilReady();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
std::string host;
|
||||
rc = GetHostAddress(&host);
|
||||
CHECK(rc.OK());
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host};
|
||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
||||
return args;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -57,9 +57,8 @@ class FederatedTracker : public collective::Tracker {
|
||||
explicit FederatedTracker(Json const& config);
|
||||
~FederatedTracker() override;
|
||||
std::future<Result> Run() override;
|
||||
// federated tracker do not provide initialization parameters, users have to provide it
|
||||
// themseleves.
|
||||
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
|
||||
|
||||
[[nodiscard]] Json WorkerArgs() const override;
|
||||
[[nodiscard]] Result Shutdown();
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user