xgboost/src/collective/tracker.h
Jiaming Yuan 4da4e092b5
[coll] Improvements and fixes for tracker and allreduce. (#9745)
- Allow the tracker to wait.
- Fix allreduce type cast
- Return args from the federated tracker.
2023-11-02 04:06:46 +08:00

144 lines
5.1 KiB
C++

/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <future> // for future
#include <string> // for string
#include <utility> // for pair
#include <vector> // for vector
#include "protocol.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
/**
*
* @brief Implementation of RABIT tracker.
*
* * What is a tracker
*
* The implementation of collective follows what RABIT did in the past. It requires a
* tracker to coordinate initialization and error recovery of workers. While the
* original implementation attempted to attain error resislient inside the collective
* module, which turned out be too challenging due to large amount of external
* states. The new implementation here differs from RABIT in the way that neither state
* recovery nor resislient is handled inside the collective, it merely provides the
* mechanism to signal error to other workers through the use of a centralized tracker.
*
* There are three major functionalities provided the a tracker, namely:
* - Initialization. Share the node addresses among all workers.
* - Logging.
* - Signal error. If an exception is thrown in one (or many) of the workers, it can
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
std::atomic<bool> ready_{false};
public:
explicit Tracker(Json const& config);
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
[[nodiscard]] Result WaitUntilReady() const;
[[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
/**
* @brief Flag to indicate whether the server is running.
*/
[[nodiscard]] bool Ready() const { return ready_; }
};
class RabitTracker : public Tracker {
// a wrapper for connected worker socket.
class WorkerProxy {
TCPSocket sock_;
proto::PeerInfo info_;
std::int32_t eport_{0};
std::int32_t world_{-1};
std::string task_id_;
proto::CMD cmd_{proto::CMD::kInvalid};
std::string msg_;
std::int32_t code_{0};
Result rc_;
public:
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
WorkerProxy(WorkerProxy const& that) = delete;
WorkerProxy(WorkerProxy&& that) = default;
WorkerProxy& operator=(WorkerProxy const&) = delete;
WorkerProxy& operator=(WorkerProxy&&) = default;
[[nodiscard]] auto Host() const { return info_.host; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Port() const { return info_.port; }
[[nodiscard]] auto Rank() const { return info_.rank; }
[[nodiscard]] auto ErrorPort() const { return eport_; }
[[nodiscard]] auto Command() const { return cmd_; }
[[nodiscard]] auto Msg() const { return msg_; }
[[nodiscard]] auto Code() const { return code_; }
[[nodiscard]] Result const& Status() const { return rc_; }
[[nodiscard]] Result& Status() { return rc_; }
void Send(StringView value) { this->sock_.Send(value); }
};
// provide an ordering for workers, this helps us get deterministic topology.
struct WorkerCmp {
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
auto const& lh = lhs.Host();
auto const& rh = rhs.Host();
if (lh != rh) {
return lh < rh;
}
return lhs.TaskID() < rhs.TaskID();
}
};
private:
std::string host_;
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
TCPSocket listener_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
public:
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
std::chrono::seconds timeout)
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
listener_ = TCPSocket::Create(SockDomain::kV4);
auto rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
explicit RabitTracker(Json const& config);
~RabitTracker() noexcept(false) override = default;
std::future<Result> Run() override;
[[nodiscard]] Json WorkerArgs() const override;
};
// Prob the public IP address of the host, need a better method.
//
// This is directly translated from the previous Python implementation, we should find a
// more riguous approach, can use some expertise in network programming.
[[nodiscard]] Result GetHostAddress(std::string* out);
} // namespace xgboost::collective