Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
char const *c_json_config, DMatrixHandle m,
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
char const *config, DMatrixHandle m,
bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);
@@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
*
* @brief Experimental support for exposing internal communicator in XGBoost.
*
* @note This is still under development.
*
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
* changed significantly since its adoption. It consists of a tracker and a set of
* workers. The tracker is responsible for bootstrapping the communication group and
* handling centralized tasks like logging. The workers are actual communicators
* performing collective tasks like allreduce.
*
* To use the collective implementation, one needs to first create a tracker with
* corresponding parameters, then get the arguments for workers using
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
* runtime is shutting down. This requirement is similar to a Python thread or socket,
* which should not be relied upon in a `__del__` function.
*
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
* is called, for instance, training a booster might return a connection error.
*
* @{
*/
/**
* @brief Handle to tracker.
* @brief Handle to the tracker.
*
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
* other one is `federated`.
* other one is `federated`. `rabit` is used for normal collective communication, while
* `federated` is used for federated learning.
*
* This is still under development.
*/
typedef void *TrackerHandle; /* NOLINT */
@@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
*
* @param config JSON encoded parameters.
*
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
* and `federated`.
* - dmlc_communicator: String, the type of tracker to create. Available options are
* `rabit` and `federated`. See @ref TrackerHandle for more info.
* - n_workers: Integer, the number of workers.
* - port: (Optional) Integer, the port this tracker should listen to.
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
* - timeout: (Optional) Integer, timeout in seconds for various networking
operations. Default is 300 seconds.
*
* Some configurations are `rabit` specific:
*
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
* This can be useful when the communicator cannot reliably obtain the host address.
* - sortby: (Optional) Integer.
* + 0: Sort workers by their host name.
* + 1: Sort workers by task IDs.
*
* Some `federated` specific configurations:
* - federated_secure: Boolean, whether this is a secure server.
* - federated_secure: Boolean, whether this is a secure server. False for testing.
* - server_key_path: Path to the server key. Used only if this is a secure server.
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
@@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
*/
XGB_DLL int XGTrackerFree(TrackerHandle handle);
/*!
* \brief Initialize the collective communicator.
/**
* @brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
* Call this once in the worker process before using anything. Please make sure
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
* thread-local variable.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* @param config JSON encoded configuration. Accepted JSON keys are:
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
* `USE_DLOPEN_NCCL`.
* Only applicable to the Federated communicator (use upper case for environment variables, use
*
* Only applicable to the `rabit` communicator:
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
* - dmlc_tracker_port: Port number of the tracker.
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - dmlc_retry: The number of retries for connection failure.
* - dmlc_timeout: Timeout in seconds.
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
*
* Only applicable to the `federated` communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* config);
/*!
* \brief Finalize the collective communicator.
/**
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
* Call this function after you have finished all jobs.
*
* \return 0 for success, -1 for failure.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);
/*!
* \brief Get rank of current process.
/**
* @brief Get rank of the current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);
/*!
* \brief Get total number of processes.
/**
* @brief Get the total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);
/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);
/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the tracker.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
* This function can be used to communicate the information of the progress to the user
* who monitors the tracker.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
* @param message The message to be printed.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);
/*!
* \brief Get the name of the processor.
/**
* @brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
* @param name_str Pointer to received returned processor name.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
/**
* @brief Broadcast a memory region to all others from root. This function is NOT
* thread-safe.
*
* Example:
* \code
* @code
* int a = 1;
* Broadcast(&a, sizeof(a), root);
* \endcode
* @endcode
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Pointer to the send or receive buffer.
* @param size Size of the data in bytes.
* @param root The process rank to broadcast from.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
/**
* @brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* \code
* vector<int> data(10);
* @code
* enum class Op {
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
* };
* std::vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
* ...
* \endcode
* @endcode
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Buffer for both sending and receiving data.
* @param count Number of elements to be reduced.
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);

View File

@@ -55,10 +55,9 @@ struct ResultImpl {
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
} // namespace detail
/**

View File

@@ -16,6 +16,10 @@
#include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap
#if defined(__linux__)
#include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
#endif // defined(__linux__)
#if !defined(xgboost_IS_MINGW)
#if defined(__MINGW32__)
@@ -319,7 +323,8 @@ class TCPSocket {
std::int32_t domain;
socklen_t len = sizeof(domain);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
0);
return ret_iafamily(domain);
#else
struct sockaddr sa;
@@ -426,6 +431,35 @@ class TCPSocket {
return Success();
}
[[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
socklen_t optlen;
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
&optlen);
if (rc != 0 || optlen != sizeof(std::int32_t)) {
return system::FailWithCode("getsockopt");
}
return Success();
}
[[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
socklen_t optlen;
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
&optlen);
if (rc != 0 || optlen != sizeof(std::int32_t)) {
return system::FailWithCode("getsockopt");
}
return Success();
}
#if defined(__linux__)
[[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
: system::FailWithCode("ioctl");
}
[[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
: system::FailWithCode("ioctl");
}
#endif // defined(__linux__)
[[nodiscard]] Result SetKeepAlive() {
std::int32_t keepalive = 1;
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
@@ -436,10 +470,9 @@ class TCPSocket {
return Success();
}
[[nodiscard]] Result SetNoDelay() {
std::int32_t tcp_no_delay = 1;
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay));
[[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
sizeof(no_delay));
if (rc != 0) {
return system::FailWithCode("Failed to set TCP no delay.");
}
@@ -602,45 +635,47 @@ class TCPSocket {
}
/**
* \brief Send data, without error then all data should be sent.
* @brief Send data, without error then all data should be sent.
*/
[[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
[[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
char const *_buf = reinterpret_cast<const char *>(buf);
std::size_t ndone = 0;
std::size_t &ndone = *n_sent;
ndone = 0;
while (ndone < len) {
ssize_t ret = send(handle_, _buf, len - ndone, 0);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
return Success();
}
system::ThrowAtError("send");
return system::FailWithCode("send");
}
_buf += ret;
ndone += ret;
}
return ndone;
return Success();
}
/**
* \brief Receive data, without error then all data should be received.
* @brief Receive data, without error then all data should be received.
*/
[[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
[[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
char *_buf = reinterpret_cast<char *>(buf);
std::size_t ndone = 0;
std::size_t &ndone = *n_recv;
ndone = 0;
while (ndone < len) {
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
return Success();
}
system::ThrowAtError("recv");
return system::FailWithCode("recv");
}
if (ret == 0) {
return ndone;
return Success();
}
_buf += ret;
ndone += ret;
}
return ndone;
return Success();
}
/**
* \brief Send data using the socket