[coll] Implement shutdown for tracker and comm. (#10208)
- Force shutdown the tracker. - Implement shutdown notice for error handling thread in comm.
This commit is contained in:
@@ -5,9 +5,11 @@
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
@@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
|
||||
|
||||
namespace {
|
||||
using TrackerHandleT =
|
||||
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||
|
||||
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
||||
xgboost_CHECK_C_ARG_PTR(handle);
|
||||
@@ -41,12 +43,14 @@ struct CollAPIEntry {
|
||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{60};
|
||||
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
|
||||
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
|
||||
|
||||
auto fut = ptr->second;
|
||||
while (fut.valid()) {
|
||||
auto res = fut.wait_for(wait_for);
|
||||
@@ -72,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
|
||||
Json jconfig = Json::Load(config);
|
||||
|
||||
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
||||
std::unique_ptr<collective::Tracker> tptr;
|
||||
std::shared_ptr<collective::Tracker> tptr;
|
||||
if (type == "federated") {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
|
||||
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
|
||||
#else
|
||||
LOG(FATAL) << error::NoFederated();
|
||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||
} else if (type == "rabit") {
|
||||
tptr = std::make_unique<collective::RabitTracker>(jconfig);
|
||||
tptr = std::make_shared<collective::RabitTracker>(jconfig);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator:" << type;
|
||||
}
|
||||
@@ -103,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
||||
@@ -111,13 +115,14 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
||||
// interrupt the model training.
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
||||
WaitImpl(ptr, std::chrono::seconds{timeout});
|
||||
API_END();
|
||||
@@ -125,8 +130,24 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
API_BEGIN();
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
ptr->first->Stop();
|
||||
// The wait is not necessary since we just called stop, just reusing the function to do
|
||||
// any potential cleanups.
|
||||
WaitImpl(ptr, ptr->first->Timeout());
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
// Make sure no one else is waiting on the tracker.
|
||||
while (!ptr->first.unique()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > ptr->first->Timeout().count()) {
|
||||
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(64ms);
|
||||
}
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user