[coll] Improvements and fixes for tracker and allreduce. (#9745)

- Allow the tracker to wait.
- Fix allreduce type cast
- Return args from the federated tracker.
This commit is contained in:
Jiaming Yuan
2023-11-02 04:06:46 +08:00
committed by GitHub
parent 0ff8572737
commit 4da4e092b5
8 changed files with 184 additions and 57 deletions

View File

@@ -3,19 +3,35 @@
*/
#include "coll.h"
#include <algorithm> // for min, max
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include <algorithm> // for min, max, copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include <type_traits> // for is_floating_point_v, is_same_v
#include <utility> // for move
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#if defined(XGBOOST_USE_CUDA)
#include "cuda_fp16.h" // for __half
#endif
namespace xgboost::collective {
template <typename T>
bool constexpr IsFloatingPointV() {
#if defined(XGBOOST_USE_CUDA)
return std::is_floating_point_v<T> || std::is_same_v<T, __half>;
#else
return std::is_floating_point_v<T>;
#endif // defined(XGBOOST_USE_CUDA)
}
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type, Op op) {
ArrayInterfaceHandler::Type type, Op op) {
namespace coll = ::xgboost::collective;
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
@@ -25,32 +41,59 @@ namespace xgboost::collective {
p_out[i] = elem_op(p_lhs[i], p_out[i]);
}
};
auto fn = [&](auto elem_op) {
return coll::Allreduce(
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
auto fn = [&](auto elem_op, auto t) {
using T = decltype(t);
auto erased_fn = [redop_fn, elem_op](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out);
redop_fn(lhs_t, rhs_t, elem_op);
};
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
};
switch (op) {
case Op::kMax: {
return fn([](auto l, auto r) { return std::max(l, r); });
auto rc = DispatchDType(type, [&](auto t) {
using T = decltype(t);
switch (op) {
case Op::kMax: {
return fn([](auto l, auto r) { return std::max(l, r); }, t);
}
case Op::kMin: {
return fn([](auto l, auto r) { return std::min(l, r); }, t);
}
case Op::kSum: {
return fn(std::plus<>{}, t);
}
case Op::kBitwiseAND: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
} else {
return fn(std::bit_and<>{}, t);
}
}
case Op::kBitwiseOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
} else {
return fn(std::bit_or<>{}, t);
}
}
case Op::kBitwiseXOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
} else {
return fn(std::bit_xor<>{}, t);
}
}
}
case Op::kMin: {
return fn([](auto l, auto r) { return std::min(l, r); });
}
case Op::kSum: {
return fn(std::plus<>{});
}
case Op::kBitwiseAND: {
return fn(std::bit_and<>{});
}
case Op::kBitwiseOR: {
return fn(std::bit_or<>{});
}
case Op::kBitwiseXOR: {
return fn(std::bit_xor<>{});
}
}
return comm.Block();
return Fail("Invalid op.");
});
return std::move(rc) << [&] { return comm.Block(); };
}
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,

View File

@@ -16,7 +16,7 @@
#endif // defined(_WIN32)
#include <algorithm> // for sort
#include <chrono> // for seconds
#include <chrono> // for seconds, ms
#include <cstdint> // for int32_t
#include <string> // for string
#include <utility> // for move, forward
@@ -37,6 +37,25 @@ Tracker::Tracker(Json const& config)
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
Result Tracker::WaitUntilReady() const {
using namespace std::chrono_literals; // NOLINT
// Busy waiting. The function is mostly for waiting for the OS to launch an async
// thread, which should be reasonably fast.
common::Timer timer;
timer.Start();
while (!this->Ready()) {
auto ela = timer.Duration().count();
if (ela > this->Timeout().count()) {
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(100ms);
}
return Success();
}
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
: sock_{std::move(sock)} {
auto host = addr.Addr();
@@ -76,6 +95,7 @@ RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
auto rc = collective::GetHostAddress(&self);
auto host = OptionalArg<String>(config, "host", self);
host_ = host;
listener_ = TCPSocket::Create(SockDomain::kV4);
rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
@@ -173,6 +193,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
while (state.ShouldContinue()) {
TCPSocket sock;
SockAddrV4 addr;
this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr);
if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc));
@@ -237,10 +258,21 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
}
}
ready_ = false;
return Success();
});
}
[[nodiscard]] Json RabitTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
[[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out);
if (!rc.OK()) {

View File

@@ -40,6 +40,7 @@ class Tracker {
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
std::atomic<bool> ready_{false};
public:
explicit Tracker(Json const& config);
@@ -47,10 +48,17 @@ class Tracker {
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
[[nodiscard]] Result WaitUntilReady() const;
[[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
/**
* @brief Flag to indicate whether the server is running.
*/
[[nodiscard]] bool Ready() const { return ready_; }
};
class RabitTracker : public Tracker {
@@ -124,13 +132,7 @@ class RabitTracker : public Tracker {
~RabitTracker() noexcept(false) override = default;
std::future<Result> Run() override;
[[nodiscard]] Json WorkerArgs() const override {
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
[[nodiscard]] Json WorkerArgs() const override;
};
// Prob the public IP address of the host, need a better method.

View File

@@ -29,6 +29,7 @@ struct Timer {
void Start() { start = ClockT::now(); }
void Stop() { elapsed += ClockT::now() - start; }
double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
SecondsT Duration() const { return ClockT::now() - start; }
void PrintElapsed(std::string label) {
char buffer[255];
snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(),