[coll] Implement shutdown for tracker and comm. (#10208)

- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
This commit is contained in:
Jiaming Yuan
2024-04-20 04:08:17 +08:00
committed by GitHub
parent 8fb05c8c95
commit 3fbb221fec
24 changed files with 553 additions and 199 deletions

View File

@@ -5,9 +5,11 @@
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include <thread> // for sleep_for
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
@@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
namespace {
using TrackerHandleT =
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
xgboost_CHECK_C_ARG_PTR(handle);
@@ -41,12 +43,14 @@ struct CollAPIEntry {
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{60};
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
common::Timer timer;
timer.Start();
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
@@ -72,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
Json jconfig = Json::Load(config);
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
std::unique_ptr<collective::Tracker> tptr;
std::shared_ptr<collective::Tracker> tptr;
if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
#else
LOG(FATAL) << error::NoFederated();
#endif // defined(XGBOOST_USE_FEDERATED)
} else if (type == "rabit") {
tptr = std::make_unique<collective::RabitTracker>(jconfig);
tptr = std::make_shared<collective::RabitTracker>(jconfig);
} else {
LOG(FATAL) << "Unknown communicator:" << type;
}
@@ -103,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
API_END();
}
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
CHECK(!ptr->second.valid()) << "Tracker is already running.";
@@ -111,13 +115,14 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
API_END();
}
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
// Internally, 0 indicates no timeout, which is the default since we don't want to
// interrupt the model training.
xgboost_CHECK_C_ARG_PTR(config);
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
WaitImpl(ptr, std::chrono::seconds{timeout});
API_END();
@@ -125,8 +130,24 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
using namespace std::chrono_literals; // NOLINT
auto *ptr = GetTrackerHandle(handle);
ptr->first->Stop();
// The wait is not necessary since we just called stop, just reusing the function to do
// any potential cleanups.
WaitImpl(ptr, ptr->first->Timeout());
common::Timer timer;
timer.Start();
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
}
std::this_thread::sleep_for(64ms);
}
delete ptr;
API_END();
}