Update collective implementation. (#10152)

* Update collective implementation.

- Cleanup resource during `Finalize` to avoid handling threads in destructor.
- Calculate the size for allgather automatically.
- Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
Jiaming Yuan
2024-03-30 18:57:31 +08:00
committed by GitHub
parent 230010d9a0
commit 8bad677c2f
31 changed files with 233 additions and 127 deletions

View File

@@ -694,9 +694,9 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
{size}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface{linalg::ArrayInterface(t)};
CHECK(ArrayInterface<1>{interface}.is_contiguous);
str = Json::Dump(interface);
Json iface{linalg::ArrayInterface(t)};
CHECK(ArrayInterface<1>{iface}.is_contiguous);
str = Json::Dump(iface);
return str;
};

View File

@@ -1,8 +1,7 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
@@ -10,6 +9,7 @@
#include <utility> // for pair
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
@@ -40,17 +40,27 @@ struct CollAPIEntry {
};
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr) {
std::chrono::seconds wait_for{100};
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{60};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
common::Timer timer;
timer.Start();
auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
CHECK(res != std::future_status::deferred);
if (res == std::future_status::ready) {
auto const &rc = ptr->second.get();
CHECK(rc.OK()) << rc.Report();
collective::SafeColl(rc);
break;
}
if (timer.Duration() > timeout && timeout.count() != 0) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
}
}
} // namespace
@@ -106,14 +116,17 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
WaitImpl(ptr);
// Internally, 0 indicates no timeout, which is the default since we don't want to
// interrupt the model training.
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
WaitImpl(ptr, std::chrono::seconds{timeout});
API_END();
}
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
WaitImpl(ptr);
WaitImpl(ptr, ptr->first->Timeout());
delete ptr;
API_END();
}