Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
|
||||
*
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
|
||||
char const *config, DMatrixHandle m,
|
||||
bst_ulong const **out_shape, bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
@@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
||||
*
|
||||
* @brief Experimental support for exposing internal communicator in XGBoost.
|
||||
*
|
||||
* @note This is still under development.
|
||||
*
|
||||
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
|
||||
* changed significantly since its adoption. It consists of a tracker and a set of
|
||||
* workers. The tracker is responsible for bootstrapping the communication group and
|
||||
* handling centralized tasks like logging. The workers are actual communicators
|
||||
* performing collective tasks like allreduce.
|
||||
*
|
||||
* To use the collective implementation, one needs to first create a tracker with
|
||||
* corresponding parameters, then get the arguments for workers using
|
||||
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
|
||||
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
|
||||
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
|
||||
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
|
||||
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
|
||||
* runtime is shutting down. This requirement is similar to a Python thread or socket,
|
||||
* which should not be relied upon in a `__del__` function.
|
||||
*
|
||||
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
|
||||
* is called, for instance, training a booster might return a connection error.
|
||||
*
|
||||
* @{
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Handle to tracker.
|
||||
* @brief Handle to the tracker.
|
||||
*
|
||||
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
|
||||
* other one is `federated`.
|
||||
* other one is `federated`. `rabit` is used for normal collective communication, while
|
||||
* `federated` is used for federated learning.
|
||||
*
|
||||
* This is still under development.
|
||||
*/
|
||||
typedef void *TrackerHandle; /* NOLINT */
|
||||
|
||||
@@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
|
||||
*
|
||||
* @param config JSON encoded parameters.
|
||||
*
|
||||
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
|
||||
* and `federated`.
|
||||
* - dmlc_communicator: String, the type of tracker to create. Available options are
|
||||
* `rabit` and `federated`. See @ref TrackerHandle for more info.
|
||||
* - n_workers: Integer, the number of workers.
|
||||
* - port: (Optional) Integer, the port this tracker should listen to.
|
||||
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
|
||||
* - timeout: (Optional) Integer, timeout in seconds for various networking
|
||||
operations. Default is 300 seconds.
|
||||
*
|
||||
* Some configurations are `rabit` specific:
|
||||
*
|
||||
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
|
||||
* This can be useful when the communicator cannot reliably obtain the host address.
|
||||
* - sortby: (Optional) Integer.
|
||||
* + 0: Sort workers by their host name.
|
||||
* + 1: Sort workers by task IDs.
|
||||
*
|
||||
* Some `federated` specific configurations:
|
||||
* - federated_secure: Boolean, whether this is a secure server.
|
||||
* - federated_secure: Boolean, whether this is a secure server. False for testing.
|
||||
* - server_key_path: Path to the server key. Used only if this is a secure server.
|
||||
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
|
||||
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
|
||||
@@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
|
||||
*/
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
/**
|
||||
* @brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
* Call this once in the worker process before using anything. Please make sure
|
||||
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
|
||||
* thread-local variable.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* @param config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
|
||||
* `USE_DLOPEN_NCCL`.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
*
|
||||
* Only applicable to the `rabit` communicator:
|
||||
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
|
||||
* - dmlc_tracker_port: Port number of the tracker.
|
||||
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - dmlc_retry: The number of retries for connection failure.
|
||||
* - dmlc_timeout: Timeout in seconds.
|
||||
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
|
||||
*
|
||||
* Only applicable to the `federated` communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorInit(char const* config);
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
/**
|
||||
* @brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
* Call this function after you have finished all jobs.
|
||||
*
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
/**
|
||||
* @brief Get rank of the current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
* @return Rank of the worker.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetRank(void);
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
/**
|
||||
* @brief Get the total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
* @return Total world size.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void);
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
/**
|
||||
* @brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
* @return True if the communicator is distributed.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void);
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
/**
|
||||
* @brief Print the message to the tracker.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
* This function can be used to communicate the information of the progress to the user
|
||||
* who monitors the tracker.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param message The message to be printed.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message);
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
/**
|
||||
* @brief Get the name of the processor.
|
||||
*
|
||||
* \param name_str Pointer to received returned processor name.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param name_str Pointer to received returned processor name.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
/**
|
||||
* @brief Broadcast a memory region to all others from root. This function is NOT
|
||||
* thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* \code
|
||||
* @code
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
* \endcode
|
||||
* @endcode
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param root The process rank to broadcast from.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
/**
|
||||
* @brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* \code
|
||||
* vector<int> data(10);
|
||||
* @code
|
||||
* enum class Op {
|
||||
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
|
||||
* };
|
||||
* std::vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \endcode
|
||||
* @endcode
|
||||
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* @param count Number of elements to be reduced.
|
||||
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);
|
||||
|
||||
|
||||
@@ -55,10 +55,9 @@ struct ResultImpl {
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
#define __builtin_FILE() nullptr
|
||||
#define __builtin_LINE() (-1)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||
#endif
|
||||
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
#include <utility> // std::swap
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
|
||||
#endif // defined(__linux__)
|
||||
|
||||
#if !defined(xgboost_IS_MINGW)
|
||||
|
||||
#if defined(__MINGW32__)
|
||||
@@ -319,7 +323,8 @@ class TCPSocket {
|
||||
std::int32_t domain;
|
||||
socklen_t len = sizeof(domain);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
|
||||
getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
|
||||
0);
|
||||
return ret_iafamily(domain);
|
||||
#else
|
||||
struct sockaddr sa;
|
||||
@@ -426,6 +431,35 @@ class TCPSocket {
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
|
||||
socklen_t optlen;
|
||||
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
|
||||
&optlen);
|
||||
if (rc != 0 || optlen != sizeof(std::int32_t)) {
|
||||
return system::FailWithCode("getsockopt");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
|
||||
socklen_t optlen;
|
||||
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
|
||||
&optlen);
|
||||
if (rc != 0 || optlen != sizeof(std::int32_t)) {
|
||||
return system::FailWithCode("getsockopt");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
#if defined(__linux__)
|
||||
[[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
|
||||
return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
|
||||
: system::FailWithCode("ioctl");
|
||||
}
|
||||
[[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
|
||||
return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
|
||||
: system::FailWithCode("ioctl");
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
[[nodiscard]] Result SetKeepAlive() {
|
||||
std::int32_t keepalive = 1;
|
||||
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
|
||||
@@ -436,10 +470,9 @@ class TCPSocket {
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result SetNoDelay() {
|
||||
std::int32_t tcp_no_delay = 1;
|
||||
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||
sizeof(tcp_no_delay));
|
||||
[[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
|
||||
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
|
||||
sizeof(no_delay));
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set TCP no delay.");
|
||||
}
|
||||
@@ -602,45 +635,47 @@ class TCPSocket {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Send data, without error then all data should be sent.
|
||||
* @brief Send data, without error then all data should be sent.
|
||||
*/
|
||||
[[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
|
||||
[[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
|
||||
char const *_buf = reinterpret_cast<const char *>(buf);
|
||||
std::size_t ndone = 0;
|
||||
std::size_t &ndone = *n_sent;
|
||||
ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(handle_, _buf, len - ndone, 0);
|
||||
if (ret == -1) {
|
||||
if (system::LastErrorWouldBlock()) {
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
system::ThrowAtError("send");
|
||||
return system::FailWithCode("send");
|
||||
}
|
||||
_buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* \brief Receive data, without error then all data should be received.
|
||||
* @brief Receive data, without error then all data should be received.
|
||||
*/
|
||||
[[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
|
||||
[[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
|
||||
char *_buf = reinterpret_cast<char *>(buf);
|
||||
std::size_t ndone = 0;
|
||||
std::size_t &ndone = *n_recv;
|
||||
ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
|
||||
if (ret == -1) {
|
||||
if (system::LastErrorWouldBlock()) {
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
system::ThrowAtError("recv");
|
||||
return system::FailWithCode("recv");
|
||||
}
|
||||
if (ret == 0) {
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
_buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* \brief Send data using the socket
|
||||
|
||||
Reference in New Issue
Block a user