Update collective implementation. (#10152)
* Update collective implementation. - Cleanup resource during `Finalize` to avoid handling threads in destructor. - Calculate the size for allgather automatically. - Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
@@ -694,9 +694,9 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
|
||||
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
|
||||
{size}, DeviceOrd::CPU());
|
||||
CHECK(t.CContiguous());
|
||||
Json interface{linalg::ArrayInterface(t)};
|
||||
CHECK(ArrayInterface<1>{interface}.is_contiguous);
|
||||
str = Json::Dump(interface);
|
||||
Json iface{linalg::ArrayInterface(t)};
|
||||
CHECK(ArrayInterface<1>{iface}.is_contiguous);
|
||||
str = Json::Dump(iface);
|
||||
return str;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
@@ -10,6 +9,7 @@
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@@ -40,17 +40,27 @@ struct CollAPIEntry {
|
||||
};
|
||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr) {
|
||||
std::chrono::seconds wait_for{100};
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{60};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
|
||||
auto fut = ptr->second;
|
||||
while (fut.valid()) {
|
||||
auto res = fut.wait_for(wait_for);
|
||||
CHECK(res != std::future_status::deferred);
|
||||
|
||||
if (res == std::future_status::ready) {
|
||||
auto const &rc = ptr->second.get();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
collective::SafeColl(rc);
|
||||
break;
|
||||
}
|
||||
|
||||
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@@ -106,14 +116,17 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
WaitImpl(ptr);
|
||||
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
||||
// interrupt the model training.
|
||||
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
||||
WaitImpl(ptr, std::chrono::seconds{timeout});
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
WaitImpl(ptr);
|
||||
WaitImpl(ptr, ptr->first->Timeout());
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user