Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
parent
ba9b4cb1ee
commit
a5a58102e5
@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
|
||||
option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF)
|
||||
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
|
||||
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
|
||||
option(RABIT_MOCK "Build rabit with mock" OFF)
|
||||
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
|
||||
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
|
||||
## CUDA
|
||||
@ -282,9 +281,6 @@ if(MSVC)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# rabit
|
||||
add_subdirectory(rabit)
|
||||
|
||||
# core xgboost
|
||||
add_subdirectory(${xgboost_SOURCE_DIR}/src)
|
||||
target_link_libraries(objxgboost PUBLIC dmlc)
|
||||
|
||||
@ -106,10 +106,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/collective/comm.o \
|
||||
$(PKGROOT)/src/collective/comm_group.o \
|
||||
$(PKGROOT)/src/collective/coll.o \
|
||||
$(PKGROOT)/src/collective/communicator-inl.o \
|
||||
$(PKGROOT)/src/collective/tracker.o \
|
||||
$(PKGROOT)/src/collective/communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_handler.o \
|
||||
$(PKGROOT)/src/collective/loop.o \
|
||||
$(PKGROOT)/src/collective/socket.o \
|
||||
@ -134,7 +131,4 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/common/version.o \
|
||||
$(PKGROOT)/src/c_api/c_api.o \
|
||||
$(PKGROOT)/src/c_api/c_api_error.o \
|
||||
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
|
||||
$(PKGROOT)/rabit/src/engine.o \
|
||||
$(PKGROOT)/rabit/src/rabit_c_api.o \
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o
|
||||
$(PKGROOT)/amalgamation/dmlc-minimum0.o
|
||||
|
||||
@ -106,10 +106,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/collective/comm.o \
|
||||
$(PKGROOT)/src/collective/comm_group.o \
|
||||
$(PKGROOT)/src/collective/coll.o \
|
||||
$(PKGROOT)/src/collective/communicator-inl.o \
|
||||
$(PKGROOT)/src/collective/tracker.o \
|
||||
$(PKGROOT)/src/collective/communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_handler.o \
|
||||
$(PKGROOT)/src/collective/loop.o \
|
||||
$(PKGROOT)/src/collective/socket.o \
|
||||
@ -134,7 +131,4 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/common/version.o \
|
||||
$(PKGROOT)/src/c_api/c_api.o \
|
||||
$(PKGROOT)/src/c_api/c_api_error.o \
|
||||
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
|
||||
$(PKGROOT)/rabit/src/engine.o \
|
||||
$(PKGROOT)/rabit/src/rabit_c_api.o \
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o
|
||||
$(PKGROOT)/amalgamation/dmlc-minimum0.o
|
||||
|
||||
@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target)
|
||||
target_include_directories(
|
||||
${target} PRIVATE
|
||||
${xgboost_SOURCE_DIR}/gputreeshap
|
||||
${xgboost_SOURCE_DIR}/rabit/include
|
||||
${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
if(MSVC)
|
||||
|
||||
@ -16,7 +16,7 @@ def main(client: Client) -> None:
|
||||
m = 100000
|
||||
n = 100
|
||||
rng = da.random.default_rng(1)
|
||||
X = rng.normal(size=(m, n))
|
||||
X = rng.normal(size=(m, n), chunks=(10000, -1))
|
||||
y = X.sum(axis=1)
|
||||
|
||||
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
|
||||
|
||||
@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
|
||||
*
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
|
||||
char const *config, DMatrixHandle m,
|
||||
bst_ulong const **out_shape, bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
||||
*
|
||||
* @brief Experimental support for exposing internal communicator in XGBoost.
|
||||
*
|
||||
* @note This is still under development.
|
||||
*
|
||||
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
|
||||
* changed significantly since its adoption. It consists of a tracker and a set of
|
||||
* workers. The tracker is responsible for bootstrapping the communication group and
|
||||
* handling centralized tasks like logging. The workers are actual communicators
|
||||
* performing collective tasks like allreduce.
|
||||
*
|
||||
* To use the collective implementation, one needs to first create a tracker with
|
||||
* corresponding parameters, then get the arguments for workers using
|
||||
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
|
||||
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
|
||||
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
|
||||
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
|
||||
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
|
||||
* runtime is shutting down. This requirement is similar to a Python thread or socket,
|
||||
* which should not be relied upon in a `__del__` function.
|
||||
*
|
||||
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
|
||||
* is called, for instance, training a booster might return a connection error.
|
||||
*
|
||||
* @{
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Handle to tracker.
|
||||
* @brief Handle to the tracker.
|
||||
*
|
||||
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
|
||||
* other one is `federated`.
|
||||
* other one is `federated`. `rabit` is used for normal collective communication, while
|
||||
* `federated` is used for federated learning.
|
||||
*
|
||||
* This is still under development.
|
||||
*/
|
||||
typedef void *TrackerHandle; /* NOLINT */
|
||||
|
||||
@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
|
||||
*
|
||||
* @param config JSON encoded parameters.
|
||||
*
|
||||
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
|
||||
* and `federated`.
|
||||
* - dmlc_communicator: String, the type of tracker to create. Available options are
|
||||
* `rabit` and `federated`. See @ref TrackerHandle for more info.
|
||||
* - n_workers: Integer, the number of workers.
|
||||
* - port: (Optional) Integer, the port this tracker should listen to.
|
||||
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
|
||||
* - timeout: (Optional) Integer, timeout in seconds for various networking
|
||||
operations. Default is 300 seconds.
|
||||
*
|
||||
* Some configurations are `rabit` specific:
|
||||
*
|
||||
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
|
||||
* This can be useful when the communicator cannot reliably obtain the host address.
|
||||
* - sortby: (Optional) Integer.
|
||||
* + 0: Sort workers by their host name.
|
||||
* + 1: Sort workers by task IDs.
|
||||
*
|
||||
* Some `federated` specific configurations:
|
||||
* - federated_secure: Boolean, whether this is a secure server.
|
||||
* - federated_secure: Boolean, whether this is a secure server. False for testing.
|
||||
* - server_key_path: Path to the server key. Used only if this is a secure server.
|
||||
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
|
||||
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
|
||||
@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
|
||||
*/
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
/**
|
||||
* @brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
* Call this once in the worker process before using anything. Please make sure
|
||||
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
|
||||
* thread-local variable.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* @param config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
|
||||
* `USE_DLOPEN_NCCL`.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
*
|
||||
* Only applicable to the `rabit` communicator:
|
||||
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
|
||||
* - dmlc_tracker_port: Port number of the tracker.
|
||||
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - dmlc_retry: The number of retries for connection failure.
|
||||
* - dmlc_timeout: Timeout in seconds.
|
||||
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
|
||||
*
|
||||
* Only applicable to the `federated` communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorInit(char const* config);
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
/**
|
||||
* @brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
* Call this function after you have finished all jobs.
|
||||
*
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
/**
|
||||
* @brief Get rank of the current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
* @return Rank of the worker.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetRank(void);
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
/**
|
||||
* @brief Get the total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
* @return Total world size.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void);
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
/**
|
||||
* @brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
* @return True if the communicator is distributed.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void);
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
/**
|
||||
* @brief Print the message to the tracker.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
* This function can be used to communicate the information of the progress to the user
|
||||
* who monitors the tracker.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param message The message to be printed.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message);
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
/**
|
||||
* @brief Get the name of the processor.
|
||||
*
|
||||
* \param name_str Pointer to received returned processor name.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param name_str Pointer to received returned processor name.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
/**
|
||||
* @brief Broadcast a memory region to all others from root. This function is NOT
|
||||
* thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* \code
|
||||
* @code
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
* \endcode
|
||||
* @endcode
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param root The process rank to broadcast from.
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
/**
|
||||
* @brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* \code
|
||||
* vector<int> data(10);
|
||||
* @code
|
||||
* enum class Op {
|
||||
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
|
||||
* };
|
||||
* std::vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \endcode
|
||||
* @endcode
|
||||
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
* \return 0 for success, -1 for failure.
|
||||
* @param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* @param count Number of elements to be reduced.
|
||||
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*
|
||||
* @return 0 for success, -1 for failure.
|
||||
*/
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);
|
||||
|
||||
|
||||
@ -55,10 +55,9 @@ struct ResultImpl {
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
#define __builtin_FILE() nullptr
|
||||
#define __builtin_LINE() (-1)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||
#endif
|
||||
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
|
||||
@ -16,6 +16,10 @@
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
#include <utility> // std::swap
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
|
||||
#endif // defined(__linux__)
|
||||
|
||||
#if !defined(xgboost_IS_MINGW)
|
||||
|
||||
#if defined(__MINGW32__)
|
||||
@ -319,7 +323,8 @@ class TCPSocket {
|
||||
std::int32_t domain;
|
||||
socklen_t len = sizeof(domain);
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
|
||||
getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
|
||||
0);
|
||||
return ret_iafamily(domain);
|
||||
#else
|
||||
struct sockaddr sa;
|
||||
@ -426,6 +431,35 @@ class TCPSocket {
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
|
||||
socklen_t optlen;
|
||||
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
|
||||
&optlen);
|
||||
if (rc != 0 || optlen != sizeof(std::int32_t)) {
|
||||
return system::FailWithCode("getsockopt");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
|
||||
socklen_t optlen;
|
||||
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
|
||||
&optlen);
|
||||
if (rc != 0 || optlen != sizeof(std::int32_t)) {
|
||||
return system::FailWithCode("getsockopt");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
#if defined(__linux__)
|
||||
[[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
|
||||
return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
|
||||
: system::FailWithCode("ioctl");
|
||||
}
|
||||
[[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
|
||||
return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
|
||||
: system::FailWithCode("ioctl");
|
||||
}
|
||||
#endif // defined(__linux__)
|
||||
|
||||
[[nodiscard]] Result SetKeepAlive() {
|
||||
std::int32_t keepalive = 1;
|
||||
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
|
||||
@ -436,10 +470,9 @@ class TCPSocket {
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result SetNoDelay() {
|
||||
std::int32_t tcp_no_delay = 1;
|
||||
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||
sizeof(tcp_no_delay));
|
||||
[[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
|
||||
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
|
||||
sizeof(no_delay));
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set TCP no delay.");
|
||||
}
|
||||
@ -602,45 +635,47 @@ class TCPSocket {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Send data, without error then all data should be sent.
|
||||
* @brief Send data, without error then all data should be sent.
|
||||
*/
|
||||
[[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
|
||||
[[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
|
||||
char const *_buf = reinterpret_cast<const char *>(buf);
|
||||
std::size_t ndone = 0;
|
||||
std::size_t &ndone = *n_sent;
|
||||
ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(handle_, _buf, len - ndone, 0);
|
||||
if (ret == -1) {
|
||||
if (system::LastErrorWouldBlock()) {
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
system::ThrowAtError("send");
|
||||
return system::FailWithCode("send");
|
||||
}
|
||||
_buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* \brief Receive data, without error then all data should be received.
|
||||
* @brief Receive data, without error then all data should be received.
|
||||
*/
|
||||
[[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
|
||||
[[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
|
||||
char *_buf = reinterpret_cast<char *>(buf);
|
||||
std::size_t ndone = 0;
|
||||
std::size_t &ndone = *n_recv;
|
||||
ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
|
||||
if (ret == -1) {
|
||||
if (system::LastErrorWouldBlock()) {
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
system::ThrowAtError("recv");
|
||||
return system::FailWithCode("recv");
|
||||
}
|
||||
if (ret == 0) {
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
_buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
return Success();
|
||||
}
|
||||
/**
|
||||
* \brief Send data using the socket
|
||||
|
||||
@ -23,6 +23,7 @@ CONFIG = {
|
||||
"USE_NCCL": "OFF",
|
||||
"JVM_BINDINGS": "ON",
|
||||
"LOG_CAPI_INVOCATION": "OFF",
|
||||
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
|
||||
}
|
||||
|
||||
|
||||
@ -97,10 +98,6 @@ def native_build(args):
|
||||
|
||||
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
|
||||
|
||||
# if enviorment set rabit_mock
|
||||
if os.getenv("RABIT_MOCK", None) is not None:
|
||||
args.append("-DRABIT_MOCK:BOOL=ON")
|
||||
|
||||
# if enviorment set GPU_ARCH_FLAG
|
||||
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
|
||||
if gpu_arch_flag is not None:
|
||||
@ -162,12 +159,6 @@ def native_build(args):
|
||||
maybe_makedirs(output_folder)
|
||||
cp("../lib/" + library_name, output_folder)
|
||||
|
||||
print("copying pure-Python tracker")
|
||||
cp(
|
||||
"../python-package/xgboost/tracker.py",
|
||||
"{}/src/main/resources".format(xgboost4j),
|
||||
)
|
||||
|
||||
print("copying train/test files")
|
||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
|
||||
with cd("../demo/CLI/regression"):
|
||||
|
||||
@ -489,6 +489,11 @@
|
||||
<artifactId>kryo</artifactId>
|
||||
<version>5.6.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.14.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
|
||||
@ -54,9 +54,9 @@ public class XGBoost {
|
||||
|
||||
private final Map<String, Object> params;
|
||||
private final int round;
|
||||
private final Map<String, String> workerEnvs;
|
||||
private final Map<String, Object> workerEnvs;
|
||||
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
|
||||
this.params = params;
|
||||
this.round = round;
|
||||
this.workerEnvs = workerEnvs;
|
||||
@ -174,9 +174,9 @@ public class XGBoost {
|
||||
int numBoostRound) throws Exception {
|
||||
final RabitTracker tracker =
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start(0L)) {
|
||||
if (tracker.start()) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -22,7 +22,7 @@ import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession
|
||||
/**
|
||||
* Rabit tracker configurations.
|
||||
*
|
||||
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
|
||||
* Set timeout length to zero to disable timeout.
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
|
||||
* @param timeout The number of seconds before timeout waiting for workers to connect. and
|
||||
* for the tracker to shutdown.
|
||||
* @param hostIp The Rabit Tracker host IP address.
|
||||
* This is only needed if the host IP cannot be automatically guessed.
|
||||
* @param pythonExec The python executed path for Rabit Tracker,
|
||||
* which is only used for python implementation.
|
||||
* @param port The port number for the tracker to listen to. Use a system allocated one by
|
||||
* default.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long,
|
||||
hostIp: String = "", pythonExec: String = "")
|
||||
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
def apply(): TrackerConf = TrackerConf(0)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel {
|
||||
private def buildDistributedBooster(
|
||||
buildWatches: () => Watches,
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
rabitEnv: java.util.Map[String, Object],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel {
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val attempt = TaskContext.get().attemptNumber.toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||
val numRounds = xgbExecutionParam.numRounds
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
|
||||
@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel {
|
||||
}
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = new PyRabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
|
||||
)
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker: ITracker = new RabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
|
||||
tracker
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker = getTracker(nWorkers, trackerConf)
|
||||
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
}
|
||||
|
||||
@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
tracker.workerArgs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.workerArgs
|
||||
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
|
||||
var optionWatches: Option[() => Watches] = None
|
||||
@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel {
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
if (trackerReturnVal != 0) {
|
||||
throw new XGBoostError("XGBoostModel training failed.")
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params {
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
* TrackerConf class, which has the following definition:
|
||||
*
|
||||
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
|
||||
* trackerImpl: String)
|
||||
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
|
||||
*
|
||||
* See below for detailed explanations.
|
||||
*
|
||||
* - trackerImpl: Select the implementation of Rabit tracker.
|
||||
* default: "python"
|
||||
*
|
||||
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
|
||||
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
|
||||
* The "scala" version removes Python components, and fully supports timeout settings.
|
||||
*
|
||||
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
|
||||
* default: 0 millisecond (no timeout)
|
||||
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
|
||||
* default: 0 (no timeout)
|
||||
*
|
||||
* Timeout for constructing the communication group and waiting for the tracker to
|
||||
* shutdown when it's instructed to, doesn't apply to communication when tracking
|
||||
* is running.
|
||||
* The timeout value should take the time of data loading and pre-processing into account,
|
||||
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
|
||||
* due to potential lazy execution. Alternatively, you may force Spark to
|
||||
* perform data transformation before calling XGBoost.train(), so that this timeout truly
|
||||
* reflects the connection delay. Set a reasonable timeout value to prevent model
|
||||
* training/testing from hanging indefinitely, possible due to network issues.
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*
|
||||
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
|
||||
* cannot be automatically guessed.
|
||||
*
|
||||
* - port : The port number for the tracker to listen to. Use a system allocated one by
|
||||
* default.
|
||||
*/
|
||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
||||
setDefault(trackerConf, TrackerConf())
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
xgbParamsFactory.buildXGBRuntimeParams
|
||||
}
|
||||
|
||||
test("Customize host ip and python exec for Rabit tracker") {
|
||||
val hostIp = "192.168.22.111"
|
||||
val pythonExec = "/usr/bin/python3"
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.contains(hostIp))
|
||||
assert(cmd.startsWith("python"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(!cmd.contains(hostIp))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
case pyTracker: PyRabitTracker =>
|
||||
val cmd = pyTracker.getRabitTrackerCommand
|
||||
assert(cmd.startsWith(pythonExec))
|
||||
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||
case _ => assert(false, "expected python tracker implementation")
|
||||
}
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
*/
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new PyRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker. workerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
|
||||
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
|
||||
|
||||
The Java RabitTracker class reacts to exceptions by killing the spawned process running
|
||||
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
|
||||
it is killed, the resulted connection failure will trigger the Rabit worker to call
|
||||
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
|
||||
|
||||
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
|
||||
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
|
||||
running in separate containers. However, as unit tests are run in Spark local mode, in which
|
||||
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
|
||||
ultimately kills the entire process, which also happens to host the Spark driver, causing
|
||||
the entire Spark application to crash.
|
||||
|
||||
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
|
||||
the exception is thrown at last, ideally after all worker connections have been established.
|
||||
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
|
||||
process to ensure that pending worker connections are handled.
|
||||
*/
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("Communicator allreduce works.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker.workerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
|
||||
rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
Communicator.init(trackerEnvs)
|
||||
val a = Array(1.0f, 2.0f, 3.0f)
|
||||
System.out.println(a.mkString(", "))
|
||||
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
|
||||
for (i <- 0 to 2) {
|
||||
assert(a(i) * workerCount == b(i))
|
||||
}
|
||||
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
|
||||
for (i <- 0 to 2) {
|
||||
assert(a(i) == c(i))
|
||||
}
|
||||
Communicator.shutdown()
|
||||
Iterator(index)
|
||||
}.collect()
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
|
||||
|
||||
@ -23,7 +23,6 @@ import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
"objective_type" -> "classification")
|
||||
@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
intercept[SparkException] {
|
||||
xgb.fit(trainingDF)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
test("fail training elegantly with unsupported eval metrics") {
|
||||
|
||||
@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
|
||||
.fit(training)
|
||||
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
|
||||
).fit(training)
|
||||
assert(Communicator.communicatorEnvs.asScala.size > 3)
|
||||
Communicator.communicatorEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
|
||||
})
|
||||
// check the equality of single instance prediction
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
assert(math.abs(p1 - p2) < predictionErrorMin)
|
||||
}
|
||||
}
|
||||
|
||||
test("test rabit timeout fail handle") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
// mock rank 0 failure during 8th allreduce synchronization
|
||||
Communicator.mockList = Array("0,8,0,0").toList.asJava
|
||||
|
||||
intercept[SparkException] {
|
||||
new XGBoostClassifier(Map(
|
||||
"eta" -> "0.1",
|
||||
"max_depth" -> "10",
|
||||
"verbosity" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> 5,
|
||||
"num_workers" -> numWorkers,
|
||||
"rabit_timeout" -> 0))
|
||||
.fit(training)
|
||||
}
|
||||
|
||||
Communicator.mockList = Array.empty.toList.asJava
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -51,6 +51,11 @@ pom_template = """
|
||||
<artifactId>commons-logging</artifactId>
|
||||
<version>1.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.14.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${{scala.binary.version}}</artifactId>
|
||||
|
||||
@ -7,6 +7,9 @@ import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* Collective communicator global class for synchronization.
|
||||
*
|
||||
@ -30,8 +33,9 @@ public class Communicator {
|
||||
}
|
||||
|
||||
public enum DataType implements Serializable {
|
||||
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
|
||||
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
|
||||
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
|
||||
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
|
||||
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
|
||||
|
||||
private final int enumOp;
|
||||
private final int size;
|
||||
@ -56,30 +60,20 @@ public class Communicator {
|
||||
}
|
||||
}
|
||||
|
||||
// used as way to test/debug passed communicator init parameters
|
||||
public static Map<String, String> communicatorEnvs;
|
||||
public static List<String> mockList = new LinkedList<>();
|
||||
|
||||
/**
|
||||
* Initialize the collective communicator on current working thread.
|
||||
*
|
||||
* @param envs The additional environment variables to pass to the communicator.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
communicatorEnvs = envs;
|
||||
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey();
|
||||
args[idx++] = e.getValue();
|
||||
public static void init(Map<String, Object> envs) throws XGBoostError {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
String jconfig = mapper.writeValueAsString(envs);
|
||||
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
|
||||
}
|
||||
// pass list of rabit mock strings eg mock=0,1,0,0
|
||||
for (String mock : mockList) {
|
||||
args[idx++] = "mock";
|
||||
args[idx++] = mock;
|
||||
}
|
||||
checkCall(XGBoostJNI.CommunicatorInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Interface for Rabit tracker implementations with three public methods:
|
||||
* Interface for a tracker implementations with three public methods:
|
||||
*
|
||||
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
|
||||
* timeout value (in milliseconds.)
|
||||
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
|
||||
* - start(timeout): Start the tracker awaiting for worker connections, with a given
|
||||
* timeout value (in seconds).
|
||||
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
|
||||
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
public interface ITracker extends Thread.UncaughtExceptionHandler {
|
||||
enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, String> getWorkerEnvs();
|
||||
boolean start(long workerConnectionTimeout);
|
||||
void stop();
|
||||
// taskExecutionTimeout has no effect in current version of XGBoost.
|
||||
int waitFor(long taskExecutionTimeout);
|
||||
Map<String, Object> workerArgs() throws XGBoostError;
|
||||
|
||||
boolean start() throws XGBoostError;
|
||||
|
||||
void stop() throws XGBoostError;
|
||||
|
||||
void waitFor(long taskExecutionTimeout) throws XGBoostError;
|
||||
}
|
||||
@ -1,101 +1,40 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
|
||||
* start() and waitFor() methods (i.e., the timeout is infinite.)
|
||||
*
|
||||
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
|
||||
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
|
||||
* provides timeout support.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker implements IRabitTracker {
|
||||
public class RabitTracker implements ITracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
private static String tracker_py = null;
|
||||
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int numWorkers;
|
||||
private String hostIp = "";
|
||||
private String pythonExec = "";
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
private long handle = 0;
|
||||
private Thread tracker_daemon;
|
||||
|
||||
static {
|
||||
try {
|
||||
initTrackerPy();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load tracker library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
public RabitTracker(int numWorkers) throws XGBoostError {
|
||||
this(numWorkers, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracker logger that logs output from tracker.
|
||||
*/
|
||||
private class TrackerProcessLogger implements Runnable {
|
||||
public void run() {
|
||||
|
||||
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||
trackerProcess.get().getErrorStream()));
|
||||
String line;
|
||||
try {
|
||||
while ((line = reader.readLine()) != null) {
|
||||
trackerProcessLogger.info(line);
|
||||
}
|
||||
trackerProcess.get().waitFor();
|
||||
int exitValue = trackerProcess.get().exitValue();
|
||||
if (exitValue != 0) {
|
||||
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
|
||||
} else {
|
||||
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
trackerProcessLogger.error(ex.toString());
|
||||
} catch (InterruptedException ie) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
ie.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void initTrackerPy() throws IOException {
|
||||
try {
|
||||
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||
} catch (IOException ioe) {
|
||||
logger.trace("cannot access tracker python script");
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers)
|
||||
public RabitTracker(int numWorkers, String hostIp)
|
||||
throws XGBoostError {
|
||||
this(numWorkers, hostIp, 0, 300);
|
||||
}
|
||||
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||
}
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
|
||||
throws XGBoostError {
|
||||
this(numWorkers);
|
||||
this.hostIp = hostIp;
|
||||
this.pythonExec = pythonExec;
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
|
||||
this.handle = out[0];
|
||||
}
|
||||
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
trackerProcess.get().destroy();
|
||||
this.tracker_daemon.interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
*/
|
||||
public Map<String, String> getWorkerEnvs() {
|
||||
return envs;
|
||||
public Map<String, Object> workerArgs() throws XGBoostError {
|
||||
// fixme: timeout
|
||||
String[] args = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
|
||||
};
|
||||
Map<String, Object> config;
|
||||
try {
|
||||
config = mapper.readValue(args[0], typeRef);
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to get worker arguments.", ex);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private void loadEnvs(InputStream ins) throws IOException {
|
||||
try {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||
break;
|
||||
}
|
||||
String[] sep = line.split("=");
|
||||
if (sep.length == 2) {
|
||||
envs.put(sep[0], sep[1]);
|
||||
}
|
||||
public void stop() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
|
||||
}
|
||||
|
||||
public boolean start() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
|
||||
this.tracker_daemon = new Thread(() -> {
|
||||
try {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return; // exit the thread
|
||||
}
|
||||
reader.close();
|
||||
} catch (IOException ioe){
|
||||
logger.error("cannot get runtime configuration from tracker process");
|
||||
ioe.printStackTrace();
|
||||
throw ioe;
|
||||
}
|
||||
});
|
||||
this.tracker_daemon.setDaemon(true);
|
||||
this.tracker_daemon.start();
|
||||
|
||||
return this.tracker_daemon.isAlive();
|
||||
}
|
||||
|
||||
/** visible for testing */
|
||||
public String getRabitTrackerCommand() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (pythonExec == null || pythonExec.isEmpty()) {
|
||||
sb.append("python ");
|
||||
} else {
|
||||
sb.append(pythonExec + " ");
|
||||
}
|
||||
sb.append(" " + tracker_py + " ");
|
||||
sb.append(" --log-level=DEBUG" + " ");
|
||||
sb.append(" --num-workers=" + numWorkers + " ");
|
||||
|
||||
// we first check the property then check the parameter
|
||||
String hostIpFromProperties = trackerProperties.getHostIp();
|
||||
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
|
||||
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
|
||||
sb.append(" --host-ip=" + hostIpFromProperties + " ");
|
||||
} else if (hostIp != null & !hostIp.isEmpty()) {
|
||||
logger.debug("Using the parametr host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=" + hostIp + " ");
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
String cmd = getRabitTrackerCommand();
|
||||
trackerProcess.set(Runtime.getRuntime().exec(cmd));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean start(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for all workers to connect indefinitely, unless " +
|
||||
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
System.out.println("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
logger_thread.start();
|
||||
return true;
|
||||
} else {
|
||||
logger.error("FAULT: failed to start tracker process");
|
||||
stop();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public int waitFor(long timeout) {
|
||||
if (timeout > 0L) {
|
||||
logger.warn("Python RabitTracker does not support timeout. " +
|
||||
"The tracker will wait for either all workers to finish tasks and send " +
|
||||
"shutdown signal, or manual interruptions. " +
|
||||
"Use the Scala RabitTracker for timeout support.");
|
||||
}
|
||||
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
int returnVal = trackerProcess.get().exitValue();
|
||||
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||
stop();
|
||||
return returnVal;
|
||||
} catch (InterruptedException e) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
return TrackerStatus.INTERRUPTED.getStatusCode();
|
||||
}
|
||||
public void waitFor(long timeout) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@ -54,7 +54,7 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
|
||||
float[] data, int shapeParam,
|
||||
@ -146,12 +146,24 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorInit(String args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
public final static native int CommunicatorPrint(String msg);
|
||||
public final static native int CommunicatorGetRank(int[] out);
|
||||
public final static native int CommunicatorGetWorldSize(int[] out);
|
||||
|
||||
// Tracker functions
|
||||
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
|
||||
long[] out);
|
||||
|
||||
public final static native int TrackerRun(long handle);
|
||||
|
||||
public final static native int TrackerWaitFor(long handle, long timeout);
|
||||
|
||||
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
|
||||
|
||||
public final static native int TrackerFree(long handle);
|
||||
|
||||
// Perform Allreduce operation on data in sendrecvbuf.
|
||||
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
@ -168,5 +180,4 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
}
|
||||
|
||||
@ -42,5 +42,4 @@ public final class UtilUnsafe {
|
||||
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,20 +1,21 @@
|
||||
/**
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "./xgboost4j.h"
|
||||
|
||||
#include <rabit/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/json.h>
|
||||
@ -23,7 +24,6 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
|
||||
jclass jcls,
|
||||
jstring jargs) {
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
assert(len % 2 == 0);
|
||||
for (bst_ulong i = 0; i < len / 2; ++i) {
|
||||
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
|
||||
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
|
||||
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
|
||||
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
|
||||
config[key_str] = xgboost::String(value_str);
|
||||
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(args));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
|
||||
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
|
||||
jlongArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
TrackerHandle handle;
|
||||
Json config{Object{}};
|
||||
std::string shost{jenv->GetStringUTFChars(host, nullptr),
|
||||
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
|
||||
if (!shost.empty()) {
|
||||
config["host"] = shost;
|
||||
}
|
||||
std::string json_str;
|
||||
xgboost::Json::Dump(config, &json_str);
|
||||
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
|
||||
config["port"] = Integer{static_cast<Integer::Int>(port)};
|
||||
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
|
||||
config["dmlc_communicator"] = String{"rabit"};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
|
||||
setHandle(jenv, jout, handle);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
|
||||
jlong jhandle,
|
||||
jlong timeout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
Json config{Object{}};
|
||||
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
|
||||
std::string sconfig = Json::Dump(config);
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
char const *args;
|
||||
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
|
||||
auto jargs = Json::Load(StringView{args});
|
||||
|
||||
jstring jret = jenv->NewStringUTF(args);
|
||||
jenv->SetObjectArrayElement(jout, 0, jret);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
|
||||
jlong jhandle) {
|
||||
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
|
||||
JVM_CHECK_CALL(XGTrackerFree(handle));
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
* Method: CommunicatorFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
|
||||
jclass) {
|
||||
JVM_CHECK_CALL(XGCommunicatorFinalize());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerCreate
|
||||
* Signature: (Ljava/lang/String;IIIJ[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
|
||||
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerRun
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWaitFor
|
||||
* Signature: (JJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerWorkerArgs
|
||||
* Signature: (JJ[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
|
||||
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: TrackerFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorAllreduce
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -298,7 +298,7 @@ public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
Map<String, Object> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Communicator.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
|
||||
@ -31,31 +31,13 @@ protobuf_generate(
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
add_library(federated_old_proto STATIC federated.old.proto)
|
||||
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
xgboost_target_properties(federated_old_proto)
|
||||
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE cpp
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE grpc
|
||||
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
# Wrapper for the gRPC client.
|
||||
add_library(federated_client INTERFACE)
|
||||
target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
target_link_libraries(federated_client INTERFACE federated_old_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(
|
||||
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
|
||||
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc
|
||||
)
|
||||
if(USE_CUDA)
|
||||
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
|
||||
|
||||
@ -1,81 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
package xgboost.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
|
||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT8 = 0;
|
||||
UINT8 = 1;
|
||||
INT32 = 2;
|
||||
UINT32 = 3;
|
||||
INT64 = 4;
|
||||
UINT64 = 5;
|
||||
FLOAT = 6;
|
||||
DOUBLE = 7;
|
||||
}
|
||||
|
||||
enum ReduceOperation {
|
||||
MAX = 0;
|
||||
MIN = 1;
|
||||
SUM = 2;
|
||||
BITWISE_AND = 3;
|
||||
BITWISE_OR = 4;
|
||||
BITWISE_XOR = 5;
|
||||
}
|
||||
|
||||
message AllgatherRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllgatherVRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherVReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllreduceRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
DataType data_type = 4;
|
||||
ReduceOperation reduce_operation = 5;
|
||||
}
|
||||
|
||||
message AllreduceReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message BroadcastRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
// The root rank to broadcast from.
|
||||
int32 root = 4;
|
||||
}
|
||||
|
||||
message BroadcastReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
@ -1,132 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.old.grpc.pb.h>
|
||||
#include <federated.old.pb.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost::federated {
|
||||
/**
|
||||
* @brief A wrapper around the gRPC client.
|
||||
*/
|
||||
class FederatedClient {
|
||||
public:
|
||||
FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
|
||||
std::string const &client_key, std::string const &client_cert)
|
||||
: stub_{[&] {
|
||||
grpc::SslCredentialsOptions options;
|
||||
options.pem_root_certs = server_cert;
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel =
|
||||
grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return Federated::NewStub(channel);
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
/** @brief Insecure client for connecting to localhost only. */
|
||||
FederatedClient(std::string const &server_address, int rank)
|
||||
: stub_{[&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return Federated::NewStub(
|
||||
grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args));
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
std::string Allgather(std::string_view send_buffer) {
|
||||
AllgatherRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->Allgather(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("Allgather RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string AllgatherV(std::string_view send_buffer) {
|
||||
AllgatherVRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherVReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->AllgatherV(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("AllgatherV RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||
ReduceOperation reduce_operation) {
|
||||
AllreduceRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer);
|
||||
request.set_data_type(data_type);
|
||||
request.set_reduce_operation(reduce_operation);
|
||||
|
||||
AllreduceReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->Allreduce(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("Allreduce RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Broadcast(std::string const &send_buffer, int root) {
|
||||
BroadcastRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer);
|
||||
request.set_root(root);
|
||||
|
||||
BroadcastReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->Broadcast(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("Broadcast RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<Federated::Stub> const stub_;
|
||||
int const rank_;
|
||||
uint64_t sequence_number_{};
|
||||
};
|
||||
} // namespace xgboost::federated
|
||||
@ -1,195 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include "../../src/c_api/c_api_utils.h"
|
||||
#include "../../src/collective/communicator.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "federated_client.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief A Federated Learning communicator class that handles collective communication.
|
||||
*/
|
||||
class FederatedCommunicator : public Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new communicator based on JSON configuration.
|
||||
* @param config JSON configuration.
|
||||
* @return Communicator as specified by the JSON configuration.
|
||||
*/
|
||||
static Communicator *Create(Json const &config) {
|
||||
std::string server_address{};
|
||||
int world_size{0};
|
||||
int rank{-1};
|
||||
std::string server_cert{};
|
||||
std::string client_key{};
|
||||
std::string client_cert{};
|
||||
|
||||
// Parse environment variables first.
|
||||
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
|
||||
if (value != nullptr) {
|
||||
server_address = value;
|
||||
}
|
||||
value = getenv("FEDERATED_WORLD_SIZE");
|
||||
if (value != nullptr) {
|
||||
world_size = std::stoi(value);
|
||||
}
|
||||
value = getenv("FEDERATED_RANK");
|
||||
if (value != nullptr) {
|
||||
rank = std::stoi(value);
|
||||
}
|
||||
value = getenv("FEDERATED_SERVER_CERT");
|
||||
if (value != nullptr) {
|
||||
server_cert = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_KEY");
|
||||
if (value != nullptr) {
|
||||
client_key = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_CERT");
|
||||
if (value != nullptr) {
|
||||
client_cert = value;
|
||||
}
|
||||
|
||||
// Runtime configuration overrides, optional as users can specify them as env vars.
|
||||
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||
world_size =
|
||||
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
|
||||
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
|
||||
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
|
||||
|
||||
if (server_address.empty()) {
|
||||
LOG(FATAL) << "Federated server address must be set.";
|
||||
}
|
||||
if (world_size == 0) {
|
||||
LOG(FATAL) << "Federated world size must be set.";
|
||||
}
|
||||
if (rank == -1) {
|
||||
LOG(FATAL) << "Federated rank must be set.";
|
||||
}
|
||||
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Construct a new federated communicator.
|
||||
*
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
* @param server_address Address of the federated server (host:port).
|
||||
* @param server_cert_path Path to the server cert file.
|
||||
* @param client_key_path Path to the client key file.
|
||||
* @param client_cert_path Path to the client cert file.
|
||||
*/
|
||||
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
|
||||
std::string const &server_cert_path, std::string const &client_key_path,
|
||||
std::string const &client_cert_path)
|
||||
: Communicator{world_size, rank} {
|
||||
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
|
||||
} else {
|
||||
client_.reset(new xgboost::federated::FederatedClient(
|
||||
server_address, rank, xgboost::common::ReadAll(server_cert_path),
|
||||
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Construct an insecure federated communicator without using SSL.
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
* @param server_address Address of the federated server (host:port).
|
||||
*/
|
||||
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
|
||||
: Communicator{world_size, rank} {
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
|
||||
}
|
||||
|
||||
~FederatedCommunicator() override { client_.reset(); }
|
||||
|
||||
/**
|
||||
* \brief Get if the communicator is distributed.
|
||||
* \return True.
|
||||
*/
|
||||
[[nodiscard]] bool IsDistributed() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Get if the communicator is federated.
|
||||
* \return True.
|
||||
*/
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Perform allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
std::string AllGather(std::string_view input) override {
|
||||
return client_->Allgather(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform variable-length allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
return client_->AllgatherV(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform in-place allreduce.
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type.
|
||||
* \param op Enumeration of operation type.
|
||||
*/
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
|
||||
count * GetTypeSize(data_type));
|
||||
auto const received =
|
||||
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
|
||||
static_cast<xgboost::federated::ReduceOperation>(op));
|
||||
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Broadcast a memory region to all others from root.
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||
if (GetWorldSize() == 1) return;
|
||||
if (GetRank() == root) {
|
||||
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
|
||||
client_->Broadcast(send_buffer, root);
|
||||
} else {
|
||||
auto const received = client_->Broadcast("", root);
|
||||
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get the name of the processor.
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
|
||||
|
||||
/**
|
||||
* \brief Print the message to the communicator.
|
||||
* \param message The message to be printed.
|
||||
*/
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
|
||||
private:
|
||||
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -1,86 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "federated_server.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/server.h> // for Server
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "../../src/collective/comm.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/json_utils.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) {
|
||||
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) {
|
||||
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) {
|
||||
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
request->root());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
auto options =
|
||||
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
|
||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||
key.private_key = xgboost::common::ReadAll(server_key_file);
|
||||
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
|
||||
options.pem_key_cert_pairs.push_back(key);
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< world_size;
|
||||
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< world_size;
|
||||
|
||||
server->Wait();
|
||||
}
|
||||
} // namespace xgboost::federated
|
||||
@ -1,37 +0,0 @@
|
||||
/**
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.old.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) override;
|
||||
|
||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) override;
|
||||
|
||||
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) override;
|
||||
|
||||
private:
|
||||
xgboost::collective::InMemoryHandler handler_;
|
||||
};
|
||||
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
} // namespace xgboost::federated
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#include "federated_tracker.h"
|
||||
|
||||
@ -8,13 +8,12 @@
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception
|
||||
#include <future> // for future, async
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace federated {
|
||||
@ -36,8 +35,8 @@ grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||
static_cast<xgboost::ArrayInterfaceHandler::Type>(request->data_type()),
|
||||
static_cast<xgboost::collective::Op>(request->reduce_operation()));
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
@ -53,9 +52,13 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
StringView msg{"Empty certificate path."};
|
||||
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
|
||||
CHECK(!server_key_path_.empty()) << msg;
|
||||
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
|
||||
CHECK(!server_cert_file_.empty()) << msg;
|
||||
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
|
||||
CHECK(!client_cert_file_.empty()) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -5,11 +5,12 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "../sycl/device_manager.h"
|
||||
|
||||
#include "../../src/collective/communicator-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
|
||||
@ -21,22 +22,23 @@ namespace sycl {
|
||||
}
|
||||
|
||||
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
|
||||
(rabit::IsDistributed());
|
||||
(collective::IsDistributed());
|
||||
if (not_use_default_selector) {
|
||||
DeviceRegister& device_register = GetDevicesRegister();
|
||||
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
|
||||
const int device_idx =
|
||||
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
|
||||
if (device_spec.IsSyclDefault()) {
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
return devices[device_idx];
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
return devices[device_idx];
|
||||
} else if (device_spec.IsSyclCPU()) {
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
return cpu_devices[device_idx];
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
return cpu_devices[device_idx];
|
||||
} else {
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
return gpu_devices[device_idx];
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
return gpu_devices[device_idx];
|
||||
}
|
||||
} else {
|
||||
if (device_spec.IsSyclCPU()) {
|
||||
@ -62,24 +64,25 @@ namespace sycl {
|
||||
}
|
||||
|
||||
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
|
||||
(rabit::IsDistributed());
|
||||
(collective::IsDistributed());
|
||||
std::lock_guard<std::mutex> guard(queue_registering_mutex);
|
||||
if (not_use_default_selector) {
|
||||
DeviceRegister& device_register = GetDevicesRegister();
|
||||
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
|
||||
if (device_spec.IsSyclDefault()) {
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
|
||||
} else if (device_spec.IsSyclCPU()) {
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);;
|
||||
} else if (device_spec.IsSyclGPU()) {
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
|
||||
}
|
||||
DeviceRegister& device_register = GetDevicesRegister();
|
||||
const int device_idx =
|
||||
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
|
||||
if (device_spec.IsSyclDefault()) {
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
|
||||
} else if (device_spec.IsSyclCPU()) {
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
|
||||
} else if (device_spec.IsSyclGPU()) {
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
|
||||
}
|
||||
} else {
|
||||
if (device_spec.IsSyclCPU()) {
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v);
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#pragma GCC diagnostic pop
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
"""XGBoost collective communication related API."""
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typing import _T
|
||||
from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str
|
||||
from .core import _LIB, _check_call, build_info, c_str, make_jcargs, py_str
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.collective]")
|
||||
|
||||
@ -21,49 +21,35 @@ def init(**args: Any) -> None:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: Dict[str, Any]
|
||||
args :
|
||||
Keyword arguments representing the parameters and their values.
|
||||
|
||||
Accepted parameters:
|
||||
- xgboost_communicator: The type of the communicator. Can be set as an environment
|
||||
variable.
|
||||
- dmlc_communicator: The type of the communicator.
|
||||
* rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* federated: Use the gRPC interface for Federated Learning.
|
||||
Only applicable to the Rabit communicator (these are case sensitive):
|
||||
-- rabit_tracker_uri: Hostname of the tracker.
|
||||
-- rabit_tracker_port: Port number of the tracker.
|
||||
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
|
||||
assignment.
|
||||
-- rabit_world_size: Total number of workers.
|
||||
-- rabit_hadoop_mode: Enable Hadoop support.
|
||||
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
-- rabit_reduce_buffer: Size of the reduce buffer.
|
||||
-- rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
-- rabit_debug: Enable debugging.
|
||||
-- rabit_timeout: Enable timeout.
|
||||
-- rabit_timeout_sec: Timeout in seconds.
|
||||
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
environment variables):
|
||||
-- DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
-- DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
|
||||
assignment.
|
||||
-- DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
lower case for runtime configuration):
|
||||
-- federated_server_address: Address of the federated server.
|
||||
-- federated_world_size: Number of federated workers.
|
||||
-- federated_rank: Rank of the current worker.
|
||||
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
-- federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
|
||||
Only applicable to the Rabit communicator:
|
||||
- dmlc_tracker_uri: Hostname of the tracker.
|
||||
- dmlc_tracker_port: Port number of the tracker.
|
||||
- dmlc_task_id: ID of the current task, can be used to obtain deterministic
|
||||
- dmlc_retry: The number of retry when handling network errors.
|
||||
- dmlc_timeout: Timeout in seconds.
|
||||
- dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
|
||||
|
||||
Only applicable to the Federated communicator (use upper case for environment
|
||||
variables, use lower case for runtime configuration):
|
||||
|
||||
- federated_server_address: Address of the federated server.
|
||||
- federated_world_size: Number of federated workers.
|
||||
- federated_rank: Rank of the current worker.
|
||||
- federated_server_cert: Server certificate file path. Only needed for the SSL
|
||||
mode.
|
||||
- federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
- federated_client_cert: Client certificate file path. Only needed for the SSL
|
||||
mode.
|
||||
"""
|
||||
config = from_pystr_to_cstr(json.dumps(args))
|
||||
_check_call(_LIB.XGCommunicatorInit(config))
|
||||
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))
|
||||
|
||||
|
||||
def finalize() -> None:
|
||||
@ -157,7 +143,7 @@ def broadcast(data: _T, root: int) -> _T:
|
||||
assert data is not None, "need to pass in data when broadcasting"
|
||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
length.value = len(s)
|
||||
# run first broadcast
|
||||
# Run first broadcast
|
||||
_check_call(
|
||||
_LIB.XGCommunicatorBroadcast(
|
||||
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
|
||||
@ -184,16 +170,27 @@ def broadcast(data: _T, root: int) -> _T:
|
||||
|
||||
|
||||
# enumeration of dtypes
|
||||
DTYPE_ENUM__ = {
|
||||
np.dtype("int8"): 0,
|
||||
np.dtype("uint8"): 1,
|
||||
np.dtype("int32"): 2,
|
||||
np.dtype("uint32"): 3,
|
||||
np.dtype("int64"): 4,
|
||||
np.dtype("uint64"): 5,
|
||||
np.dtype("float32"): 6,
|
||||
np.dtype("float64"): 7,
|
||||
}
|
||||
def _map_dtype(dtype: np.dtype) -> int:
|
||||
dtype_map = {
|
||||
np.dtype("float16"): 0,
|
||||
np.dtype("float32"): 1,
|
||||
np.dtype("float64"): 2,
|
||||
np.dtype("int8"): 4,
|
||||
np.dtype("int16"): 5,
|
||||
np.dtype("int32"): 6,
|
||||
np.dtype("int64"): 7,
|
||||
np.dtype("uint8"): 8,
|
||||
np.dtype("uint16"): 9,
|
||||
np.dtype("uint32"): 10,
|
||||
np.dtype("uint64"): 11,
|
||||
}
|
||||
if platform.system() != "Windows":
|
||||
dtype_map.update({np.dtype("float128"): 3})
|
||||
|
||||
if dtype not in dtype_map:
|
||||
raise TypeError(f"data type {dtype} is not supported on the current platform.")
|
||||
|
||||
return dtype_map[dtype]
|
||||
|
||||
|
||||
@unique
|
||||
@ -229,24 +226,23 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid
|
||||
"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise TypeError("allreduce only takes in numpy.ndarray")
|
||||
buf = data.ravel()
|
||||
if buf.base is data.base:
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise TypeError(f"data type {buf.dtype} not supported")
|
||||
buf = data.ravel().copy()
|
||||
_check_call(
|
||||
_LIB.XGCommunicatorAllreduce(
|
||||
buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size,
|
||||
DTYPE_ENUM__[buf.dtype],
|
||||
_map_dtype(buf.dtype),
|
||||
int(op),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return buf
|
||||
|
||||
|
||||
def signal_error() -> None:
|
||||
"""Kill the process."""
|
||||
_check_call(_LIB.XGCommunicatorSignalError())
|
||||
|
||||
|
||||
class CommunicatorContext:
|
||||
"""A context controlling collective communicator initialization and finalization."""
|
||||
|
||||
|
||||
@ -295,7 +295,7 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
|
||||
if device and device.find(":") != -1:
|
||||
raise ValueError(
|
||||
"Distributed training doesn't support selecting device ordinal as GPUs are"
|
||||
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
|
||||
" managed by the distributed frameworks. use `device=cuda` or `device=gpu`"
|
||||
" instead."
|
||||
)
|
||||
|
||||
|
||||
@ -71,6 +71,7 @@ from xgboost.core import (
|
||||
Metric,
|
||||
Objective,
|
||||
QuantileDMatrix,
|
||||
XGBoostError,
|
||||
_check_distributed_params,
|
||||
_deprecate_positional_args,
|
||||
_expect,
|
||||
@ -90,7 +91,7 @@ from xgboost.sklearn import (
|
||||
_wrap_evaluation_matrices,
|
||||
xgboost_model_doc,
|
||||
)
|
||||
from xgboost.tracker import RabitTracker, get_host_ip
|
||||
from xgboost.tracker import RabitTracker
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .utils import get_n_threads
|
||||
@ -160,36 +161,38 @@ def _try_start_tracker(
|
||||
n_workers: int,
|
||||
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
||||
) -> Dict[str, Union[int, str]]:
|
||||
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
|
||||
env: Dict[str, Union[int, str]] = {}
|
||||
try:
|
||||
if isinstance(addrs[0], tuple):
|
||||
host_ip = addrs[0][0]
|
||||
port = addrs[0][1]
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=get_host_ip(host_ip),
|
||||
n_workers=n_workers,
|
||||
host_ip=host_ip,
|
||||
port=port,
|
||||
use_logger=False,
|
||||
sortby="task",
|
||||
)
|
||||
else:
|
||||
addr = addrs[0]
|
||||
assert isinstance(addr, str) or addr is None
|
||||
host_ip = get_host_ip(addr)
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||
n_workers=n_workers, host_ip=addr, sortby="task"
|
||||
)
|
||||
env.update(rabit_tracker.worker_envs())
|
||||
rabit_tracker.start(n_workers)
|
||||
thread = Thread(target=rabit_tracker.join)
|
||||
|
||||
rabit_tracker.start()
|
||||
thread = Thread(target=rabit_tracker.wait_for)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
except socket.error as e:
|
||||
if len(addrs) < 2 or e.errno != 99:
|
||||
env.update(rabit_tracker.worker_args())
|
||||
|
||||
except XGBoostError as e:
|
||||
if len(addrs) < 2:
|
||||
raise
|
||||
LOGGER.warning(
|
||||
"Failed to bind address '%s', trying to use '%s' instead.",
|
||||
"Failed to bind address '%s', trying to use '%s' instead. Error:\n %s",
|
||||
str(addrs[0]),
|
||||
str(addrs[1]),
|
||||
str(e),
|
||||
)
|
||||
env = _try_start_tracker(n_workers, addrs[1:])
|
||||
|
||||
|
||||
@ -1,45 +1,85 @@
|
||||
"""XGBoost Federated Learning related API."""
|
||||
"""XGBoost Experimental Federated Learning related API."""
|
||||
|
||||
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
|
||||
import ctypes
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .core import _LIB, _check_call, make_jcargs
|
||||
from .tracker import RabitTracker
|
||||
|
||||
|
||||
def run_federated_server(
|
||||
port: int,
|
||||
world_size: int,
|
||||
server_key_path: str = "",
|
||||
server_cert_path: str = "",
|
||||
client_cert_path: str = "",
|
||||
) -> None:
|
||||
"""Run the Federated Learning server.
|
||||
class FederatedTracker(RabitTracker):
|
||||
"""Tracker for federated training.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
port : int
|
||||
The port to listen on.
|
||||
world_size: int
|
||||
n_workers :
|
||||
The number of federated workers.
|
||||
server_key_path: str
|
||||
Path to the server private key file. SSL is turned off if empty.
|
||||
server_cert_path: str
|
||||
Path to the server certificate file. SSL is turned off if empty.
|
||||
client_cert_path: str
|
||||
Path to the client certificate file. SSL is turned off if empty.
|
||||
|
||||
port :
|
||||
The port to listen on.
|
||||
|
||||
secure :
|
||||
Whether this is a secure instance. If True, then the following arguments for SSL
|
||||
must be provided.
|
||||
|
||||
server_key_path :
|
||||
Path to the server private key file.
|
||||
|
||||
server_cert_path :
|
||||
Path to the server certificate file.
|
||||
|
||||
client_cert_path :
|
||||
Path to the client certificate file.
|
||||
|
||||
"""
|
||||
if build_info()["USE_FEDERATED"]:
|
||||
if not server_key_path or not server_cert_path or not client_cert_path:
|
||||
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
|
||||
else:
|
||||
_check_call(
|
||||
_LIB.XGBRunFederatedServer(
|
||||
port,
|
||||
world_size,
|
||||
c_str(server_key_path),
|
||||
c_str(server_cert_path),
|
||||
c_str(client_cert_path),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise XGBoostError(
|
||||
"XGBoost needs to be built with the federated learning plugin "
|
||||
"enabled in order to use this module"
|
||||
|
||||
def __init__( # pylint: disable=R0913, W0231
|
||||
self,
|
||||
n_workers: int,
|
||||
port: int,
|
||||
secure: bool,
|
||||
server_key_path: str = "",
|
||||
server_cert_path: str = "",
|
||||
client_cert_path: str = "",
|
||||
timeout: int = 300,
|
||||
) -> None:
|
||||
handle = ctypes.c_void_p()
|
||||
args = make_jcargs(
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
dmlc_communicator="federated",
|
||||
federated_secure=secure,
|
||||
server_key_path=server_key_path,
|
||||
server_cert_path=server_cert_path,
|
||||
client_cert_path=client_cert_path,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
|
||||
def run_federated_server( # pylint: disable=too-many-arguments
|
||||
n_workers: int,
|
||||
port: int,
|
||||
server_key_path: Optional[str] = None,
|
||||
server_cert_path: Optional[str] = None,
|
||||
client_cert_path: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
) -> Dict[str, Any]:
|
||||
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
|
||||
args: Dict[str, Any] = {"n_workers": n_workers}
|
||||
secure = all(
|
||||
path is not None
|
||||
for path in [server_key_path, server_cert_path, client_cert_path]
|
||||
)
|
||||
tracker = FederatedTracker(
|
||||
n_workers=n_workers, port=port, secure=secure, timeout=timeout
|
||||
)
|
||||
tracker.start()
|
||||
|
||||
thread = Thread(target=tracker.wait_for)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
args.update(tracker.worker_args())
|
||||
return args
|
||||
|
||||
@ -47,21 +47,21 @@ class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
|
||||
"""Context with PySpark specific task ID."""
|
||||
|
||||
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
||||
args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||
args["dmlc_task_id"] = str(context.partitionId())
|
||||
super().__init__(**args)
|
||||
|
||||
|
||||
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||
"""Start Rabit tracker with n_workers"""
|
||||
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
|
||||
args: Dict[str, Any] = {"n_workers": n_workers}
|
||||
host = _get_host_ip(context)
|
||||
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers, sortby="task")
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
thread = Thread(target=rabit_context.join)
|
||||
tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task")
|
||||
tracker.start()
|
||||
thread = Thread(target=tracker.wait_for)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return env
|
||||
args.update(tracker.worker_args())
|
||||
return args
|
||||
|
||||
|
||||
def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||
|
||||
@ -111,8 +111,6 @@ def no_sklearn() -> PytestSkip:
|
||||
|
||||
|
||||
def no_dask() -> PytestSkip:
|
||||
if sys.platform.startswith("win"):
|
||||
return {"reason": "Unsupported platform.", "condition": True}
|
||||
return no_mod("dask")
|
||||
|
||||
|
||||
@ -193,6 +191,10 @@ def no_multiple(*args: Any) -> PytestSkip:
|
||||
return {"condition": condition, "reason": reason}
|
||||
|
||||
|
||||
def skip_win() -> PytestSkip:
|
||||
return {"reason": "Unsupported platform.", "condition": is_windows()}
|
||||
|
||||
|
||||
def skip_s390x() -> PytestSkip:
|
||||
condition = platform.machine() == "s390x"
|
||||
reason = "Known to fail on s390x"
|
||||
@ -968,18 +970,18 @@ def run_with_rabit(
|
||||
exception_queue.put(e)
|
||||
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
tracker.start()
|
||||
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
|
||||
worker = threading.Thread(target=run_worker, args=(tracker.worker_args(),))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}"
|
||||
|
||||
tracker.join()
|
||||
tracker.wait_for()
|
||||
|
||||
|
||||
def column_split_feature_names(
|
||||
|
||||
@ -1,64 +1,12 @@
|
||||
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
|
||||
"""
|
||||
This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
|
||||
which is a specialized version for xgboost tasks.
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
"""Tracker for XGBoost collective."""
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
from threading import Thread
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from enum import IntEnum, unique
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
_RingMap = Dict[int, Tuple[int, int]]
|
||||
_TreeMap = Dict[int, List[int]]
|
||||
|
||||
|
||||
class ExSocket:
|
||||
"""
|
||||
Extension of socket to handle recv and send of special data
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
self.sock = sock
|
||||
|
||||
def recvall(self, nbytes: int) -> bytes:
|
||||
"""Receive number of bytes."""
|
||||
res = []
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||
nread += len(chunk)
|
||||
res.append(chunk)
|
||||
return b"".join(res)
|
||||
|
||||
def recvint(self) -> int:
|
||||
"""Receive an integer of 32 bytes"""
|
||||
return struct.unpack("@i", self.recvall(4))[0]
|
||||
|
||||
def sendint(self, value: int) -> None:
|
||||
"""Send an integer of 32 bytes"""
|
||||
self.sock.sendall(struct.pack("@i", value))
|
||||
|
||||
def sendstr(self, value: str) -> None:
|
||||
"""Send a Python string"""
|
||||
self.sendint(len(value))
|
||||
self.sock.sendall(value.encode())
|
||||
|
||||
def recvstr(self) -> str:
|
||||
"""Receive a Python string"""
|
||||
slen = self.recvint()
|
||||
return self.recvall(slen).decode()
|
||||
|
||||
|
||||
# magic number used to verify existence of data
|
||||
MAGIC_NUM = 0xFF99
|
||||
|
||||
|
||||
def get_some_ip(host: str) -> str:
|
||||
"""Get ip from host"""
|
||||
return socket.getaddrinfo(host, None)[0][4][0]
|
||||
from .core import _LIB, _check_call, make_jcargs
|
||||
|
||||
|
||||
def get_family(addr: str) -> int:
|
||||
@ -66,439 +14,95 @@ def get_family(addr: str) -> int:
|
||||
return socket.getaddrinfo(addr, None)[0][0]
|
||||
|
||||
|
||||
class WorkerEntry:
|
||||
"""Hanlder to each worker."""
|
||||
|
||||
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
||||
worker = ExSocket(sock)
|
||||
self.sock = worker
|
||||
self.host = get_some_ip(s_addr[0])
|
||||
magic = worker.recvint()
|
||||
assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
|
||||
worker.sendint(MAGIC_NUM)
|
||||
self.rank = worker.recvint()
|
||||
self.world_size = worker.recvint()
|
||||
self.task_id = worker.recvstr()
|
||||
self.cmd = worker.recvstr()
|
||||
self.wait_accept = 0
|
||||
self.port: Optional[int] = None
|
||||
|
||||
def print(self, use_logger: bool) -> None:
|
||||
"""Execute the print command from worker."""
|
||||
msg = self.sock.recvstr()
|
||||
# On dask we use print to avoid setting global verbosity.
|
||||
if use_logger:
|
||||
logging.info(msg.strip())
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
|
||||
def decide_rank(self, job_map: Dict[str, int]) -> int:
|
||||
"""Get the rank of current entry."""
|
||||
if self.rank >= 0:
|
||||
return self.rank
|
||||
if self.task_id != "NULL" and self.task_id in job_map:
|
||||
return job_map[self.task_id]
|
||||
return -1
|
||||
|
||||
def assign_rank(
|
||||
self,
|
||||
rank: int,
|
||||
wait_conn: Dict[int, "WorkerEntry"],
|
||||
tree_map: _TreeMap,
|
||||
parent_map: Dict[int, int],
|
||||
ring_map: _RingMap,
|
||||
) -> List[int]:
|
||||
"""Assign the rank for current entry."""
|
||||
self.rank = rank
|
||||
nnset = set(tree_map[rank])
|
||||
rprev, next_rank = ring_map[rank]
|
||||
self.sock.sendint(rank)
|
||||
# send parent rank
|
||||
self.sock.sendint(parent_map[rank])
|
||||
# send world size
|
||||
self.sock.sendint(len(tree_map))
|
||||
self.sock.sendint(len(nnset))
|
||||
# send the rprev and next link
|
||||
for r in nnset:
|
||||
self.sock.sendint(r)
|
||||
# send prev link
|
||||
if rprev not in (-1, rank):
|
||||
nnset.add(rprev)
|
||||
self.sock.sendint(rprev)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
# send next link
|
||||
if next_rank not in (-1, rank):
|
||||
nnset.add(next_rank)
|
||||
self.sock.sendint(next_rank)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
|
||||
return self._get_remote(wait_conn, nnset)
|
||||
|
||||
def _get_remote(
|
||||
self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int]
|
||||
) -> List[int]:
|
||||
while True:
|
||||
conset = []
|
||||
for r in badset:
|
||||
if r in wait_conn:
|
||||
conset.append(r)
|
||||
self.sock.sendint(len(conset))
|
||||
self.sock.sendint(len(badset) - len(conset))
|
||||
for r in conset:
|
||||
self.sock.sendstr(wait_conn[r].host)
|
||||
port = wait_conn[r].port
|
||||
assert port is not None
|
||||
# send port of this node to other workers so that they can call connect
|
||||
self.sock.sendint(port)
|
||||
self.sock.sendint(r)
|
||||
nerr = self.sock.recvint()
|
||||
if nerr != 0:
|
||||
continue
|
||||
self.port = self.sock.recvint()
|
||||
rmset = []
|
||||
# all connection was successuly setup
|
||||
for r in conset:
|
||||
wait_conn[r].wait_accept -= 1
|
||||
if wait_conn[r].wait_accept == 0:
|
||||
rmset.append(r)
|
||||
for r in rmset:
|
||||
wait_conn.pop(r, None)
|
||||
self.wait_accept = len(badset) - len(conset)
|
||||
return rmset
|
||||
|
||||
|
||||
class RabitTracker:
|
||||
"""
|
||||
tracker for rabit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_ip: str,
|
||||
n_workers: int,
|
||||
port: int = 0,
|
||||
use_logger: bool = False,
|
||||
sortby: str = "host",
|
||||
) -> None:
|
||||
"""A Python implementation of RABIT tracker.
|
||||
|
||||
Parameters
|
||||
..........
|
||||
use_logger:
|
||||
Use logging.info for tracker print command. When set to False, Python print
|
||||
function is used instead.
|
||||
|
||||
sortby:
|
||||
How to sort the workers for rank assignment. The default is host, but users
|
||||
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
|
||||
deterministic rank assignment. Available options are:
|
||||
- host
|
||||
- task
|
||||
|
||||
"""
|
||||
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
|
||||
sock.bind((host_ip, port))
|
||||
self.port = sock.getsockname()[1]
|
||||
sock.listen(256)
|
||||
self.sock = sock
|
||||
self.host_ip = host_ip
|
||||
self.thread: Optional[Thread] = None
|
||||
self.n_workers = n_workers
|
||||
self._use_logger = use_logger
|
||||
self._sortby = sortby
|
||||
logging.info("start listen on %s:%d", host_ip, self.port)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "sock"):
|
||||
self.sock.close()
|
||||
|
||||
@staticmethod
|
||||
def _get_neighbor(rank: int, n_workers: int) -> List[int]:
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank // 2 - 1)
|
||||
if rank * 2 - 1 < n_workers:
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < n_workers:
|
||||
ret.append(rank * 2)
|
||||
return ret
|
||||
|
||||
def worker_envs(self) -> Dict[str, Union[str, int]]:
|
||||
"""
|
||||
get environment variables for workers
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
|
||||
|
||||
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||
tree_map: _TreeMap = {}
|
||||
parent_map: Dict[int, int] = {}
|
||||
for r in range(n_workers):
|
||||
tree_map[r] = self._get_neighbor(r, n_workers)
|
||||
parent_map[r] = (r + 1) // 2 - 1
|
||||
return tree_map, parent_map
|
||||
|
||||
def find_share_ring(
|
||||
self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
|
||||
) -> List[int]:
|
||||
"""
|
||||
get a ring structure that tends to share nodes with the tree
|
||||
return a list starting from rank
|
||||
"""
|
||||
nset = set(tree_map[rank])
|
||||
cset = nset - {parent_map[rank]}
|
||||
if not cset:
|
||||
return [rank]
|
||||
rlst = [rank]
|
||||
cnt = 0
|
||||
for v in cset:
|
||||
vlst = self.find_share_ring(tree_map, parent_map, v)
|
||||
cnt += 1
|
||||
if cnt == len(cset):
|
||||
vlst.reverse()
|
||||
rlst += vlst
|
||||
return rlst
|
||||
|
||||
def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
|
||||
"""
|
||||
get a ring connection used to recover local data
|
||||
"""
|
||||
assert parent_map[0] == -1
|
||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||
assert len(rlst) == len(tree_map)
|
||||
ring_map: _RingMap = {}
|
||||
n_workers = len(tree_map)
|
||||
for r in range(n_workers):
|
||||
rprev = (r + n_workers - 1) % n_workers
|
||||
rnext = (r + 1) % n_workers
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
|
||||
def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
|
||||
"""
|
||||
get the link map, this is a bit hacky, call for better algorithm
|
||||
to place similar nodes together
|
||||
"""
|
||||
tree_map, parent_map = self._get_tree(n_workers)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
rmap = {0: 0}
|
||||
k = 0
|
||||
for i in range(n_workers - 1):
|
||||
k = ring_map[k][1]
|
||||
rmap[k] = i + 1
|
||||
|
||||
ring_map_: _RingMap = {}
|
||||
tree_map_: _TreeMap = {}
|
||||
parent_map_: Dict[int, int] = {}
|
||||
for k, v in ring_map.items():
|
||||
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||
for k, tree_nodes in tree_map.items():
|
||||
tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
|
||||
for k, parent in parent_map.items():
|
||||
if k != 0:
|
||||
parent_map_[rmap[k]] = rmap[parent]
|
||||
else:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
|
||||
if self._sortby == "host":
|
||||
pending.sort(key=lambda s: s.host)
|
||||
elif self._sortby == "task":
|
||||
pending.sort(key=lambda s: s.task_id)
|
||||
return pending
|
||||
|
||||
def accept_workers(self, n_workers: int) -> None:
|
||||
"""Wait for all workers to connect to the tracker."""
|
||||
|
||||
# set of nodes that finishes the job
|
||||
shutdown: Dict[int, WorkerEntry] = {}
|
||||
# set of nodes that is waiting for connections
|
||||
wait_conn: Dict[int, WorkerEntry] = {}
|
||||
# maps job id to rank
|
||||
job_map: Dict[str, int] = {}
|
||||
# list of workers that is pending to be assigned rank
|
||||
pending: List[WorkerEntry] = []
|
||||
# lazy initialize tree_map
|
||||
tree_map = None
|
||||
|
||||
while len(shutdown) != n_workers:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = WorkerEntry(fd, s_addr)
|
||||
if s.cmd == "print":
|
||||
s.print(self._use_logger)
|
||||
continue
|
||||
if s.cmd == "shutdown":
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
assert s.rank not in wait_conn
|
||||
shutdown[s.rank] = s
|
||||
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||
continue
|
||||
assert s.cmd == "start"
|
||||
# lazily initialize the workers
|
||||
if tree_map is None:
|
||||
assert s.cmd == "start"
|
||||
if s.world_size > 0:
|
||||
n_workers = s.world_size
|
||||
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
|
||||
# set of nodes that is pending for getting up
|
||||
todo_nodes = list(range(n_workers))
|
||||
else:
|
||||
assert s.world_size in (-1, n_workers)
|
||||
if s.cmd == "recover":
|
||||
assert s.rank >= 0
|
||||
|
||||
rank = s.decide_rank(job_map)
|
||||
# batch assignment of ranks
|
||||
if rank == -1:
|
||||
assert todo_nodes
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending = self._sort_pending(pending)
|
||||
for s in pending:
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.task_id != "NULL":
|
||||
job_map[s.task_id] = rank
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.debug(
|
||||
"Received %s signal from %s; assign rank %d",
|
||||
s.cmd,
|
||||
s.host,
|
||||
s.rank,
|
||||
)
|
||||
if not todo_nodes:
|
||||
logging.info("@tracker All of %d nodes getting started", n_workers)
|
||||
else:
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.info("@tracker All nodes finishes job")
|
||||
|
||||
def start(self, n_workers: int) -> None:
|
||||
"""Strat the tracker, it will wait for `n_workers` to connect."""
|
||||
|
||||
def run() -> None:
|
||||
self.accept_workers(n_workers)
|
||||
|
||||
self.thread = Thread(target=run, args=(), daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for the tracker to finish."""
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(100)
|
||||
|
||||
def alive(self) -> bool:
|
||||
"""Wether the tracker thread is alive"""
|
||||
return self.thread is not None and self.thread.is_alive()
|
||||
|
||||
|
||||
def get_host_ip(host_ip: Optional[str] = None) -> str:
|
||||
"""Get the IP address of current host. If `host_ip` is not none then it will be
|
||||
returned as it's
|
||||
|
||||
"""
|
||||
if host_ip is None or host_ip == "auto":
|
||||
host_ip = "ip"
|
||||
|
||||
if host_ip == "dns":
|
||||
host_ip = socket.getfqdn()
|
||||
elif host_ip == "ip":
|
||||
from socket import gaierror
|
||||
|
||||
try:
|
||||
host_ip = socket.gethostbyname(socket.getfqdn())
|
||||
except gaierror:
|
||||
logging.debug(
|
||||
"gethostbyname(socket.getfqdn()) failed... trying on hostname()"
|
||||
)
|
||||
host_ip = socket.gethostbyname(socket.gethostname())
|
||||
if host_ip.startswith("127."):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# doesn't have to be reachable
|
||||
s.connect(("10.255.255.255", 1))
|
||||
host_ip = s.getsockname()[0]
|
||||
|
||||
assert host_ip is not None
|
||||
return host_ip
|
||||
|
||||
|
||||
def start_rabit_tracker(args: argparse.Namespace) -> None:
|
||||
"""Standalone function to start rabit tracker.
|
||||
"""Tracker for the collective used in XGBoost, acting as a coordinator between
|
||||
workers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: arguments to start the rabit tracker.
|
||||
..........
|
||||
sortby:
|
||||
|
||||
How to sort the workers for rank assignment. The default is host, but users can
|
||||
set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
|
||||
deterministic rank assignment. Available options are:
|
||||
- host
|
||||
- task
|
||||
|
||||
timeout :
|
||||
|
||||
Timeout for constructing the communication group and waiting for the tracker to
|
||||
shutdown when it's instructed to, doesn't apply to communication when tracking
|
||||
is running.
|
||||
|
||||
The timeout value should take the time of data loading and pre-processing into
|
||||
account, due to potential lazy execution.
|
||||
|
||||
The :py:meth:`.wait_for` method has a different timeout parameter that can stop
|
||||
the tracker even if the tracker is still being used. A value error is raised
|
||||
when timeout is reached.
|
||||
|
||||
"""
|
||||
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
|
||||
rabit = RabitTracker(
|
||||
host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
|
||||
)
|
||||
envs.update(rabit.worker_envs())
|
||||
rabit.start(args.num_workers)
|
||||
sys.stdout.write("DMLC_TRACKER_ENV_START\n")
|
||||
# simply write configuration to stdout
|
||||
for k, v in envs.items():
|
||||
sys.stdout.write(f"{k}={v}\n")
|
||||
sys.stdout.write("DMLC_TRACKER_ENV_END\n")
|
||||
sys.stdout.flush()
|
||||
rabit.join()
|
||||
|
||||
@unique
|
||||
class _SortBy(IntEnum):
|
||||
HOST = 0
|
||||
TASK = 1
|
||||
|
||||
def main() -> None:
|
||||
"""Main function if tracker is executed in standalone mode."""
|
||||
parser = argparse.ArgumentParser(description="Rabit Tracker start.")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
required=True,
|
||||
type=int,
|
||||
help="Number of worker process to be launched.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-servers",
|
||||
default=0,
|
||||
type=int,
|
||||
help="Number of server process to be launched. Only used in PS jobs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host-ip",
|
||||
default=None,
|
||||
type=str,
|
||||
help=(
|
||||
"Host IP addressed, this is only needed "
|
||||
+ "if the host IP cannot be automatically guessed."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
type=str,
|
||||
choices=["INFO", "DEBUG"],
|
||||
help="Logging level of the logger.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
n_workers: int,
|
||||
host_ip: Optional[str],
|
||||
port: int = 0,
|
||||
sortby: str = "host",
|
||||
timeout: int = 0,
|
||||
) -> None:
|
||||
|
||||
fmt = "%(asctime)s %(levelname)s %(message)s"
|
||||
if args.log_level == "INFO":
|
||||
level = logging.INFO
|
||||
elif args.log_level == "DEBUG":
|
||||
level = logging.DEBUG
|
||||
else:
|
||||
raise RuntimeError(f"Unknown logging level {args.log_level}")
|
||||
handle = ctypes.c_void_p()
|
||||
if sortby not in ("host", "task"):
|
||||
raise ValueError("Expecting either 'host' or 'task' for sortby.")
|
||||
if host_ip is not None:
|
||||
get_family(host_ip) # use python socket to stop early for invalid address
|
||||
args = make_jcargs(
|
||||
host=host_ip,
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
dmlc_communicator="rabit",
|
||||
sortby=self._SortBy.HOST if sortby == "host" else self._SortBy.TASK,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
logging.basicConfig(format=fmt, level=level)
|
||||
def free(self) -> None:
|
||||
"""Internal function for testing."""
|
||||
if hasattr(self, "handle"):
|
||||
handle = self.handle
|
||||
del self.handle
|
||||
_check_call(_LIB.XGTrackerFree(handle))
|
||||
|
||||
if args.num_servers == 0:
|
||||
start_rabit_tracker(args)
|
||||
else:
|
||||
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
|
||||
def __del__(self) -> None:
|
||||
self.free()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the tracker. Once started, the client still need to call the
|
||||
:py:meth:`wait_for` method in order to wait for it to finish (think of it as a
|
||||
thread).
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
||||
_check_call(_LIB.XGTrackerRun(self.handle, make_jcargs()))
|
||||
|
||||
def wait_for(self, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the tracker to finish all the work and shutdown. When timeout is
|
||||
reached, a value error is raised. By default we don't have timeout since we
|
||||
don't know how long it takes for the model to finish training.
|
||||
|
||||
"""
|
||||
_check_call(_LIB.XGTrackerWaitFor(self.handle, make_jcargs(timeout=timeout)))
|
||||
|
||||
def worker_args(self) -> Dict[str, Union[str, int]]:
|
||||
"""Get arguments for workers."""
|
||||
c_env = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGTrackerWorkerArgs(self.handle, ctypes.byref(c_env)))
|
||||
assert c_env.value is not None
|
||||
env = json.loads(c_env.value)
|
||||
return env
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
set(RABIT_SOURCES
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
|
||||
|
||||
if(RABIT_MOCK)
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
|
||||
else()
|
||||
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine.cc)
|
||||
endif()
|
||||
|
||||
set(RABIT_SOURCES ${RABIT_SOURCES} PARENT_SCOPE)
|
||||
@ -1,28 +0,0 @@
|
||||
Copyright (c) 2014 by Contributors
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of rabit nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# This directory contains the CPU network module for XGBoost. The library originates from [RABIT](https://github.com/dmlc/rabit)
|
||||
@ -1,19 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2020 by Contributors
|
||||
* \file base.h
|
||||
* \brief Macros common to all headers
|
||||
*
|
||||
* \author Hyunsu Cho
|
||||
*/
|
||||
|
||||
#ifndef RABIT_BASE_H_
|
||||
#define RABIT_BASE_H_
|
||||
|
||||
#ifndef _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#endif // _CRT_SECURE_NO_WARNINGS
|
||||
#ifndef _CRT_SECURE_NO_DEPRECATE
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#endif // _CRT_SECURE_NO_DEPRECATE
|
||||
|
||||
#endif // RABIT_BASE_H_
|
||||
@ -1,157 +0,0 @@
|
||||
/*!
|
||||
* Copyright by Contributors
|
||||
* \file c_api.h
|
||||
* \author Tianqi Chen
|
||||
* \brief a C style API of rabit.
|
||||
*/
|
||||
#ifndef RABIT_C_API_H_
|
||||
#define RABIT_C_API_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
#define RABIT_EXTERN_C extern "C"
|
||||
#include <cstdio>
|
||||
#else
|
||||
#define RABIT_EXTERN_C
|
||||
#include <stdio.h>
|
||||
#endif // __cplusplus
|
||||
|
||||
#if defined(_MSC_VER) || defined(_WIN32)
|
||||
#define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport)
|
||||
#else
|
||||
#define RABIT_DLL RABIT_EXTERN_C __attribute__ ((visibility ("default")))
|
||||
#endif // defined(_MSC_VER) || defined(_WIN32)
|
||||
|
||||
/*! \brief rabit unsigned long type */
|
||||
typedef unsigned long rbt_ulong; // NOLINT(*)
|
||||
|
||||
/*!
|
||||
* \brief initialize the rabit module,
|
||||
* call this once before using anything
|
||||
* The additional arguments is not necessary.
|
||||
* Usually rabit will detect settings
|
||||
* from environment variables.
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL bool RabitInit(int argc, char *argv[]);
|
||||
|
||||
/*!
|
||||
* \brief finalize the rabit engine,
|
||||
* call this function after you finished all jobs.
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL int RabitFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of previous process in ring topology
|
||||
* \return rank number of worker
|
||||
* */
|
||||
RABIT_DLL int RabitGetRingPrevRank(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
* \return rank number of worker
|
||||
* */
|
||||
RABIT_DLL int RabitGetRank(void);
|
||||
|
||||
/*!
|
||||
* \brief get total number of process
|
||||
* \return total world size
|
||||
* */
|
||||
RABIT_DLL int RabitGetWorldSize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
* \return if rabit is distributed
|
||||
* */
|
||||
RABIT_DLL int RabitIsDistributed(void);
|
||||
|
||||
/*!
|
||||
* \brief print the msg to the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg the message to be printed
|
||||
*/
|
||||
RABIT_DLL int RabitTrackerPrint(const char *msg);
|
||||
/*!
|
||||
* \brief get name of processor
|
||||
* \param out_name hold output string
|
||||
* \param out_len hold length of output string
|
||||
* \param max_len maximum buffer length of input
|
||||
*/
|
||||
RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||
rbt_ulong *out_len,
|
||||
rbt_ulong max_len);
|
||||
/*!
|
||||
* \brief broadcast an memory region to all others from root
|
||||
*
|
||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||
* \param sendrecv_data the pointer to send or receive buffer,
|
||||
* \param size the size of the data
|
||||
* \param root the root of process
|
||||
*/
|
||||
RABIT_DLL int RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root);
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
|
||||
* \param size_node_slice size of the current node slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
RABIT_DLL int RabitAllgather(void *sendrecvbuf, size_t total_size,
|
||||
size_t beginIndex, size_t size_node_slice,
|
||||
size_t size_prev_slice, int enum_dtype);
|
||||
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size());
|
||||
* ...
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
|
||||
* \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit
|
||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
|
||||
int enum_op, void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg);
|
||||
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \return rabit version number
|
||||
*/
|
||||
RABIT_DLL int RabitVersionNumber(void);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief a Dummy function,
|
||||
* used to cause force link of C API into the DLL.
|
||||
* \code
|
||||
* \/\/force link rabit C API library.
|
||||
* static int must_link_rabit_ = RabitLinkTag();
|
||||
* \endcode
|
||||
* \return a dummy integer.
|
||||
*/
|
||||
RABIT_DLL int RabitLinkTag(void);
|
||||
|
||||
#endif // RABIT_C_API_H_
|
||||
@ -1,197 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine.h
|
||||
* \brief This file defines the core interface of rabit library
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_ENGINE_H_
|
||||
#define RABIT_INTERNAL_ENGINE_H_
|
||||
#include <string>
|
||||
#include "rabit/serializable.h"
|
||||
|
||||
namespace MPI { // NOLINT
|
||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||
class Datatype;
|
||||
}
|
||||
|
||||
/*! \brief namespace of rabit */
|
||||
namespace rabit {
|
||||
/*! \brief core interface of the engine */
|
||||
namespace engine {
|
||||
/*! \brief interface of core Allreduce engine */
|
||||
class IEngine {
|
||||
public:
|
||||
/*!
|
||||
* \brief Preprocessing function, that is called before AllReduce,
|
||||
* used to prepare the data used by AllReduce
|
||||
* \param arg additional possible argument used to invoke the preprocessor
|
||||
*/
|
||||
typedef void (PreprocFunction) (void *arg); // NOLINT
|
||||
/*!
|
||||
* \brief reduce function, the same form of MPI reduce function is used,
|
||||
* to be compatible with MPI interface
|
||||
* In all the functions, the memory is ensured to aligned to 64-bit
|
||||
* which means it is OK to cast src,dst to double* int* etc
|
||||
* \param src pointer to source space
|
||||
* \param dst pointer to destination reduction
|
||||
* \param count total number of elements to be reduced (note this is total number of elements instead of bytes)
|
||||
* the definition of the reduce function should be type aware
|
||||
* \param dtype the data type object, to be compatible with MPI reduce
|
||||
*/
|
||||
typedef void (ReduceFunction) (const void *src, // NOLINT
|
||||
void *dst, int count,
|
||||
const MPI::Datatype &dtype);
|
||||
/*! \brief virtual destructor */
|
||||
~IEngine() = default;
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) = 0;
|
||||
/*!
|
||||
* \brief performs in-place Allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the number of bytes the type has
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr) = 0;
|
||||
/*!
|
||||
* \brief broadcasts data from root to every other node
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0;
|
||||
/*!
|
||||
* deprecated
|
||||
*/
|
||||
virtual int LoadCheckPoint() = 0;
|
||||
/*!
|
||||
* \brief Increase internal version number. Deprecated.
|
||||
*/
|
||||
virtual void CheckPoint() = 0;
|
||||
/*!
|
||||
* \return version number of the current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
virtual int VersionNumber() const = 0;
|
||||
/*! \brief gets rank of previous node in ring topology */
|
||||
virtual int GetRingPrevRank() const = 0;
|
||||
/*! \brief gets rank of current node */
|
||||
virtual int GetRank() const = 0;
|
||||
/*! \brief gets total number of nodes */
|
||||
virtual int GetWorldSize() const = 0;
|
||||
/*! \brief whether we run in distribted mode */
|
||||
virtual bool IsDistributed() const = 0;
|
||||
/*! \brief gets the host name of the current node */
|
||||
virtual std::string GetHost() const = 0;
|
||||
/*!
|
||||
* \brief prints the msg in the tracker,
|
||||
* this function can be used to communicate progress information to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg) = 0;
|
||||
};
|
||||
|
||||
/*! \brief initializes the engine module */
|
||||
bool Init(int argc, char *argv[]);
|
||||
/*! \brief finalizes the engine module */
|
||||
bool Finalize();
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine();
|
||||
|
||||
/*! \brief namespace that contains stubs to be compatible with MPI */
|
||||
namespace mpi {
|
||||
/*!\brief enum of all operators */
|
||||
enum OpType {
|
||||
kMax = 0,
|
||||
kMin = 1,
|
||||
kSum = 2,
|
||||
kBitwiseAND = 3,
|
||||
kBitwiseOR = 4,
|
||||
kBitwiseXOR = 5,
|
||||
};
|
||||
/*!\brief enum of supported data types */
|
||||
enum DataType {
|
||||
kChar = 0,
|
||||
kUChar = 1,
|
||||
kInt = 2,
|
||||
kUInt = 3,
|
||||
kLong = 4,
|
||||
kULong = 5,
|
||||
kFloat = 6,
|
||||
kDouble = 7,
|
||||
kLongLong = 8,
|
||||
kULongLong = 9
|
||||
};
|
||||
} // namespace mpi
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
*/
|
||||
void Allgather(void* sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice);
|
||||
/*!
|
||||
* \brief perform in-place Allreduce, on sendrecvbuf
|
||||
* this is an internal function used by rabit to be able to compile with MPI
|
||||
* do not use this function directly
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param type_nbytes the number of bytes the type has
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param dtype the data type
|
||||
* \param op the reduce operator type
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function.
|
||||
*/
|
||||
void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr);
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_ENGINE_H_
|
||||
@ -1,118 +0,0 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* \file io.h
|
||||
* \brief utilities with different serializable implementations
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_IO_H_
|
||||
#define RABIT_INTERNAL_IO_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdio>
|
||||
#include <cstring> // for memcpy
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dmlc/io.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace rabit::utils {
|
||||
/*! \brief re-use definition of dmlc::SeekStream */
|
||||
using SeekStream = dmlc::SeekStream;
|
||||
/**
|
||||
* @brief Fixed size memory buffer as a stream.
|
||||
*/
|
||||
struct MemoryFixSizeBuffer : public SeekStream {
|
||||
public:
|
||||
// similar to SEEK_END in libc
|
||||
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Ctor
|
||||
*
|
||||
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
|
||||
* @param buffer_size Size of the source buffer
|
||||
*/
|
||||
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
|
||||
~MemoryFixSizeBuffer() override = default;
|
||||
|
||||
std::size_t Read(void *ptr, std::size_t size) override {
|
||||
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
void Write(const void *ptr, std::size_t size) override {
|
||||
if (size == 0) return;
|
||||
CHECK_LE(curr_ptr_ + size, buffer_size_);
|
||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
void Seek(std::size_t pos) override {
|
||||
if (pos == kSeekEnd) {
|
||||
curr_ptr_ = buffer_size_;
|
||||
} else {
|
||||
curr_ptr_ = static_cast<std::size_t>(pos);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Current position in the buffer (stream).
|
||||
*/
|
||||
std::size_t Tell() override { return curr_ptr_; }
|
||||
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
|
||||
|
||||
protected:
|
||||
/*! \brief in memory buffer */
|
||||
char *p_buffer_{nullptr};
|
||||
/*! \brief current pointer */
|
||||
std::size_t buffer_size_{0};
|
||||
/*! \brief current pointer */
|
||||
std::size_t curr_ptr_{0};
|
||||
};
|
||||
|
||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||
struct MemoryBufferStream : public SeekStream {
|
||||
public:
|
||||
explicit MemoryBufferStream(std::string *p_buffer)
|
||||
: p_buffer_(p_buffer) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
~MemoryBufferStream() override = default;
|
||||
size_t Read(void *ptr, size_t size) override {
|
||||
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
|
||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
void Write(const void *ptr, size_t size) override {
|
||||
if (size == 0) return;
|
||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||
p_buffer_->resize(curr_ptr_+size);
|
||||
}
|
||||
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
void Seek(size_t pos) override {
|
||||
curr_ptr_ = static_cast<size_t>(pos);
|
||||
}
|
||||
size_t Tell() override {
|
||||
return curr_ptr_;
|
||||
}
|
||||
virtual bool AtEnd() const {
|
||||
return curr_ptr_ == p_buffer_->length();
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief in memory buffer */
|
||||
std::string *p_buffer_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryBufferStream
|
||||
} // namespace rabit::utils
|
||||
#endif // RABIT_INTERNAL_IO_H_
|
||||
@ -1,234 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2019 by Contributors
|
||||
* \file rabit-inl.h
|
||||
* \brief implementation of inline template function for rabit interface
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_RABIT_INL_H_
|
||||
#define RABIT_INTERNAL_RABIT_INL_H_
|
||||
// use engine for implementation
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "rabit/internal/io.h"
|
||||
#include "rabit/internal/utils.h"
|
||||
#include "rabit/rabit.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
namespace mpi {
|
||||
// template function to translate type to enum indicator
|
||||
template<typename DType>
|
||||
inline DataType GetType();
|
||||
template<>
|
||||
inline DataType GetType<char>() {
|
||||
return kChar;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned char>() {
|
||||
return kUChar;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<int>() {
|
||||
return kInt;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned int>() { // NOLINT(*)
|
||||
return kUInt;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<long>() { // NOLINT(*)
|
||||
return kLong;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned long>() { // NOLINT(*)
|
||||
return kULong;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<float>() {
|
||||
return kFloat;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<double>() {
|
||||
return kDouble;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<long long>() { // NOLINT(*)
|
||||
return kLongLong;
|
||||
}
|
||||
template<>
|
||||
inline DataType GetType<unsigned long long>() { // NOLINT(*)
|
||||
return kULongLong;
|
||||
}
|
||||
} // namespace mpi
|
||||
} // namespace engine
|
||||
|
||||
namespace op {
|
||||
struct Max {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kMax;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
if (dst < src) dst = src;
|
||||
}
|
||||
};
|
||||
struct Min {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kMin;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
if (dst > src) dst = src;
|
||||
}
|
||||
};
|
||||
struct Sum {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kSum;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
dst += src;
|
||||
}
|
||||
};
|
||||
struct BitAND {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
dst &= src;
|
||||
}
|
||||
};
|
||||
struct BitOR {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
dst |= src;
|
||||
}
|
||||
};
|
||||
struct BitXOR {
|
||||
static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR;
|
||||
template<typename DType>
|
||||
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
|
||||
dst ^= src;
|
||||
}
|
||||
};
|
||||
template <typename OP, typename DType>
|
||||
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) {
|
||||
const DType *src = static_cast<const DType *>(src_);
|
||||
DType *dst = (DType *)dst_; // NOLINT(*)
|
||||
for (int i = 0; i < len; i++) {
|
||||
OP::Reduce(dst[i], src[i]);
|
||||
}
|
||||
}
|
||||
} // namespace op
|
||||
|
||||
// initialize the rabit engine
|
||||
inline bool Init(int argc, char *argv[]) {
|
||||
return engine::Init(argc, argv);
|
||||
}
|
||||
// finalize the rabit engine
|
||||
inline bool Finalize() {
|
||||
return engine::Finalize();
|
||||
}
|
||||
// get the rank of the previous worker in ring topology
|
||||
inline int GetRingPrevRank() {
|
||||
return engine::GetEngine()->GetRingPrevRank();
|
||||
}
|
||||
// get the rank of current process
|
||||
inline int GetRank() {
|
||||
return engine::GetEngine()->GetRank();
|
||||
}
|
||||
// the the size of the world
|
||||
inline int GetWorldSize() {
|
||||
return engine::GetEngine()->GetWorldSize();
|
||||
}
|
||||
// whether rabit is distributed
|
||||
inline bool IsDistributed() {
|
||||
return engine::GetEngine()->IsDistributed();
|
||||
}
|
||||
// get the name of current processor
|
||||
inline std::string GetProcessorName() {
|
||||
return engine::GetEngine()->GetHost();
|
||||
}
|
||||
// broadcast data to all other nodes from root
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
|
||||
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
|
||||
}
|
||||
template<typename DType>
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->size();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->size() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
|
||||
}
|
||||
}
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
}
|
||||
}
|
||||
|
||||
// perform inplace Allreduce
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
|
||||
}
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
inline void InvokeLambda(void *fun) {
|
||||
(*static_cast<std::function<void()>*>(fun))();
|
||||
}
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun) {
|
||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun);
|
||||
}
|
||||
|
||||
// Performs inplace Allgather
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf,
|
||||
size_t totalSize,
|
||||
size_t beginIndex,
|
||||
size_t sizeNodeSlice,
|
||||
size_t sizePrevSlice) {
|
||||
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
|
||||
(beginIndex + sizeNodeSlice) * sizeof(DType),
|
||||
sizePrevSlice * sizeof(DType));
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// print message to the tracker
|
||||
inline void TrackerPrint(const std::string &msg) {
|
||||
engine::GetEngine()->TrackerPrint(msg);
|
||||
}
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
inline void TrackerPrintf(const char *fmt, ...) {
|
||||
const int kPrintBuffer = 1 << 10;
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
msg.resize(strlen(msg.c_str()));
|
||||
TrackerPrint(msg);
|
||||
}
|
||||
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
// deprecated, planned for removal after checkpoing from JVM package is removed.
|
||||
inline int LoadCheckPoint() { return engine::GetEngine()->LoadCheckPoint(); }
|
||||
// deprecated, increase internal version number
|
||||
inline void CheckPoint() { engine::GetEngine()->CheckPoint(); }
|
||||
// return the version number of currently stored model
|
||||
inline int VersionNumber() {
|
||||
return engine::GetEngine()->VersionNumber();
|
||||
}
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_RABIT_INL_H_
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file socket.h
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@ -95,7 +95,10 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
|
||||
template <typename E>
|
||||
std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
|
||||
if ((revents & POLLERR) != 0) {
|
||||
return xgboost::system::FailWithCode("Poll error condition.");
|
||||
auto err = errno;
|
||||
auto str = strerror(err);
|
||||
return xgboost::system::FailWithCode(std::string{"Poll error condition:"} + std::string{str} +
|
||||
" code:" + std::to_string(err));
|
||||
}
|
||||
if ((revents & POLLNVAL) != 0) {
|
||||
return xgboost::system::FailWithCode("Invalid polling request.");
|
||||
@ -211,12 +214,7 @@ struct PollHelper {
|
||||
}
|
||||
|
||||
auto revents = pfd.revents & pfd.events;
|
||||
if (!revents) {
|
||||
// FIXME(jiamingy): remove this once rabit is replaced.
|
||||
fds.erase(pfd.fd);
|
||||
} else {
|
||||
fds[pfd.fd].events = revents;
|
||||
}
|
||||
fds[pfd.fd].events = revents;
|
||||
}
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
|
||||
@ -1,146 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file utils.h
|
||||
* \brief simple utils to support the code
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_UTILS_H_
|
||||
#define RABIT_INTERNAL_UTILS_H_
|
||||
|
||||
#include <rabit/base.h>
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dmlc/io.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
#define fopen64 std::fopen
|
||||
#endif // !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
|
||||
#ifndef _MSC_VER
|
||||
|
||||
#ifdef _FILE_OFFSET_BITS
|
||||
#if _FILE_OFFSET_BITS == 32
|
||||
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
|
||||
#endif // _FILE_OFFSET_BITS == 32
|
||||
#endif // _FILE_OFFSET_BITS
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define off64_t off_t
|
||||
#define fopen64 std::fopen
|
||||
#endif // __APPLE__
|
||||
|
||||
extern "C" {
|
||||
#include <sys/types.h>
|
||||
}
|
||||
#endif // _MSC_VER
|
||||
|
||||
#include <cinttypes>
|
||||
|
||||
namespace rabit {
|
||||
/*! \brief namespace for helper utils of the project */
|
||||
namespace utils {
|
||||
|
||||
/*! \brief error message buffer length */
|
||||
const int kPrintBuffer = 1 << 12;
|
||||
|
||||
/* \brief Case-insensitive string comparison */
|
||||
inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
|
||||
#ifdef _MSC_VER
|
||||
return _stricmp(s1, s2);
|
||||
#else // _MSC_VER
|
||||
return strcasecmp(s1, s2);
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
/* \brief parse config string too bool*/
|
||||
inline bool StringToBool(const char* s) {
|
||||
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
|
||||
}
|
||||
|
||||
/*! \brief printf, prints messages to the console */
|
||||
inline void Printf(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
LOG(CONSOLE) << msg;
|
||||
}
|
||||
|
||||
/*! \brief assert a condition is true, use this to handle debug information */
|
||||
inline void Assert(bool exp, const char *fmt, ...) {
|
||||
if (!exp) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
/*!\brief same as assert, but this is intended to be used as a message for users */
|
||||
inline void Check(bool exp, const char *fmt, ...) {
|
||||
if (!exp) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief report error message, same as check */
|
||||
inline void Error(const char *fmt, ...) {
|
||||
{
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||
va_end(args);
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
}
|
||||
} // namespace utils
|
||||
|
||||
// Can not use std::min on Windows with msvc due to:
|
||||
// error C2589: '(': illegal token on right side of '::'
|
||||
template <typename T>
|
||||
auto Min(T const& l, T const& r) {
|
||||
return l < r ? l : r;
|
||||
}
|
||||
// same with Min
|
||||
template <typename T>
|
||||
auto Max(T const& l, T const& r) {
|
||||
return l > r ? l : r;
|
||||
}
|
||||
|
||||
// easy utils that can be directly accessed in xgboost
|
||||
/*! \brief get the beginning address of a vector */
|
||||
template<typename T>
|
||||
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
|
||||
if (vec.size() == 0) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return &vec[0];
|
||||
}
|
||||
}
|
||||
inline char* BeginPtr(std::string &str) { // NOLINT(*)
|
||||
if (str.length() == 0) return nullptr;
|
||||
return &str[0];
|
||||
}
|
||||
inline const char* BeginPtr(const std::string &str) {
|
||||
if (str.length() == 0) return nullptr;
|
||||
return &str[0];
|
||||
}
|
||||
} // namespace rabit
|
||||
#endif // RABIT_INTERNAL_UTILS_H_
|
||||
@ -1,237 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file rabit.h
|
||||
* \brief This file defines rabit's Allreduce/Broadcast interface
|
||||
* The rabit engine contains the actual implementation
|
||||
* Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant),
|
||||
*
|
||||
* rabit.h and serializable.h is all what the user needs to use the rabit interface
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_RABIT_H_ // NOLINT(*)
|
||||
#define RABIT_RABIT_H_ // NOLINT(*)
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
// engine definition of rabit, defines internal implementation
|
||||
// to use rabit interface, there is no need to read engine.h
|
||||
// rabit.h and serializable.h are enough to use the interface
|
||||
#include "./internal/engine.h"
|
||||
|
||||
/*! \brief rabit namespace */
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief defines stream used in rabit
|
||||
* see definition of Stream in dmlc/io.h
|
||||
*/
|
||||
using Stream = dmlc::Stream;
|
||||
/*!
|
||||
* \brief defines serializable objects used in rabit
|
||||
* see definition of Serializable in dmlc/io.h
|
||||
*/
|
||||
using Serializable = dmlc::Serializable;
|
||||
|
||||
/*!
|
||||
* \brief reduction operators namespace
|
||||
*/
|
||||
namespace op {
|
||||
/*!
|
||||
* \class rabit::op::Max
|
||||
* \brief maximum reduction operator
|
||||
*/
|
||||
struct Max;
|
||||
/*!
|
||||
* \class rabit::op::Min
|
||||
* \brief minimum reduction operator
|
||||
*/
|
||||
struct Min;
|
||||
/*!
|
||||
* \class rabit::op::Sum
|
||||
* \brief sum reduction operator
|
||||
*/
|
||||
struct Sum;
|
||||
/*!
|
||||
* \class rabit::op::BitAND
|
||||
* \brief bitwise AND reduction operator
|
||||
*/
|
||||
struct BitAND;
|
||||
/*!
|
||||
* \class rabit::op::BitOR
|
||||
* \brief bitwise OR reduction operator
|
||||
*/
|
||||
struct BitOR;
|
||||
/*!
|
||||
* \class rabit::op::BitXOR
|
||||
* \brief bitwise XOR reduction operator
|
||||
*/
|
||||
struct BitXOR;
|
||||
} // namespace op
|
||||
/*!
|
||||
* \brief initializes rabit, call this once at the beginning of your program
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if initialized successfully, otherwise false
|
||||
*/
|
||||
inline bool Init(int argc, char *argv[]);
|
||||
/*!
|
||||
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
||||
* \return true if finalized successfully, otherwise false
|
||||
*/
|
||||
inline bool Finalize();
|
||||
/*! \brief gets rank of the current process
|
||||
* \return rank number of worker*/
|
||||
inline int GetRank();
|
||||
/*! \brief gets total number of processes
|
||||
* \return total world size*/
|
||||
inline int GetWorldSize();
|
||||
/*! \brief whether rabit env is in distributed mode
|
||||
* \return is distributed*/
|
||||
inline bool IsDistributed();
|
||||
|
||||
/*! \brief gets processor's name
|
||||
* \return processor name*/
|
||||
inline std::string GetProcessorName();
|
||||
/*!
|
||||
* \brief prints the msg to the tracker,
|
||||
* this function can be used to communicate progress information to
|
||||
* the user who monitors the tracker
|
||||
* \param msg the message to be printed
|
||||
*/
|
||||
inline void TrackerPrint(const std::string &msg);
|
||||
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief prints the msg to the tracker, this function may not be available
|
||||
* in very strict c++98 compilers, though it usually is.
|
||||
* this function can be used to communicate progress information to
|
||||
* the user who monitors the tracker
|
||||
* \param fmt the format string
|
||||
*/
|
||||
inline void TrackerPrintf(const char *fmt, ...);
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
/*!
|
||||
* \brief broadcasts a memory region to every node from the root
|
||||
*
|
||||
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
|
||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||
* \param size the data size
|
||||
* \param root the process root
|
||||
*/
|
||||
inline void Broadcast(void *sendrecv_data, size_t size, int root);
|
||||
|
||||
/*!
|
||||
* \brief broadcasts an std::vector<DType> to every node from root
|
||||
* \param sendrecv_data the pointer to send/receive vector,
|
||||
* for the receiver, the vector does not need to be pre-allocated
|
||||
* \param root the process root
|
||||
* \tparam DType the data type stored in the vector, has to be a simple data type
|
||||
* that can be directly transmitted by sending the sizeof(DType)
|
||||
*/
|
||||
template<typename DType>
|
||||
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
|
||||
/*!
|
||||
* \brief broadcasts a std::string to every node from the root
|
||||
* \param sendrecv_data the pointer to the send/receive buffer,
|
||||
* for the receiver, the vector does not need to be pre-allocated
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
* \param root the process root
|
||||
*/
|
||||
inline void Broadcast(std::string *sendrecv_data, int root);
|
||||
/*!
|
||||
* \brief performs in-place Allreduce on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
*
|
||||
* Example Usage: the following code does an Allreduce and outputs the sum as the result
|
||||
* \code{.cpp}
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size());
|
||||
* ...
|
||||
* \endcode
|
||||
*
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to pass into the lazy preprocessing function
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType data type
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *) = nullptr,
|
||||
void *prepare_arg = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
*/
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice);
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
* \brief performs in-place Allreduce, on sendrecvbuf
|
||||
* with a prepare function specified by a lambda function
|
||||
*
|
||||
* Example Usage:
|
||||
* \code{.cpp}
|
||||
* // the following code does an Allreduce and outputs the sum as the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
|
||||
* for (int i = 0; i < 10; ++i) {
|
||||
* data[i] = i;
|
||||
* }
|
||||
* });
|
||||
* ...
|
||||
* \endcode
|
||||
* \param sendrecvbuf buffer for both sending and receiving data
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
|
||||
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \tparam OP see namespace op, reduce operator
|
||||
* \tparam DType data type
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
#endif // C++11
|
||||
|
||||
/*!
|
||||
* \brief deprecated, planned for removal after checkpoing from JVM package is removed.
|
||||
*/
|
||||
inline int LoadCheckPoint();
|
||||
/*!
|
||||
* \brief deprecated, planned for removal after checkpoing from JVM package is removed.
|
||||
*/
|
||||
inline void CheckPoint();
|
||||
|
||||
/*!
|
||||
* \return version number of the current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
inline int VersionNumber();
|
||||
} // namespace rabit
|
||||
// implementation of template functions
|
||||
#include "./internal/rabit-inl.h"
|
||||
#endif // RABIT_RABIT_H_ // NOLINT(*)
|
||||
@ -1,26 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file serializable.h
|
||||
* \brief defines serializable interface of rabit
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_SERIALIZABLE_H_
|
||||
#define RABIT_SERIALIZABLE_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "rabit/internal/utils.h"
|
||||
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief defines stream used in rabit
|
||||
* see definition of Stream in dmlc/io.h
|
||||
*/
|
||||
using Stream = dmlc::Stream ;
|
||||
/*!
|
||||
* \brief defines serializable objects used in rabit
|
||||
* see definition of Serializable in dmlc/io.h
|
||||
*/
|
||||
using Serializable = dmlc::Serializable;
|
||||
|
||||
} // namespace rabit
|
||||
#endif // RABIT_SERIALIZABLE_H_
|
||||
@ -1,997 +0,0 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* \file allreduce_base.cc
|
||||
* \brief Basic implementation of AllReduce
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#endif // !defined(NOMINMAX)
|
||||
|
||||
#include "allreduce_base.h"
|
||||
|
||||
#include "rabit/base.h"
|
||||
#include "rabit/internal/rabit-inl.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <netinet/tcp.h>
|
||||
#endif // _WIN32
|
||||
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
namespace rabit::engine {
|
||||
// constructor
|
||||
AllreduceBase::AllreduceBase() {
|
||||
tracker_uri = "NULL";
|
||||
tracker_port = 9000;
|
||||
host_uri = "";
|
||||
rank = 0;
|
||||
world_size = -1;
|
||||
connect_retry = 5;
|
||||
hadoop_mode = false;
|
||||
version_number = 0;
|
||||
// 32 K items
|
||||
reduce_ring_mincount = 32 << 10;
|
||||
// 1M reducer size each time
|
||||
tree_reduce_minsize = 1 << 20;
|
||||
// tracker URL
|
||||
task_id = "NULL";
|
||||
err_link = nullptr;
|
||||
dmlc_role = "worker";
|
||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||
// setup possible environment variable of interest
|
||||
// include dmlc support direct variables
|
||||
env_vars.emplace_back("DMLC_TASK_ID");
|
||||
env_vars.emplace_back("DMLC_ROLE");
|
||||
env_vars.emplace_back("DMLC_NUM_ATTEMPT");
|
||||
env_vars.emplace_back("DMLC_TRACKER_URI");
|
||||
env_vars.emplace_back("DMLC_TRACKER_PORT");
|
||||
env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
// setup from environment variables
|
||||
// handler to get variables from env
|
||||
for (auto & env_var : env_vars) {
|
||||
const char *value = getenv(env_var.c_str());
|
||||
if (value != nullptr) {
|
||||
this->SetParam(env_var.c_str(), value);
|
||||
}
|
||||
}
|
||||
// pass in arguments override env variable.
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
this->SetParam(name, val);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// handling for hadoop
|
||||
const char *task_id = getenv("mapred_tip_id");
|
||||
if (task_id == nullptr) {
|
||||
task_id = getenv("mapreduce_task_id");
|
||||
}
|
||||
if (hadoop_mode) {
|
||||
utils::Check(task_id != nullptr,
|
||||
"hadoop_mode is set but cannot find mapred_task_id");
|
||||
}
|
||||
if (task_id != nullptr) {
|
||||
this->SetParam("rabit_task_id", task_id);
|
||||
this->SetParam("rabit_hadoop_mode", "1");
|
||||
}
|
||||
const char *attempt_id = getenv("mapred_task_id");
|
||||
if (attempt_id != nullptr) {
|
||||
const char *att = strrchr(attempt_id, '_');
|
||||
int num_trial;
|
||||
if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||
this->SetParam("rabit_num_trial", att + 1);
|
||||
}
|
||||
}
|
||||
// handling for hadoop
|
||||
const char *num_task = getenv("mapred_map_tasks");
|
||||
if (num_task == nullptr) {
|
||||
num_task = getenv("mapreduce_job_maps");
|
||||
}
|
||||
if (hadoop_mode) {
|
||||
utils::Check(num_task != nullptr,
|
||||
"hadoop_mode is set but cannot find mapred_map_tasks");
|
||||
}
|
||||
if (num_task != nullptr) {
|
||||
this->SetParam("rabit_world_size", num_task);
|
||||
}
|
||||
}
|
||||
if (dmlc_role != "worker") {
|
||||
LOG(FATAL) << "Rabit Module currently only works with dmlc worker";
|
||||
}
|
||||
|
||||
// clear the setting before start reconnection
|
||||
this->rank = -1;
|
||||
//---------------------
|
||||
// start socket
|
||||
xgboost::system::SocketStartup();
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
auto rc = xgboost::collective::GetHostName(&this->host_uri);
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
// get information from tracker
|
||||
rc = this->ReConnectLinks();
|
||||
if (rc.OK()) {
|
||||
return true;
|
||||
}
|
||||
LOG(FATAL) << rc.Report();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AllreduceBase::Shutdown() {
|
||||
try {
|
||||
for (auto &all_link : all_links) {
|
||||
if (!all_link.sock.IsClosed()) {
|
||||
SafeColl(all_link.sock.Close());
|
||||
}
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return true;
|
||||
// notify tracker rank i have shutdown
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
auto rc = this->ConnectTracker(&tracker);
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
tracker.Send(xgboost::StringView{"shutdown"});
|
||||
SafeColl(tracker.Close());
|
||||
xgboost::system::SocketFinalize();
|
||||
return true;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(WARNING) << "Failed to shutdown due to" << e.what();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
}
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
auto rc = this->ConnectTracker(&tracker);
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
|
||||
tracker.Send(xgboost::StringView{"print"});
|
||||
tracker.Send(xgboost::StringView{msg});
|
||||
SafeColl(tracker.Close());
|
||||
}
|
||||
|
||||
// util to parse data with unit suffix
|
||||
inline size_t ParseUnit(const char *name, const char *val) {
|
||||
char unit;
|
||||
unsigned long amt; // NOLINT(*)
|
||||
int n = sscanf(val, "%lu%c", &amt, &unit);
|
||||
size_t amount = amt;
|
||||
if (n == 2) {
|
||||
switch (unit) {
|
||||
case 'B': return amount;
|
||||
case 'K': return amount << 10UL;
|
||||
case 'M': return amount << 20UL;
|
||||
case 'G': return amount << 30UL;
|
||||
default: utils::Error("invalid format for %s", name); return 0;
|
||||
}
|
||||
} else if (n == 1) {
|
||||
return amount;
|
||||
} else {
|
||||
utils::Error("invalid format for %s," \
|
||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
|
||||
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
|
||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||
if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val;
|
||||
if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val);
|
||||
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
|
||||
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
|
||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
|
||||
if (!strcmp(name, "rabit_tree_reduce_minsize")) tree_reduce_minsize = atoi(val);
|
||||
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
||||
reduce_ring_mincount = atoi(val);
|
||||
utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
|
||||
}
|
||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
|
||||
}
|
||||
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
|
||||
connect_retry = atoi(val);
|
||||
}
|
||||
if (!strcmp(name, "rabit_timeout")) {
|
||||
rabit_timeout = utils::StringToBool(val);
|
||||
}
|
||||
if (!strcmp(name, "rabit_timeout_sec")) {
|
||||
timeout_sec = std::chrono::seconds(atoi(val));
|
||||
utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second");
|
||||
}
|
||||
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
||||
if (!strcmp(val, "true")) {
|
||||
rabit_enable_tcp_no_delay = true;
|
||||
} else {
|
||||
rabit_enable_tcp_no_delay = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker(
|
||||
xgboost::collective::TCPSocket *out) const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
xgboost::collective::TCPSocket &tracker = *out;
|
||||
|
||||
auto rc =
|
||||
Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker);
|
||||
if (!rc.OK()) {
|
||||
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
using utils::Assert;
|
||||
if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) {
|
||||
return xgboost::collective::Fail("Failed to send the verification number.");
|
||||
}
|
||||
if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) {
|
||||
return xgboost::collective::Fail("Failed to recieve the verification number.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) {
|
||||
return xgboost::collective::Fail("Failed to send the local rank back to the tracker.");
|
||||
}
|
||||
if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) {
|
||||
return xgboost::collective::Fail("Failed to send the world size back to the tracker.");
|
||||
}
|
||||
if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) {
|
||||
return xgboost::collective::Fail("Failed to send the task ID back to the tracker.");
|
||||
}
|
||||
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0;
|
||||
world_size = 1;
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
auto rc = this->ConnectTracker(&tracker);
|
||||
if (!rc.OK()) {
|
||||
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||
tracker.Send(xgboost::StringView{cmd});
|
||||
|
||||
try {
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
std::map<int, int> tree_neighbors;
|
||||
using utils::Assert;
|
||||
// get new ranks
|
||||
int newrank, num_neighbors;
|
||||
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
|
||||
sizeof(parent_rank), "ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(rank == -1 || newrank == rank,
|
||||
"must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
|
||||
if (rank == -1) {
|
||||
LOG(FATAL) << "tracker got overwhelmed and not able to assign correct rank";
|
||||
}
|
||||
|
||||
LOG(CONSOLE) << "task " << task_id << " got new rank " << rank;
|
||||
|
||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
|
||||
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||
// create listening socket
|
||||
std::int32_t port{0};
|
||||
SafeColl(sock_listen.BindHost(&port));
|
||||
SafeColl(sock_listen.Listen());
|
||||
|
||||
// get number of to connect and number of to accept nodes from tracker
|
||||
int num_conn, num_accept, num_error = 1;
|
||||
do {
|
||||
for (auto & all_link : all_links) {
|
||||
SafeColl(all_link.sock.Close());
|
||||
}
|
||||
// tracker construct goodset
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
||||
"ReConnectLink failure 8");
|
||||
num_error = 0;
|
||||
for (int i = 0; i < num_conn; ++i) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
SafeColl(tracker.Recv(&hname));
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||
// connect to peer
|
||||
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
|
||||
timeout_sec, &r.sock)
|
||||
.OK()) {
|
||||
num_error += 1;
|
||||
SafeColl(r.sock.Close());
|
||||
continue;
|
||||
}
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 12");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank,
|
||||
"ReConnectLink failure, link rank inconsistent");
|
||||
bool match = false;
|
||||
for (auto & all_link : all_links) {
|
||||
if (all_link.rank == hrank) {
|
||||
Assert(all_link.sock.IsClosed(), "Override a link that is active");
|
||||
all_link.sock = std::move(r.sock);
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.emplace_back(std::move(r));
|
||||
}
|
||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||
"ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
SafeColl(tracker.Close());
|
||||
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
LinkRecord r;
|
||||
r.sock = sock_listen.Accept();
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 15");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (auto & all_link : all_links) {
|
||||
if (all_link.rank == r.rank) {
|
||||
utils::Assert(all_link.sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_link.sock = std::move(r.sock);
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.emplace_back(std::move(r));
|
||||
}
|
||||
SafeColl(sock_listen.Close());
|
||||
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
for (auto &all_link : all_links) {
|
||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
CHECK(all_link.sock.NonBlocking(true).OK());
|
||||
CHECK(all_link.sock.SetKeepAlive().OK());
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
CHECK(all_link.sock.SetNoDelay().OK());
|
||||
}
|
||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||
if (all_link.rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
}
|
||||
tree_links.plinks.push_back(&all_link);
|
||||
}
|
||||
if (all_link.rank == prev_rank) ring_prev = &all_link;
|
||||
if (all_link.rank == next_rank) ring_next = &all_link;
|
||||
}
|
||||
Assert(parent_rank == -1 || parent_index != -1,
|
||||
"cannot find parent in the link");
|
||||
Assert(prev_rank == -1 || ring_prev != nullptr,
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != nullptr,
|
||||
"cannot find next ring in the link");
|
||||
return xgboost::collective::Success();
|
||||
} catch (const std::exception& e) {
|
||||
std::stringstream ss;
|
||||
ss << "Failed in ReconnectLink " << e.what();
|
||||
return xgboost::collective::Fail(ss.str());
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
* NOTE on Allreduce:
|
||||
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||
* It only means the current node get the correct result of Allreduce.
|
||||
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
if (count > reduce_ring_mincount) {
|
||||
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||
} else {
|
||||
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||
* this function implements tree-shape reduction
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.Size() == 0 || count == 0) return kSuccess;
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.Size());
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
// size of space that we already performs reduce in up pass
|
||||
size_t size_up_reduce = 0;
|
||||
// size of space that we have already passed to parent
|
||||
size_t size_up_out = 0;
|
||||
// size of message we received, and send in the down pass
|
||||
size_t size_down_in = 0;
|
||||
// minimal size of each reducer
|
||||
const size_t eachreduce = (tree_reduce_minsize / type_nbytes * type_nbytes);
|
||||
|
||||
// initialize the link ring-buffer and pointer
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
|
||||
}
|
||||
links[i].ResetSize();
|
||||
}
|
||||
// if no children, no need to reduce
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
size_up_reduce = total_size;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::PollHelper watcher;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i == parent_index) {
|
||||
if (size_down_in != total_size) {
|
||||
watcher.WatchRead(links[i].sock);
|
||||
// only watch for exception in live channels
|
||||
watcher.WatchException(links[i].sock);
|
||||
finished = false;
|
||||
}
|
||||
if (size_up_out != total_size && size_up_out < size_up_reduce) {
|
||||
watcher.WatchWrite(links[i].sock);
|
||||
}
|
||||
} else {
|
||||
if (links[i].size_read != total_size) {
|
||||
watcher.WatchRead(links[i].sock);
|
||||
}
|
||||
// size_write <= size_read
|
||||
if (links[i].size_write != total_size) {
|
||||
if (links[i].size_write < size_down_in) {
|
||||
watcher.WatchWrite(links[i].sock);
|
||||
}
|
||||
// only watch for exception in live channels
|
||||
watcher.WatchException(links[i].sock);
|
||||
finished = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
// finish running allreduce
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
// select must return
|
||||
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
|
||||
// make sure to receive minimal reducer size
|
||||
// since each child reduce and sends the minimal reducer size
|
||||
while (links[i].size_read < total_size
|
||||
&& links[i].size_read - size_up_reduce < eachreduce) {
|
||||
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// this node have children, perform reduce
|
||||
if (nlink > static_cast<int>(parent_index != -1)) {
|
||||
size_t buffer_size = 0;
|
||||
// do upstream reduce
|
||||
size_t max_reduce = total_size;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
max_reduce = std::min(max_reduce, links[i].size_read);
|
||||
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
|
||||
"buffer size inconsistent");
|
||||
buffer_size = links[i].buffer_size;
|
||||
}
|
||||
}
|
||||
utils::Assert(buffer_size != 0, "must assign buffer_size");
|
||||
// round to type_n4bytes
|
||||
max_reduce = (max_reduce / type_nbytes * type_nbytes);
|
||||
|
||||
// if max reduce is less than total size, we reduce multiple times of
|
||||
// each reduce size
|
||||
if (max_reduce < total_size) {
|
||||
max_reduce = max_reduce - max_reduce % eachreduce;
|
||||
}
|
||||
|
||||
// perform reduce, can be at most two rounds
|
||||
while (size_up_reduce < max_reduce) {
|
||||
// start position
|
||||
size_t start = size_up_reduce % buffer_size;
|
||||
// perform read till end of buffer
|
||||
size_t nread = std::min(buffer_size - start,
|
||||
max_reduce - size_up_reduce);
|
||||
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
reducer(links[i].buffer_head + start,
|
||||
sendrecvbuf + size_up_reduce,
|
||||
static_cast<int>(nread / type_nbytes),
|
||||
MPI::Datatype(type_nbytes));
|
||||
}
|
||||
}
|
||||
size_up_reduce += nread;
|
||||
}
|
||||
}
|
||||
if (parent_index != -1) {
|
||||
// pass message up to parent, can pass data that are already been reduced
|
||||
if (size_up_out < size_up_reduce) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
||||
if (len != -1) {
|
||||
size_up_out += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[parent_index], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
// read data from parent
|
||||
if (watcher.CheckRead(links[parent_index].sock) &&
|
||||
total_size > size_down_in) {
|
||||
size_t left_size = total_size-size_down_in;
|
||||
size_t reduce_size_min = std::min(left_size, eachreduce);
|
||||
size_t recved = 0;
|
||||
while (recved < reduce_size_min) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
|
||||
if (len == 0) {
|
||||
SafeColl(links[parent_index].sock.Close());
|
||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||
}
|
||||
if (len != -1) {
|
||||
size_down_in += static_cast<size_t>(len);
|
||||
utils::Assert(size_down_in <= size_up_out,
|
||||
"Allreduce: boundary error");
|
||||
recved+=len;
|
||||
|
||||
// if it receives more data than each reduce, it means the next block is sent.
|
||||
// we double the reduce_size_min or add to left_size
|
||||
while (recved > reduce_size_min) {
|
||||
reduce_size_min += std::min(left_size-reduce_size_min, eachreduce);
|
||||
}
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[parent_index], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// this is root, can use reduce as most recent point
|
||||
size_down_in = size_up_out = size_up_reduce;
|
||||
}
|
||||
// can pass message down to children
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && links[i].size_write < size_down_in) {
|
||||
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param total_size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.Size() == 0 || total_size == 0) return kSuccess;
|
||||
utils::Check(root < world_size,
|
||||
"Broadcast: root should be smaller than world size");
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.Size());
|
||||
// size of space already read from data
|
||||
size_t size_in = 0;
|
||||
// input link, -2 means unknown yet, -1 means this is root
|
||||
int in_link = -2;
|
||||
|
||||
// initialize the link statistics
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
}
|
||||
// root have all the data
|
||||
if (this->rank == root) {
|
||||
size_in = total_size;
|
||||
in_link = -1;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
bool finished = true;
|
||||
// select helper
|
||||
utils::PollHelper watcher;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (in_link == -2) {
|
||||
watcher.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (i == in_link && links[i].size_read != total_size) {
|
||||
watcher.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||
if (links[i].size_write < size_in) {
|
||||
watcher.WatchWrite(links[i].sock);
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
}
|
||||
// finish running
|
||||
if (finished) break;
|
||||
// select
|
||||
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
if (in_link == -2) {
|
||||
// probe in-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (watcher.CheckRead(links[i].sock)) {
|
||||
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
size_in = links[i].size_read;
|
||||
if (size_in != 0) {
|
||||
in_link = i; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// read from in link
|
||||
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
|
||||
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[in_link], ret);
|
||||
}
|
||||
size_in = links[in_link].size_read;
|
||||
}
|
||||
}
|
||||
// send data to all out-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != in_link && links[i].size_write < size_in) {
|
||||
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
// read from next link and send to prev one
|
||||
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||
// need to reply on special rank structure
|
||||
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||
rank == (prev.rank + 1) % world_size,
|
||||
"need to assume rank structure");
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
const size_t stop_read = total_size + slice_begin;
|
||||
const size_t stop_write = total_size + slice_begin - size_prev_slice;
|
||||
size_t write_ptr = slice_begin;
|
||||
size_t read_ptr = slice_end;
|
||||
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::PollHelper watcher;
|
||||
if (read_ptr != stop_read) {
|
||||
watcher.WatchRead(next.sock);
|
||||
finished = false;
|
||||
}
|
||||
if (write_ptr != stop_write) {
|
||||
if (write_ptr < read_ptr) {
|
||||
watcher.WatchWrite(prev.sock);
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||
size_t size = stop_read - read_ptr;
|
||||
size_t start = read_ptr % total_size;
|
||||
if (start + size > total_size) {
|
||||
size = total_size - start;
|
||||
}
|
||||
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
|
||||
if (len != -1) {
|
||||
read_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) {
|
||||
auto err = ReportError(&next, ret);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (write_ptr < read_ptr && write_ptr != stop_write) {
|
||||
size_t size = std::min(read_ptr, stop_write) - write_ptr;
|
||||
size_t start = write_ptr % total_size;
|
||||
if (start + size > total_size) {
|
||||
size = total_size - start;
|
||||
}
|
||||
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||
if (len != -1) {
|
||||
write_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) {
|
||||
auto err = ReportError(&prev, ret);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
|
||||
* and will return the cause of failure
|
||||
*
|
||||
* Ring-based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryAllreduce
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
// read from next link and send to prev one
|
||||
LinkRecord &prev = *ring_prev, &next = *ring_next;
|
||||
// need to reply on special rank structure
|
||||
utils::Assert(next.rank == (rank + 1) % world_size &&
|
||||
rank == (prev.rank + 1) % world_size,
|
||||
"need to assume rank structure");
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
size_t n = static_cast<size_t>(world_size);
|
||||
size_t step = (count + n - 1) / n;
|
||||
size_t r = static_cast<size_t>(next.rank);
|
||||
size_t write_ptr = std::min(r * step, count) * type_nbytes;
|
||||
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
|
||||
size_t reduce_ptr = read_ptr;
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
// position to stop reading
|
||||
const size_t stop_read = total_size + write_ptr;
|
||||
// position to stop writing
|
||||
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
|
||||
if (stop_write > stop_read) {
|
||||
stop_write -= total_size;
|
||||
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
|
||||
}
|
||||
// use ring buffer in next position
|
||||
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
|
||||
// set size_read to read pointer for ring buffer to work properly
|
||||
next.size_read = read_ptr;
|
||||
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::PollHelper watcher;
|
||||
if (read_ptr != stop_read) {
|
||||
watcher.WatchRead(next.sock);
|
||||
finished = false;
|
||||
}
|
||||
if (write_ptr != stop_write) {
|
||||
if (write_ptr < reduce_ptr) {
|
||||
watcher.WatchWrite(prev.sock);
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&next, ret);
|
||||
}
|
||||
// sync the rate
|
||||
read_ptr = next.size_read;
|
||||
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
|
||||
const size_t buffer_size = next.buffer_size;
|
||||
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
|
||||
while (reduce_ptr < max_reduce) {
|
||||
size_t bstart = reduce_ptr % buffer_size;
|
||||
size_t nread = std::min(buffer_size - bstart,
|
||||
max_reduce - reduce_ptr);
|
||||
size_t rstart = reduce_ptr % total_size;
|
||||
nread = std::min(nread, total_size - rstart);
|
||||
reducer(next.buffer_head + bstart,
|
||||
sendrecvbuf + rstart,
|
||||
static_cast<int>(nread / type_nbytes),
|
||||
MPI::Datatype(type_nbytes));
|
||||
reduce_ptr += nread;
|
||||
}
|
||||
}
|
||||
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
|
||||
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
|
||||
size_t start = write_ptr % total_size;
|
||||
if (start + size > total_size) {
|
||||
size = total_size - start;
|
||||
}
|
||||
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
|
||||
if (len != -1) {
|
||||
write_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
|
||||
if (ret != kSuccess) return ret;
|
||||
size_t n = static_cast<size_t>(world_size);
|
||||
size_t step = (count + n - 1) / n;
|
||||
size_t begin = std::min(rank * step, count) * type_nbytes;
|
||||
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
|
||||
// previous rank
|
||||
int prank = ring_prev->rank;
|
||||
// get rank of previous
|
||||
return TryAllgatherRing
|
||||
(sendrecvbuf_, type_nbytes * count,
|
||||
begin, end,
|
||||
(std::min((prank + 1) * step, count) -
|
||||
std::min(prank * step, count)) * type_nbytes);
|
||||
}
|
||||
} // namespace rabit::engine
|
||||
@ -1,501 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_base.h
|
||||
* \brief Basic implementation of AllReduce
|
||||
* using TCP non-block socket and tree-shape reduction.
|
||||
*
|
||||
* This implementation provides basic utility of AllReduce and Broadcast
|
||||
* without considering node failure
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_BASE_H_
|
||||
#define RABIT_ALLREDUCE_BASE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "rabit/internal/socket.h"
|
||||
#include "rabit/internal/utils.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#ifdef RABIT_CXXTESTDEFS_H
|
||||
#define private public
|
||||
#define protected public
|
||||
#endif // RABIT_CXXTESTDEFS_H
|
||||
|
||||
|
||||
namespace MPI { // NOLINT
|
||||
// MPI data type to be compatible with existing MPI interface
|
||||
class Datatype {
|
||||
public:
|
||||
size_t type_size;
|
||||
explicit Datatype(size_t type_size) : type_size(type_size) {}
|
||||
};
|
||||
}
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
|
||||
/*! \brief implementation of basic Allreduce engine */
|
||||
class AllreduceBase : public IEngine {
|
||||
public:
|
||||
// magic number to verify server
|
||||
static const int kMagic = 0xff99;
|
||||
// constant one byte out of band message to indicate error happening
|
||||
AllreduceBase();
|
||||
virtual ~AllreduceBase() = default;
|
||||
// initialize the manager
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
// shutdown the engine
|
||||
virtual bool Shutdown();
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
void TrackerPrint(const std::string &msg) override;
|
||||
|
||||
/*! \brief get rank of previous node in ring topology*/
|
||||
int GetRingPrevRank() const override {
|
||||
return ring_prev->rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
int GetRank() const override {
|
||||
return rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
int GetWorldSize() const override {
|
||||
if (world_size == -1) return 1;
|
||||
return world_size;
|
||||
}
|
||||
/*! \brief whether is distributed or not */
|
||||
bool IsDistributed() const override {
|
||||
return tracker_uri != "NULL";
|
||||
}
|
||||
/*! \brief get rank */
|
||||
std::string GetHost() const override {
|
||||
return host_uri;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
*/
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice) override {
|
||||
if (world_size == 1 || world_size == -1) {
|
||||
return;
|
||||
}
|
||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin,
|
||||
slice_end, size_prev_slice) == kSuccess,
|
||||
"AllgatherRing failed");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr) override {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
|
||||
kSuccess,
|
||||
"Allreduce failed");
|
||||
}
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||
"Broadcast failed");
|
||||
}
|
||||
/*!
|
||||
* \brief deprecated
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
int LoadCheckPoint() override { return 0; }
|
||||
|
||||
// deprecated, increase internal version number
|
||||
void CheckPoint() override { version_number += 1; }
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
int VersionNumber() const override {
|
||||
return version_number;
|
||||
}
|
||||
/*!
|
||||
* \brief report current status to the job tracker
|
||||
* depending on the job tracker we are in
|
||||
*/
|
||||
inline void ReportStatus() const {
|
||||
if (hadoop_mode != 0) {
|
||||
LOG(CONSOLE) << "reporter:status:Rabit Phase[" << version_number << "] Operation " << seq_counter << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/*! \brief enumeration of possible returning results from Try functions */
|
||||
enum ReturnTypeEnum {
|
||||
/*! \brief execution is successful */
|
||||
kSuccess,
|
||||
/*! \brief a link was reset by peer */
|
||||
kConnReset,
|
||||
/*! \brief received a zero length message */
|
||||
kRecvZeroLen,
|
||||
/*! \brief a neighbor node go down, the connection is dropped */
|
||||
kSockError,
|
||||
/*!
|
||||
* \brief another node which is not my neighbor go down,
|
||||
* get Out-of-Band exception notification from my neighbor
|
||||
*/
|
||||
kGetExcept
|
||||
};
|
||||
/*! \brief struct return type to avoid implicit conversion to int/bool */
|
||||
struct ReturnType {
|
||||
/*! \brief internal return type */
|
||||
ReturnTypeEnum value;
|
||||
// constructor
|
||||
ReturnType() = default;
|
||||
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
|
||||
inline bool operator==(const ReturnTypeEnum &v) const {
|
||||
return value == v;
|
||||
}
|
||||
inline bool operator!=(const ReturnTypeEnum &v) const {
|
||||
return value != v;
|
||||
}
|
||||
};
|
||||
/*! \brief translate errno to return type */
|
||||
static ReturnType Errno2Return() {
|
||||
int errsv = xgboost::system::LastError();
|
||||
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
|
||||
#ifdef _WIN32
|
||||
if (errsv == WSAEWOULDBLOCK) return kSuccess;
|
||||
if (errsv == WSAECONNRESET) return kConnReset;
|
||||
#endif // _WIN32
|
||||
if (errsv == ECONNRESET) return kConnReset;
|
||||
return kSockError;
|
||||
}
|
||||
// link record to a neighbor
|
||||
struct LinkRecord {
|
||||
public:
|
||||
// socket to get data from/to link
|
||||
xgboost::collective::TCPSocket sock;
|
||||
// rank of the node in this link
|
||||
int rank;
|
||||
// size of data readed from link
|
||||
size_t size_read;
|
||||
// size of data sent to the link
|
||||
size_t size_write;
|
||||
// pointer to buffer head
|
||||
char *buffer_head {nullptr};
|
||||
// buffer size, in bytes
|
||||
size_t buffer_size {0};
|
||||
// constructor
|
||||
LinkRecord() = default;
|
||||
// initialize buffer
|
||||
void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
auto to = Min(reduce_buffer_size, n);
|
||||
buffer_.resize(to);
|
||||
// make sure align to type_nbytes
|
||||
buffer_size =
|
||||
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
utils::Assert(type_nbytes <= buffer_size,
|
||||
"too large type_nbytes=%lu, buffer_size=%lu",
|
||||
type_nbytes, buffer_size);
|
||||
// set buffer head
|
||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||
}
|
||||
// reset the recv and sent size
|
||||
inline void ResetSize() {
|
||||
size_write = size_read = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief read data into ring-buffer, with care not to existing useful override data
|
||||
* position after protect_start
|
||||
* \param protect_start all data start from protect_start is still needed in buffer
|
||||
* read shall not override this
|
||||
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
|
||||
* \return the type of reading
|
||||
*/
|
||||
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
|
||||
utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
|
||||
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
|
||||
size_t ngap = size_read - protect_start;
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = max_size_read - size_read;
|
||||
nmax = Min(nmax, buffer_size - ngap);
|
||||
nmax = Min(nmax, buffer_size - offset);
|
||||
if (nmax == 0) return kSuccess;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return();
|
||||
size_read += static_cast<size_t>(len);
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief read data into array,
|
||||
* this function can not be used together with ReadToRingBuffer
|
||||
* a link can either read into the ring buffer, or existing array
|
||||
* \param max_size maximum size of array
|
||||
* \return true if it is a successful read, false if there is some error happens, check errno
|
||||
*/
|
||||
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
|
||||
if (max_size == size_read) return kSuccess;
|
||||
char *p = static_cast<char*>(recvbuf_);
|
||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return();
|
||||
size_read += static_cast<size_t>(len);
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief write data in array to sock
|
||||
* \param sendbuf_ head of array
|
||||
* \param max_size maximum size of array
|
||||
* \return true if it is a successful write, false if there is some error happens, check errno
|
||||
*/
|
||||
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
|
||||
const char *p = static_cast<const char*>(sendbuf_);
|
||||
ssize_t len = sock.Send(p + size_write, max_size - size_write);
|
||||
if (len == -1) return Errno2Return();
|
||||
size_write += static_cast<size_t>(len);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
private:
|
||||
// recv buffer to get data from child
|
||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||
std::vector<uint64_t> buffer_;
|
||||
};
|
||||
/*!
|
||||
* \brief simple data structure that works like a vector
|
||||
* but takes reference instead of space
|
||||
*/
|
||||
struct RefLinkVector {
|
||||
std::vector<LinkRecord*> plinks;
|
||||
inline LinkRecord &operator[](size_t i) {
|
||||
return *plinks[i];
|
||||
}
|
||||
inline size_t Size() const {
|
||||
return plinks.size();
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the missing links
|
||||
* this function is also used when the engine start up
|
||||
* \param cmd possible command to sent to tracker
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start");
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
* NOTE on Allreduce:
|
||||
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||
* It only means the current node get the correct result of Allreduce.
|
||||
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf,
|
||||
* this function implements tree-shape reduction
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllreduceTree(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin, size_t slice_end,
|
||||
size_t size_prev_slice);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
|
||||
*
|
||||
* after the function, node k get k-th segment of the reduction result
|
||||
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
|
||||
* where step = ceil(count / world_size)
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryAllreduce
|
||||
*/
|
||||
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* use a ring based algorithm, reduce-scatter + allgather
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllreduceRing(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief function used to report error when a link goes wrong
|
||||
* \param link the pointer to the link who causes the error
|
||||
* \param err the error type
|
||||
*/
|
||||
inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
|
||||
err_link = link; return err;
|
||||
}
|
||||
//---- data structure related to model ----
|
||||
// call sequence counter, records how many calls we made so far
|
||||
// from last call to CheckPoint, LoadCheckPoint
|
||||
int seq_counter{0}; // NOLINT
|
||||
// version number of model
|
||||
int version_number {0}; // NOLINT
|
||||
// whether the job is running in Hadoop
|
||||
bool hadoop_mode; // NOLINT
|
||||
//---- local data related to link ----
|
||||
// index of parent link, can be -1, meaning this is root of the tree
|
||||
int parent_index; // NOLINT
|
||||
// rank of parent node, can be -1
|
||||
int parent_rank; // NOLINT
|
||||
// sockets of all links this connects to
|
||||
std::vector<LinkRecord> all_links; // NOLINT
|
||||
// used to record the link where things goes wrong
|
||||
LinkRecord *err_link; // NOLINT
|
||||
// all the links in the reduction tree connection
|
||||
RefLinkVector tree_links; // NOLINT
|
||||
// pointer to links in the ring
|
||||
LinkRecord *ring_prev, *ring_next; // NOLINT
|
||||
//----- meta information-----
|
||||
// list of enviroment variables that are of possible interest
|
||||
std::vector<std::string> env_vars; // NOLINT
|
||||
// unique identifier of the possible job this process is doing
|
||||
// used to assign ranks, optional, default to NULL
|
||||
std::string task_id; // NOLINT
|
||||
// uri of current host, to be set by Init
|
||||
std::string host_uri; // NOLINT
|
||||
// uri of tracker
|
||||
std::string tracker_uri; // NOLINT
|
||||
// role in dmlc jobs
|
||||
std::string dmlc_role; // NOLINT
|
||||
// port of tracker address
|
||||
int tracker_port; // NOLINT
|
||||
// reduce buffer size
|
||||
size_t reduce_buffer_size; // NOLINT
|
||||
// reduction method
|
||||
int reduce_method; // NOLINT
|
||||
// minimum count of cells to use ring based method
|
||||
size_t reduce_ring_mincount; // NOLINT
|
||||
// minimum block size per tree reduce
|
||||
size_t tree_reduce_minsize; // NOLINT
|
||||
// current rank
|
||||
int rank; // NOLINT
|
||||
// world size
|
||||
int world_size; // NOLINT
|
||||
// connect retry time
|
||||
int connect_retry; // NOLINT
|
||||
// by default, if rabit worker not recover in half an hour exit
|
||||
std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT
|
||||
// flag to enable rabit_timeout
|
||||
bool rabit_timeout = false; // NOLINT
|
||||
// Enable TCP node delay
|
||||
bool rabit_enable_tcp_no_delay = false; // NOLINT
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ALLREDUCE_BASE_H_
|
||||
@ -1,147 +0,0 @@
|
||||
/*!
|
||||
* Copyright by Contributors
|
||||
* \file allreduce_mock.h
|
||||
* \brief Mock test module of AllReduce engine,
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
*
|
||||
* \author Ignacio Cano, Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_MOCK_H_
|
||||
#define RABIT_ALLREDUCE_MOCK_H_
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <dmlc/timer.h>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
class AllreduceMock : public AllreduceBase {
|
||||
public:
|
||||
// constructor
|
||||
AllreduceMock() {
|
||||
num_trial_ = 0;
|
||||
force_local_ = 0;
|
||||
report_stats_ = 0;
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
}
|
||||
// destructor
|
||||
~AllreduceMock() override = default;
|
||||
void SetParam(const char *name, const char *val) override {
|
||||
AllreduceBase::SetParam(name, val);
|
||||
// additional parameters
|
||||
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
|
||||
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
|
||||
if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
|
||||
if (!strcmp(name, "force_local")) force_local_ = atoi(val);
|
||||
if (!strcmp(name, "mock")) {
|
||||
MockKey k;
|
||||
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
||||
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
|
||||
"invalid mock parameter");
|
||||
mock_map_[k] = 1;
|
||||
}
|
||||
}
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||
void *prepare_arg) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
|
||||
double tstart = dmlc::GetTime();
|
||||
AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer,
|
||||
prepare_fun, prepare_arg);
|
||||
tsum_allreduce_ += dmlc::GetTime() - tstart;
|
||||
}
|
||||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
|
||||
double tstart = dmlc::GetTime();
|
||||
AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end,
|
||||
size_prev_slice);
|
||||
tsum_allgather_ += dmlc::GetTime() - tstart;
|
||||
}
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
|
||||
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
int LoadCheckPoint() override {
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
time_checkpoint_ = dmlc::GetTime();
|
||||
if (force_local_ == 0) {
|
||||
return AllreduceBase::LoadCheckPoint();
|
||||
} else {
|
||||
return AllreduceBase::LoadCheckPoint();
|
||||
}
|
||||
}
|
||||
void CheckPoint() override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
|
||||
double tstart = dmlc::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint_;
|
||||
AllreduceBase::CheckPoint();
|
||||
time_checkpoint_ = dmlc::GetTime();
|
||||
double tcost = dmlc::GetTime() - tstart;
|
||||
if (report_stats_ != 0 && rank == 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[v" << version_number << "] global_size="
|
||||
<< ",check_tcost="<< tcost <<" sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
|
||||
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
|
||||
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||
this->TrackerPrint(ss.str());
|
||||
}
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
}
|
||||
|
||||
protected:
|
||||
// force checkpoint to local
|
||||
int force_local_;
|
||||
// whether report statistics
|
||||
int report_stats_;
|
||||
// sum of allreduce
|
||||
double tsum_allreduce_;
|
||||
// sum of allgather
|
||||
double tsum_allgather_;
|
||||
double time_checkpoint_;
|
||||
|
||||
private:
|
||||
// key to identify the mock stage
|
||||
struct MockKey {
|
||||
int rank;
|
||||
int version;
|
||||
int seqno;
|
||||
int ntrial;
|
||||
MockKey() = default;
|
||||
MockKey(int rank, int version, int seqno, int ntrial)
|
||||
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
||||
inline bool operator==(const MockKey &b) const {
|
||||
return rank == b.rank &&
|
||||
version == b.version &&
|
||||
seqno == b.seqno &&
|
||||
ntrial == b.ntrial;
|
||||
}
|
||||
inline bool operator<(const MockKey &b) const {
|
||||
if (rank != b.rank) return rank < b.rank;
|
||||
if (version != b.version) return version < b.version;
|
||||
if (seqno != b.seqno) return seqno < b.seqno;
|
||||
return ntrial < b.ntrial;
|
||||
}
|
||||
};
|
||||
// number of failure trials
|
||||
int num_trial_;
|
||||
// record all mock actions
|
||||
std::map<MockKey, int> mock_map_;
|
||||
// used to generate all kinds of exceptions
|
||||
inline void Verify(const MockKey &key, const char *name) {
|
||||
if (mock_map_.count(key) != 0) {
|
||||
num_trial_ += 1;
|
||||
// data processing frameworks runs on shared process
|
||||
throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ALLREDUCE_MOCK_H_
|
||||
@ -1,106 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine.cc
|
||||
* \brief this file governs which implementation of engine we are actually using
|
||||
* provides an singleton of engine interface
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#include <rabit/base.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include <memory>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// singleton sync manager
|
||||
#ifndef RABIT_USE_BASE
|
||||
#ifndef RABIT_USE_MOCK
|
||||
using Manager = AllreduceBase;
|
||||
#else
|
||||
typedef AllreduceMock Manager;
|
||||
#endif // RABIT_USE_MOCK
|
||||
#else
|
||||
typedef AllreduceBase Manager;
|
||||
#endif // RABIT_USE_BASE
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct ThreadLocalEntry {
|
||||
/*! \brief stores the current engine */
|
||||
std::unique_ptr<Manager> engine;
|
||||
/*! \brief whether init has been called */
|
||||
bool initialized{false};
|
||||
/*! \brief constructor */
|
||||
ThreadLocalEntry() = default;
|
||||
};
|
||||
|
||||
// define the threadlocal store.
|
||||
using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
if (e->engine.get() == nullptr) {
|
||||
e->initialized = true;
|
||||
e->engine.reset(new Manager());
|
||||
return e->engine->Init(argc, argv);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief finalize syncrhonization module */
|
||||
bool Finalize() {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
if (e->engine.get() != nullptr) {
|
||||
if (e->engine->Shutdown()) {
|
||||
e->engine.reset(nullptr);
|
||||
e->initialized = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine() {
|
||||
// un-initialized default manager.
|
||||
static AllreduceBase default_manager;
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
IEngine* ptr = e->engine.get();
|
||||
if (ptr == nullptr) {
|
||||
utils::Check(!e->initialized, "the rabit has not been initialized");
|
||||
return &default_manager;
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
// perform in-place allgather, on sendrecvbuf
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
|
||||
slice_end, size_prev_slice);
|
||||
}
|
||||
|
||||
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType,
|
||||
mpi::OpType ,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
@ -1,14 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mock.cc
|
||||
* \brief this is an engine implementation that will
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
// define use MOCK, os we will use mock Manager
|
||||
#define NOMINMAX
|
||||
// switch engine to AllreduceMock
|
||||
#define RABIT_USE_MOCK
|
||||
#include <rabit/base.h>
|
||||
#include "allreduce_mock.h"
|
||||
#include "engine.cc"
|
||||
@ -1,342 +0,0 @@
|
||||
// Copyright by Contributors
|
||||
// implementations in ctypes
|
||||
#include <rabit/base.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include "rabit/rabit.h"
|
||||
#include "rabit/c_api.h"
|
||||
|
||||
#include "../../src/c_api/c_api_error.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace c_api {
|
||||
// helper use to avoid BitOR operator
|
||||
template<typename OP, typename DType>
|
||||
struct FHelper {
|
||||
static void
|
||||
Allreduce(DType *senrecvbuf_,
|
||||
size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
rabit::Allreduce<OP>(senrecvbuf_, count,
|
||||
prepare_fun, prepare_arg);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename DType>
|
||||
struct FHelper<op::BitAND, DType> {
|
||||
static void
|
||||
Allreduce(DType *,
|
||||
size_t ,
|
||||
void (*)(void *arg),
|
||||
void *) {
|
||||
utils::Error("DataType does not support bitwise AND operation");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename DType>
|
||||
struct FHelper<op::BitOR, DType> {
|
||||
static void
|
||||
Allreduce(DType *,
|
||||
size_t ,
|
||||
void (*)(void *arg),
|
||||
void *) {
|
||||
utils::Error("DataType does not support bitwise OR operation");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename DType>
|
||||
struct FHelper<op::BitXOR, DType> {
|
||||
static void
|
||||
Allreduce(DType *,
|
||||
size_t ,
|
||||
void (*)(void *arg),
|
||||
void *) {
|
||||
utils::Error("DataType does not support bitwise XOR operation");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename OP>
|
||||
void Allreduce(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi; // NOLINT
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
rabit::Allreduce<OP>
|
||||
(static_cast<char*>(sendrecvbuf_),
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kUChar:
|
||||
rabit::Allreduce<OP>
|
||||
(static_cast<unsigned char*>(sendrecvbuf_),
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kInt:
|
||||
rabit::Allreduce<OP>
|
||||
(static_cast<int*>(sendrecvbuf_),
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kUInt:
|
||||
rabit::Allreduce<OP>
|
||||
(static_cast<unsigned*>(sendrecvbuf_),
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kLong:
|
||||
rabit::Allreduce<OP>
|
||||
(static_cast<long*>(sendrecvbuf_), // NOLINT(*)
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kULong:
|
||||
rabit::Allreduce<OP>
|
||||
(static_cast<unsigned long*>(sendrecvbuf_), // NOLINT(*)
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kFloat:
|
||||
FHelper<OP, float>::Allreduce
|
||||
(static_cast<float*>(sendrecvbuf_),
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kDouble:
|
||||
FHelper<OP, double>::Allreduce
|
||||
(static_cast<double*>(sendrecvbuf_),
|
||||
count, prepare_fun, prepare_arg);
|
||||
return;
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
}
|
||||
void Allreduce(void *sendrecvbuf,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
engine::mpi::OpType enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi; // NOLINT
|
||||
switch (enum_op) {
|
||||
case kMax:
|
||||
Allreduce<op::Max>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kMin:
|
||||
Allreduce<op::Min>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kSum:
|
||||
Allreduce<op::Sum>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kBitwiseAND:
|
||||
Allreduce<op::BitAND>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kBitwiseOR:
|
||||
Allreduce<op::BitOR>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kBitwiseXOR:
|
||||
Allreduce<op::BitXOR>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
default: utils::Error("unknown enum_op");
|
||||
}
|
||||
}
|
||||
|
||||
void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t beginIndex,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
using namespace engine::mpi; // NOLINT
|
||||
size_t type_size = 0;
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
type_size = sizeof(char);
|
||||
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kUChar:
|
||||
type_size = sizeof(unsigned char);
|
||||
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kInt:
|
||||
type_size = sizeof(int);
|
||||
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kUInt:
|
||||
type_size = sizeof(unsigned);
|
||||
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kLong:
|
||||
type_size = sizeof(int64_t);
|
||||
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kULong:
|
||||
type_size = sizeof(uint64_t);
|
||||
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kFloat:
|
||||
type_size = sizeof(float);
|
||||
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kDouble:
|
||||
type_size = sizeof(double);
|
||||
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
}
|
||||
|
||||
// wrapper for serialization
|
||||
struct ReadWrapper : public Serializable {
|
||||
std::string *p_str;
|
||||
explicit ReadWrapper(std::string *p_str)
|
||||
: p_str(p_str) {}
|
||||
void Load(Stream *fi) override {
|
||||
uint64_t sz;
|
||||
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
|
||||
"Read pickle string");
|
||||
p_str->resize(sz);
|
||||
if (sz != 0) {
|
||||
utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0,
|
||||
"Read pickle string");
|
||||
}
|
||||
}
|
||||
void Save(Stream *) const override {
|
||||
utils::Error("not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
struct WriteWrapper : public Serializable {
|
||||
const char *data;
|
||||
size_t length;
|
||||
explicit WriteWrapper(const char *data,
|
||||
size_t length)
|
||||
: data(data), length(length) {
|
||||
}
|
||||
void Load(Stream *) override {
|
||||
utils::Error("not implemented");
|
||||
}
|
||||
void Save(Stream *fo) const override {
|
||||
uint64_t sz = static_cast<uint16_t>(length);
|
||||
fo->Write(&sz, sizeof(sz));
|
||||
fo->Write(data, length * sizeof(char));
|
||||
}
|
||||
};
|
||||
} // namespace c_api
|
||||
} // namespace rabit
|
||||
|
||||
RABIT_DLL bool RabitInit(int argc, char *argv[]) {
|
||||
auto ret = rabit::Init(argc, argv);
|
||||
if (!ret) {
|
||||
XGBAPISetLastError("Failed to initialize RABIT.");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitFinalize() {
|
||||
auto ret = rabit::Finalize();
|
||||
if (!ret) {
|
||||
XGBAPISetLastError("Failed to shutdown RABIT worker.");
|
||||
}
|
||||
return static_cast<int>(ret);
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitGetRingPrevRank() {
|
||||
return rabit::GetRingPrevRank();
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitGetRank() {
|
||||
return rabit::GetRank();
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitGetWorldSize() {
|
||||
return rabit::GetWorldSize();
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitIsDistributed() {
|
||||
return rabit::IsDistributed();
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitTrackerPrint(const char *msg) {
|
||||
API_BEGIN()
|
||||
std::string m(msg);
|
||||
rabit::TrackerPrint(m);
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||
rbt_ulong *out_len,
|
||||
rbt_ulong max_len) {
|
||||
std::string s = rabit::GetProcessorName();
|
||||
if (s.length() > max_len) {
|
||||
s.resize(max_len - 1);
|
||||
}
|
||||
strcpy(out_name, s.c_str()); // NOLINT(*)
|
||||
*out_len = static_cast<rbt_ulong>(s.length());
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitBroadcast(void *sendrecv_data,
|
||||
rbt_ulong size, int root) {
|
||||
API_BEGIN()
|
||||
rabit::Broadcast(sendrecv_data, size, root);
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t beginIndex, size_t size_node_slice,
|
||||
size_t size_prev_slice, int enum_dtype) {
|
||||
API_BEGIN()
|
||||
rabit::c_api::Allgather(
|
||||
sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
|
||||
int enum_op, void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
API_BEGIN()
|
||||
rabit::c_api::Allreduce(sendrecvbuf, count,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
|
||||
static_cast<rabit::engine::mpi::OpType>(enum_op),
|
||||
prepare_fun, prepare_arg);
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitVersionNumber() {
|
||||
return rabit::VersionNumber();
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitLinkTag() {
|
||||
return 0;
|
||||
}
|
||||
@ -15,9 +15,9 @@
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
@ -27,11 +27,10 @@
|
||||
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
|
||||
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
|
||||
#include "dmlc/base.h" // for BeginPtr
|
||||
#include "dmlc/io.h" // for Stream
|
||||
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
|
||||
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||
#include "rabit/c_api.h" // for RabitLinkTag
|
||||
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
|
||||
@ -46,10 +45,6 @@
|
||||
#include "xgboost/string_view.h" // for StringView, operator<<
|
||||
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
#endif
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
|
||||
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||
@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
|
||||
*out_features = dmlc::BeginPtr(feature_names_c);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::Init(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize() {
|
||||
API_BEGIN();
|
||||
collective::Finalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
return collective::GetRank();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||
return collective::GetWorldSize();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||
return collective::IsDistributed();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
|
||||
char const *server_cert_path, char const *client_cert_path) {
|
||||
API_BEGIN();
|
||||
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Run a server without SSL for local testing.
|
||||
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
|
||||
API_BEGIN();
|
||||
federated::RunInsecureServer(port, world_size);
|
||||
API_END();
|
||||
}
|
||||
#endif
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
@ -1,22 +1,28 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* \file c_api_error.cc
|
||||
* \brief C error handling
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include "xgboost/c_api.h"
|
||||
#include "./c_api_error.h"
|
||||
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
#include "../collective/comm.h"
|
||||
#include "../collective/comm_group.h"
|
||||
|
||||
struct XGBAPIErrorEntry {
|
||||
std::string last_error;
|
||||
std::int32_t code{-1};
|
||||
};
|
||||
|
||||
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
|
||||
|
||||
XGB_DLL const char *XGBGetLastError() {
|
||||
return XGBAPIErrorStore::Get()->last_error.c_str();
|
||||
}
|
||||
XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); }
|
||||
|
||||
void XGBAPISetLastError(const char* msg) {
|
||||
XGBAPIErrorStore::Get()->last_error = msg;
|
||||
XGBAPIErrorStore::Get()->code = -1;
|
||||
}
|
||||
|
||||
XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; }
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
#include "c_api_utils.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
/*! \brief macro to guard beginning and end section of all functions */
|
||||
#ifdef LOG_CAPI_INVOCATION
|
||||
@ -30,7 +31,7 @@
|
||||
#define API_END() \
|
||||
} catch (dmlc::Error & _except_) { \
|
||||
return XGBAPIHandleException(_except_); \
|
||||
} catch (std::exception const &_except_) { \
|
||||
} catch (std::exception const& _except_) { \
|
||||
return XGBAPIHandleException(dmlc::Error(_except_.what())); \
|
||||
} \
|
||||
return 0; // NOLINT(*)
|
||||
@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg);
|
||||
* \param e the exception
|
||||
* \return the return value of API after exception is handled
|
||||
*/
|
||||
inline int XGBAPIHandleException(const dmlc::Error &e) {
|
||||
inline int XGBAPIHandleException(const dmlc::Error& e) {
|
||||
XGBAPISetLastError(e.what());
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -9,10 +9,15 @@
|
||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "../collective/allgather.h" // for Allgather
|
||||
#include "../collective/allreduce.h" // for Allreduce
|
||||
#include "../collective/broadcast.h" // for Broadcast
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/comm_group.h" // for GlobalCommGroup
|
||||
#include "../collective/communicator-inl.h" // for GetProcessorName
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
@ -20,10 +25,36 @@
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
|
||||
#else
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) {
|
||||
Context ctx;
|
||||
DispatchDType(static_cast<ArrayInterfaceHandler::Type>(data_type), [&](auto t) {
|
||||
using T = decltype(t);
|
||||
auto data = linalg::MakeTensorView(
|
||||
&ctx, common::Span{static_cast<T *>(send_receive_buffer), count}, count);
|
||||
auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast<Op>(op));
|
||||
SafeColl(rc);
|
||||
});
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) {
|
||||
Context ctx;
|
||||
auto rc = Broadcast(&ctx, *GlobalCommGroup(),
|
||||
linalg::MakeVec(static_cast<std::int8_t *>(send_receive_buffer), size), root);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void Allgather(void *send_receive_buffer, std::size_t size) {
|
||||
Context ctx;
|
||||
auto const &comm = GlobalCommGroup();
|
||||
auto rc = Allgather(&ctx, *comm,
|
||||
linalg::MakeVec(reinterpret_cast<std::int8_t *>(send_receive_buffer), size));
|
||||
SafeColl(rc);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
namespace {
|
||||
@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
|
||||
: kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||
if (timer.Duration() > timeout && collective::HasTimeout(timeout)) {
|
||||
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||
}
|
||||
}
|
||||
@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
// Make sure no one else is waiting on the tracker.
|
||||
while (!ptr->first.unique()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > ptr->first->Timeout().count()) {
|
||||
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
|
||||
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||
break;
|
||||
@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const *json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::GlobalCommGroupInit(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize(void) {
|
||||
API_BEGIN();
|
||||
collective::GlobalCommGroupFinalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
API_BEGIN();
|
||||
return collective::GetRank();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); }
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); }
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto &local = *CollAPIThreadLocalStore::Get();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Not exposed to the public since the previous implementation didn't and we don't want to
|
||||
// add unnecessary communicator API to a machine learning library.
|
||||
XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) {
|
||||
API_BEGIN();
|
||||
collective::Allgather(send_receive_buffer, count);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Not yet exposed to the public, error recovery is still WIP.
|
||||
XGB_DLL int XGCommunicatorSignalError() {
|
||||
API_BEGIN();
|
||||
auto msg = XGBGetLastError();
|
||||
SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg)));
|
||||
API_END()
|
||||
}
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/common.h"
|
||||
#include "common/config.h"
|
||||
#include "common/io.h"
|
||||
@ -193,10 +192,6 @@ class CLI {
|
||||
|
||||
void CLITrain() {
|
||||
const double tstart_data_load = dmlc::GetTime();
|
||||
if (collective::IsDistributed()) {
|
||||
std::string pname = collective::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
||||
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
@ -235,15 +230,9 @@ class CLI {
|
||||
version += 1;
|
||||
}
|
||||
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||
if (collective::IsDistributed()) {
|
||||
if (collective::GetRank() == 0) {
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
|
||||
collective::GetRank() == 0) {
|
||||
LOG(CONSOLE) << res;
|
||||
|
||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0) {
|
||||
std::ostringstream os;
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
<< i + 1 << ".model";
|
||||
@ -256,8 +245,7 @@ class CLI {
|
||||
<< " sec";
|
||||
// always save final round
|
||||
if ((param_.save_period == 0 ||
|
||||
param_.num_round % param_.save_period != 0) &&
|
||||
collective::GetRank() == 0) {
|
||||
param_.num_round % param_.save_period != 0)) {
|
||||
std::ostringstream os;
|
||||
if (param_.model_out == CLIParam::kNull) {
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
@ -465,13 +453,6 @@ class CLI {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the collective communicator.
|
||||
Json json{JsonObject()};
|
||||
for (auto& kv : cfg) {
|
||||
json[kv.first] = String(kv.second);
|
||||
}
|
||||
collective::Init(json);
|
||||
|
||||
param_.Configure(cfg);
|
||||
}
|
||||
|
||||
@ -507,10 +488,6 @@ class CLI {
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
~CLI() {
|
||||
collective::Finalize();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*
|
||||
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||
@ -13,7 +13,8 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
#include "allreduce.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
@ -24,15 +25,17 @@ namespace xgboost::collective {
|
||||
* column-wise (vertically), the original values are returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
*
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param device The device id.
|
||||
* @param values Pointer to the inputs to sum.
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<T, kDim> values) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
|
||||
return collective::Allreduce(ctx, values, collective::Op::kSum);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -11,11 +11,44 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "allreduce.h"
|
||||
#include "broadcast.h"
|
||||
#include "comm.h"
|
||||
#include "communicator-inl.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/data.h" // for MetaINfo
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
template <typename Fn>
|
||||
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
|
||||
std::string msg;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
fn();
|
||||
} catch (dmlc::Error const& e) {
|
||||
msg = e.what();
|
||||
}
|
||||
}
|
||||
std::size_t msg_size{msg.size()};
|
||||
auto rc = Success() << [&] {
|
||||
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
|
||||
return rc;
|
||||
} << [&] {
|
||||
if (msg_size > 0) {
|
||||
msg.resize(msg_size);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0);
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
if (msg_size > 0) {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
return rc;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
@ -30,29 +63,19 @@ namespace xgboost::collective {
|
||||
* @param size The size of the buffer.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename FN>
|
||||
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
|
||||
FN&& function) {
|
||||
template <typename Fn>
|
||||
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size,
|
||||
Fn&& fn) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<FN>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (message.empty()) {
|
||||
collective::Broadcast(buffer, size, 0);
|
||||
} else {
|
||||
LOG(FATAL) << &message[0];
|
||||
}
|
||||
auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and
|
||||
// result broadcast to other workers.
|
||||
return collective::Broadcast(
|
||||
ctx, linalg::MakeVec(reinterpret_cast<std::int8_t*>(buffer), size), 0);
|
||||
};
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
std::forward<FN>(function)();
|
||||
std::forward<Fn>(fn)();
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
|
||||
* @param result The HostDeviceVector storing the results.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename T, typename Function>
|
||||
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
|
||||
Function&& function) {
|
||||
template <typename T, typename Fn>
|
||||
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
|
||||
Fn&& fn) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
auto rc = detail::TryApplyWithLabels(ctx, fn);
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (!message.empty()) {
|
||||
LOG(FATAL) << &message[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t size{};
|
||||
if (collective::GetRank() == 0) {
|
||||
size = result->Size();
|
||||
}
|
||||
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||
|
||||
result->Resize(size);
|
||||
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||
std::size_t size{result->Size()};
|
||||
rc = std::move(rc) << [&] {
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
|
||||
} << [&] {
|
||||
result->Resize(size);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0);
|
||||
};
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
std::forward<Function>(function)();
|
||||
std::forward<Fn>(fn)();
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
|
||||
* @return The global max of the input.
|
||||
*/
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
|
||||
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const* ctx,
|
||||
MetaInfo const& info,
|
||||
T value) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(&value, 1);
|
||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax);
|
||||
SafeColl(rc);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@ -136,19 +147,14 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<T, kDim> values) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
|
||||
return collective::Allreduce(ctx, values, collective::Op::kSum);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
|
||||
return GlobalSum(ctx, info, values->data(), values->size());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global ratio of the given two values across all workers.
|
||||
*
|
||||
|
||||
@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> si
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r),
|
||||
std::move(rc));
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
@ -102,7 +103,7 @@ namespace detail {
|
||||
return prev_ch->Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc));
|
||||
}
|
||||
}
|
||||
return comm.Block();
|
||||
|
||||
@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func
|
||||
auto rc = RingAllgather(comm, typed);
|
||||
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring allreduce small failed.", std::move(rc));
|
||||
}
|
||||
auto first = s_buffer.subspan(0, data.size_bytes());
|
||||
CHECK_EQ(first.size(), data.size());
|
||||
@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto next_ch = comm.Chan(dst_rank);
|
||||
auto prev_ch = comm.Chan(src_rank);
|
||||
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1);
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r),
|
||||
std::move(rc));
|
||||
}
|
||||
|
||||
// accumulate to recv_seg
|
||||
CHECK_EQ(seg.size(), recv_seg.size());
|
||||
@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
auto n_bytes_in_seg = (n / world) * sizeof(T);
|
||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring Allreduce failed.", std::move(rc));
|
||||
}
|
||||
|
||||
auto prev = BootstrapPrev(comm.Rank(), comm.World());
|
||||
|
||||
@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
}
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.");
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
} else if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.", std::move(rc));
|
||||
}
|
||||
workers[r] = std::move(worker);
|
||||
}
|
||||
@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
return rc;
|
||||
}
|
||||
std::int32_t rank{-1};
|
||||
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
} else if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to recv rank.");
|
||||
}
|
||||
workers[rank] = std::move(peer);
|
||||
|
||||
@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const {
|
||||
CHECK(loop_);
|
||||
loop_->Submit(op);
|
||||
loop_->Submit(std::move(op));
|
||||
}
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ CommGroup::CommGroup()
|
||||
// Common args
|
||||
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
|
||||
auto timeout =
|
||||
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
get_param("dmlc_timeout", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||
|
||||
if (type == "rabit") {
|
||||
@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() {
|
||||
sptr.reset();
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void Init(Json const& config) { GlobalCommGroupInit(config); }
|
||||
|
||||
void Finalize() { GlobalCommGroupFinalize(); }
|
||||
|
||||
std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
|
||||
|
||||
std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
|
||||
|
||||
bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
|
||||
|
||||
[[nodiscard]] bool IsFederated() {
|
||||
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
|
||||
}
|
||||
|
||||
void Print(std::string const& message) {
|
||||
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() {
|
||||
std::string out;
|
||||
auto rc = GlobalCommGroup()->ProcessorName(&out);
|
||||
SafeColl(rc);
|
||||
return out;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost contributors
|
||||
*/
|
||||
#include "communicator-inl.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
std::vector<std::vector<char>> const &input) {
|
||||
auto n_inputs = input.size();
|
||||
std::vector<std::int64_t> sizes(n_inputs);
|
||||
std::transform(input.cbegin(), input.cend(), sizes.begin(),
|
||||
[](auto const &vec) { return vec.size(); });
|
||||
|
||||
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
|
||||
std::vector<std::int64_t> offset(global_sizes.size() + 1);
|
||||
offset[0] = 0;
|
||||
for (std::size_t i = 1; i < offset.size(); i++) {
|
||||
offset[i] = offset[i - 1] + global_sizes[i - 1];
|
||||
}
|
||||
|
||||
std::vector<char> collected;
|
||||
for (auto const &vec : input) {
|
||||
collected.insert(collected.end(), vec.cbegin(), vec.cend());
|
||||
}
|
||||
auto out = AllgatherV(collected);
|
||||
|
||||
std::vector<std::vector<char>> result;
|
||||
for (std::size_t i = 1; i < offset.size(); ++i) {
|
||||
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
|
||||
result.emplace_back(std::move(local));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@ -1,95 +0,0 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Reduce values from all processes and distribute the result back to all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronize device operations.
|
||||
* @param device ID of the device.
|
||||
*/
|
||||
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -3,308 +3,63 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Initialize the collective communicator.
|
||||
*/
|
||||
void Init(Json const& config);
|
||||
|
||||
/**
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * mpi: Use MPI.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
* - rabit_debug: Enable debugging.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const &config) { Communicator::Init(config); }
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
* @brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
*/
|
||||
inline void Finalize() { Communicator::Finalize(); }
|
||||
void Finalize();
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
/**
|
||||
* @brief Get rank of current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
* @return Rank of the worker.
|
||||
*/
|
||||
inline int GetRank() { return Communicator::Get()->GetRank(); }
|
||||
[[nodiscard]] std::int32_t GetRank() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
/**
|
||||
* @brief Get total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
* @return Total world size.
|
||||
*/
|
||||
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||
[[nodiscard]] std::int32_t GetWorldSize() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
/**
|
||||
* @brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
* @return True if the communicator is distributed.
|
||||
*/
|
||||
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||
[[nodiscard]] bool IsDistributed() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is federated.
|
||||
/**
|
||||
* @brief Get if the communicator is federated.
|
||||
*
|
||||
* \return True if the communicator is federated.
|
||||
* @return True if the communicator is federated.
|
||||
*/
|
||||
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
|
||||
[[nodiscard]] bool IsFederated();
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
/**
|
||||
* @brief Print the message to the communicator.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
* @param message The message to be printed.
|
||||
*/
|
||||
inline void Print(char const *message) { Communicator::Get()->Print(message); }
|
||||
|
||||
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
*
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
}
|
||||
}
|
||||
|
||||
void Print(std::string const& message);
|
||||
/**
|
||||
* @brief Gathers a single value all processes and distributes the result to all processes.
|
||||
* @brief Get the name of the processor.
|
||||
*
|
||||
* @param input The single value.
|
||||
* @return Name of the processor.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(T const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(std::vector<T> const &input) {
|
||||
if (input.empty()) {
|
||||
return input;
|
||||
}
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGatherV(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
if (!output.empty()) {
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
*
|
||||
* @param inputs All the inputs from the local worker. The number of inputs can vary
|
||||
* across different workers. Along with which, the size of each vector in
|
||||
* the input can also vary.
|
||||
*
|
||||
* @return The AllgatherV result, containing vectors from all workers.
|
||||
*/
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
std::vector<std::vector<char>> const &input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
|
||||
std::size_t total_size{0};
|
||||
for (auto const &s : input) {
|
||||
total_size += s.length() + 1; // +1 for null-terminators
|
||||
}
|
||||
std::string flat_string;
|
||||
flat_string.reserve(total_size);
|
||||
for (auto const &s : input) {
|
||||
flat_string.append(s);
|
||||
flat_string.push_back('\0'); // Append a null-terminator after each string
|
||||
}
|
||||
|
||||
auto const output = Communicator::Get()->AllGatherV(flat_string);
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::size_t start_index = 0;
|
||||
// Iterate through the output, find each null-terminated substring.
|
||||
for (std::size_t i = 0; i < output.size(); i++) {
|
||||
if (output[i] == '\0') {
|
||||
// Construct a std::string from the char* substring
|
||||
result.emplace_back(&output[start_index]);
|
||||
// Move to the next substring
|
||||
start_index = i + 1;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*/
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
|
||||
static_cast<Operation>(op));
|
||||
}
|
||||
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
// Specialization for size_t, which is implementation defined, so it might or might not
|
||||
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
template <Operation op, typename T,
|
||||
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
|
||||
inline void Allreduce(T *send_receive_buffer, size_t count) {
|
||||
static_assert(sizeof(T) == sizeof(uint64_t));
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(float *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
std::string GetProcessorName();
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -1,63 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "comm.h"
|
||||
#include "in_memory_communicator.h"
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
thread_local std::string Communicator::nccl_path_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
|
||||
nccl_path_ = nccl;
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
type = arg;
|
||||
}
|
||||
if (type == CommunicatorType::kUnknown) {
|
||||
// Default to Rabit if unspecified.
|
||||
type = CommunicatorType::kRabit;
|
||||
}
|
||||
type_ = type;
|
||||
switch (type) {
|
||||
case CommunicatorType::kRabit: {
|
||||
communicator_.reset(RabitCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kFederated: {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
communicator_.reset(FederatedCommunicator::Create(config));
|
||||
#else
|
||||
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kInMemory:
|
||||
case CommunicatorType::kInMemoryNccl: {
|
||||
communicator_.reset(InMemoryCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kUnknown:
|
||||
LOG(FATAL) << "Unknown communicator type.";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
} // namespace xgboost::collective
|
||||
@ -1,54 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "device_communicator_adapter.cuh"
|
||||
#include "noop_communicator.h"
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl_device_communicator.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
thread_local auto old_device_ordinal = -1;
|
||||
// If the number of GPUs changes, we need to re-initialize NCCL.
|
||||
thread_local auto old_world_size = -1;
|
||||
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
|
||||
communicator_->GetWorldSize() != old_world_size) {
|
||||
old_device_ordinal = device_ordinal;
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemoryNccl:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
#endif
|
||||
}
|
||||
return device_communicator_.get();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -1,247 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/** @brief Defines the integral and floating data types. */
|
||||
enum class DataType {
|
||||
kInt8 = 0,
|
||||
kUInt8 = 1,
|
||||
kInt32 = 2,
|
||||
kUInt32 = 3,
|
||||
kInt64 = 4,
|
||||
kUInt64 = 5,
|
||||
kFloat = 6,
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation {
|
||||
kMax = 0,
|
||||
kMin = 1,
|
||||
kSum = 2,
|
||||
kBitwiseAND = 3,
|
||||
kBitwiseOR = 4,
|
||||
kBitwiseXOR = 5
|
||||
};
|
||||
|
||||
class DeviceCommunicator;
|
||||
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
|
||||
|
||||
/** \brief Case-insensitive string comparison. */
|
||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||
#ifdef _MSC_VER
|
||||
return _stricmp(s1, s2);
|
||||
#else // _MSC_VER
|
||||
return strcasecmp(s1, s2);
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A communicator class that handles collective communication.
|
||||
*/
|
||||
class Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Initialize the communicator. This can only be done once.
|
||||
*
|
||||
* @param config JSON configuration for the communicator.
|
||||
*/
|
||||
static void Init(Json const &config);
|
||||
|
||||
/** @brief Finalize the communicator. */
|
||||
static void Finalize();
|
||||
|
||||
/** @brief Get the communicator instance. */
|
||||
static Communicator *Get() { return communicator_.get(); }
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
/**
|
||||
* @brief Get the device communicator.
|
||||
*
|
||||
* @param device_ordinal ID of the device.
|
||||
* @return An instance of device communicator.
|
||||
*/
|
||||
static DeviceCommunicator *GetDevice(int device_ordinal);
|
||||
#endif
|
||||
|
||||
virtual ~Communicator() = default;
|
||||
|
||||
/** @brief Get the total number of processes. */
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
|
||||
/** @brief Get the rank of the current processes. */
|
||||
int GetRank() const { return rank_; }
|
||||
|
||||
/** @brief Whether the communicator is running in distributed mode. */
|
||||
virtual bool IsDistributed() const = 0;
|
||||
|
||||
/** @brief Whether the communicator is running in federated mode. */
|
||||
virtual bool IsFederated() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGather(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGatherV(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
|
||||
* group.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param root Rank of broadcast root.
|
||||
*/
|
||||
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the name of the processor.
|
||||
*/
|
||||
virtual std::string GetProcessorName() = 0;
|
||||
|
||||
/**
|
||||
* @brief Prints the message.
|
||||
*/
|
||||
virtual void Print(std::string const &message) = 0;
|
||||
|
||||
/** @brief Get the communicator type from environment variables. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromEnv() {
|
||||
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
|
||||
if (env != nullptr) {
|
||||
return StringToType(env);
|
||||
} else {
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromConfig(Json const &config) {
|
||||
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
|
||||
if (IsA<String const>(j_upper)) {
|
||||
return StringToType(get<String const>(j_upper).c_str());
|
||||
}
|
||||
auto const &j_lower = config["xgboost_communicator"];
|
||||
if (IsA<String const>(j_lower)) {
|
||||
return StringToType(get<String const>(j_lower).c_str());
|
||||
}
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Construct a new communicator.
|
||||
*
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
*/
|
||||
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
|
||||
if (world_size < 1) {
|
||||
LOG(FATAL) << "World size " << world_size << " is less than 1.";
|
||||
}
|
||||
if (rank < 0) {
|
||||
LOG(FATAL) << "Rank " << rank << " is less than 0.";
|
||||
}
|
||||
if (rank >= world_size) {
|
||||
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuts down the communicator.
|
||||
*/
|
||||
virtual void Shutdown() = 0;
|
||||
|
||||
private:
|
||||
static CommunicatorType StringToType(char const *str) {
|
||||
CommunicatorType result = CommunicatorType::kUnknown;
|
||||
if (!CompareStringsCaseInsensitive("rabit", str)) {
|
||||
result = CommunicatorType::kRabit;
|
||||
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
||||
result = CommunicatorType::kFederated;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
|
||||
result = CommunicatorType::kInMemory;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
|
||||
result = CommunicatorType::kInMemoryNccl;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator type " << str;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
static thread_local std::string nccl_path_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -1,57 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <vector>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Collective communicator for device buffers.
|
||||
*/
|
||||
class DeviceCommunicator {
|
||||
public:
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) = 0;
|
||||
|
||||
/** @brief Synchronize device operations. */
|
||||
virtual void Synchronize() = 0;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -1,94 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
public:
|
||||
explicit DeviceCommunicatorAdapter(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.resize(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
|
||||
auto const output = Allgather(host_buffer_);
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.resize(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank_) {
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
// Noop.
|
||||
}
|
||||
|
||||
private:
|
||||
int const device_ordinal_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
std::vector<char> host_buffer_{};
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -1,12 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
InMemoryHandler InMemoryCommunicator::handler_{};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -15,14 +15,14 @@ namespace collective {
|
||||
/**
|
||||
* An in-memory communicator, useful for testing.
|
||||
*/
|
||||
class InMemoryCommunicator : public Communicator {
|
||||
class InMemoryCommunicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new communicator based on JSON configuration.
|
||||
* @param config JSON configuration.
|
||||
* @return Communicator as specified by the JSON configuration.
|
||||
*/
|
||||
static Communicator* Create(Json const& config) {
|
||||
static InMemoryCommunicator* Create(Json const& config) {
|
||||
int world_size{0};
|
||||
int rank{-1};
|
||||
|
||||
@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator {
|
||||
return new InMemoryCommunicator(world_size, rank);
|
||||
}
|
||||
|
||||
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
|
||||
InMemoryCommunicator(int world_size, int rank) {
|
||||
handler_.Init(world_size, rank);
|
||||
}
|
||||
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_handler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "comm.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Functor for allgather.
|
||||
*/
|
||||
@ -16,7 +15,7 @@ class AllgatherFunctor {
|
||||
public:
|
||||
std::string const name{"Allgather"};
|
||||
|
||||
AllgatherFunctor(std::size_t world_size, std::size_t rank)
|
||||
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
|
||||
: world_size_{world_size}, rank_{rank} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
@ -30,8 +29,8 @@ class AllgatherFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::int32_t world_size_;
|
||||
std::int32_t rank_;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -41,13 +40,13 @@ class AllgatherVFunctor {
|
||||
public:
|
||||
std::string const name{"AllgatherV"};
|
||||
|
||||
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
|
||||
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
|
||||
std::map<std::size_t, std::string_view>* data)
|
||||
: world_size_{world_size}, rank_{rank}, data_{data} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
data_->emplace(rank_, std::string_view{input, bytes});
|
||||
if (data_->size() == world_size_) {
|
||||
if (data_->size() == static_cast<std::size_t>(world_size_)) {
|
||||
for (auto const& kv : *data_) {
|
||||
buffer->append(kv.second);
|
||||
}
|
||||
@ -56,8 +55,8 @@ class AllgatherVFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::int32_t world_size_;
|
||||
std::int32_t rank_;
|
||||
std::map<std::size_t, std::string_view>* data_;
|
||||
};
|
||||
|
||||
@ -68,7 +67,7 @@ class AllreduceFunctor {
|
||||
public:
|
||||
std::string const name{"Allreduce"};
|
||||
|
||||
AllreduceFunctor(DataType dataType, Operation operation)
|
||||
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
|
||||
: data_type_{dataType}, operation_{operation} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
@ -76,23 +75,23 @@ class AllreduceFunctor {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
|
||||
// Apply the reduce_operation to the input and the buffer.
|
||||
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
|
||||
Accumulate(input, bytes / n_bytes_type, &buffer->front());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
|
||||
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
|
||||
Operation reduce_operation) const {
|
||||
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kBitwiseAND:
|
||||
case Op::kBitwiseAND:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
case Op::kBitwiseOR:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
case Op::kBitwiseXOR:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
|
||||
break;
|
||||
default:
|
||||
@ -101,27 +100,27 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
|
||||
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
|
||||
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kMax:
|
||||
case Op::kMax:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::max(a, b); });
|
||||
break;
|
||||
case Operation::kMin:
|
||||
case Op::kMin:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::min(a, b); });
|
||||
break;
|
||||
case Operation::kSum:
|
||||
case Op::kSum:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
case Op::kBitwiseAND:
|
||||
case Op::kBitwiseOR:
|
||||
case Op::kBitwiseXOR:
|
||||
AccumulateBitwise(buffer, input, size, reduce_operation);
|
||||
break;
|
||||
default:
|
||||
@ -130,36 +129,37 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
void Accumulate(char const* input, std::size_t size, char* buffer) const {
|
||||
using Type = ArrayInterfaceHandler::Type;
|
||||
switch (data_type_) {
|
||||
case DataType::kInt8:
|
||||
case Type::kI1:
|
||||
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
|
||||
reinterpret_cast<std::int8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
case Type::kU1:
|
||||
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
|
||||
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
case Type::kI4:
|
||||
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
|
||||
reinterpret_cast<std::int32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
case Type::kU4:
|
||||
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
|
||||
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
case Type::kI8:
|
||||
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
|
||||
reinterpret_cast<std::int64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
case Type::kU8:
|
||||
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
|
||||
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
case Type::kF4:
|
||||
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
case Type::kF8:
|
||||
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
@ -169,8 +169,8 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
Operation operation_;
|
||||
ArrayInterfaceHandler::Type data_type_;
|
||||
Op operation_;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -180,7 +180,7 @@ class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
|
||||
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
@ -190,11 +190,11 @@ class BroadcastFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t rank_;
|
||||
std::size_t root_;
|
||||
std::int32_t rank_;
|
||||
std::int32_t root_;
|
||||
};
|
||||
|
||||
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
|
||||
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
|
||||
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
std::size_t sequence_number, std::int32_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||
}
|
||||
|
||||
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
std::size_t sequence_number, std::int32_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type,
|
||||
Operation op) {
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
ArrayInterfaceHandler::Type data_type, Op op) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root) {
|
||||
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
|
||||
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||
}
|
||||
|
||||
template <class HandlerFunctor>
|
||||
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank,
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
HandlerFunctor const& functor) {
|
||||
// Pass through if there is only 1 client.
|
||||
if (world_size_ == 1) {
|
||||
@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
#include "../data/array_interface.h"
|
||||
#include "comm.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Handles collective communication primitives in memory.
|
||||
*
|
||||
@ -28,12 +27,11 @@ class InMemoryHandler {
|
||||
|
||||
/**
|
||||
* @brief Construct a handler with the given world size.
|
||||
* @param world_size Number of workers.
|
||||
* @param world Number of workers.
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(std::int32_t worldSize)
|
||||
: world_size_{static_cast<std::size_t>(worldSize)} {}
|
||||
explicit InMemoryHandler(std::int32_t world) : world_size_{world} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
@ -43,7 +41,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
void Init(std::size_t world_size, std::size_t rank);
|
||||
void Init(std::int32_t world_size, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Shut down the handler.
|
||||
@ -53,7 +51,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* shut it down collectively.
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, std::size_t rank);
|
||||
void Shutdown(uint64_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allgather.
|
||||
@ -64,7 +62,7 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
std::size_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform variable-length allgather.
|
||||
@ -75,7 +73,7 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
std::size_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
@ -88,7 +86,8 @@ class InMemoryHandler {
|
||||
* @param op The reduce operation.
|
||||
*/
|
||||
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
ArrayInterfaceHandler::Type data_type, Op op);
|
||||
|
||||
/**
|
||||
* @brief Perform broadcast.
|
||||
@ -100,7 +99,7 @@ class InMemoryHandler {
|
||||
* @param root Index of the worker to broadcast from.
|
||||
*/
|
||||
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root);
|
||||
std::size_t sequence_number, std::int32_t rank, std::int32_t root);
|
||||
|
||||
private:
|
||||
/**
|
||||
@ -115,17 +114,15 @@ class InMemoryHandler {
|
||||
*/
|
||||
template <class HandlerFunctor>
|
||||
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||
std::size_t rank, HandlerFunctor const& functor);
|
||||
std::int32_t rank, HandlerFunctor const& functor);
|
||||
|
||||
std::size_t world_size_{}; /// Number of workers.
|
||||
std::size_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::size_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::int32_t world_size_{}; /// Number of workers.
|
||||
std::int64_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception, current_exception, rethrow_exception
|
||||
#include <future> // for promise
|
||||
#include <memory> // for make_shared
|
||||
#include <mutex> // for lock_guard, unique_lock
|
||||
#include <queue> // for queue
|
||||
#include <string> // for string
|
||||
@ -18,9 +20,10 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
auto error = [this](Op op) {
|
||||
op.pr->set_value();
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
|
||||
// Iterate through all the ops for poll
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
auto op = std::move(qcopy.front());
|
||||
qcopy.pop();
|
||||
|
||||
switch (op.code) {
|
||||
@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
error(op);
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
qcopy.push(op);
|
||||
qcopy.push(std::move(op));
|
||||
}
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
if (!poll.fds.empty()) {
|
||||
auto rc = poll.Poll(timeout_);
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
timer_.Stop(__func__);
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
timer_.Stop("poll");
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
// We wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
// Iterate through all the ops for performing the operations
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
auto op = std::move(qcopy.front());
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
if (n_bytes_done == 0) {
|
||||
error();
|
||||
return Fail("Encountered EOF. The other end is likely closed.");
|
||||
error(op);
|
||||
return Fail("Encountered EOF. The other end is likely closed.",
|
||||
op.sock->GetSockError());
|
||||
}
|
||||
}
|
||||
break;
|
||||
@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
error(op);
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
error(op);
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
// not yet finished, push back to queue for the next round.
|
||||
qcopy.push(op);
|
||||
} else {
|
||||
op.pr->set_value();
|
||||
}
|
||||
}
|
||||
|
||||
if (!blocking) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
@ -148,8 +150,7 @@ void Loop::Process() {
|
||||
};
|
||||
|
||||
// This loop cannot exit unless `stop_` is set to true. There must always be a thread to
|
||||
// answer the blocking call even if there are errors, otherwise the blocking will wait
|
||||
// forever.
|
||||
// answer the call even if there are errors.
|
||||
while (true) {
|
||||
try {
|
||||
std::unique_lock lock{mu_};
|
||||
@ -170,44 +171,15 @@ void Loop::Process() {
|
||||
// Move the global queue into a local variable to unblock it.
|
||||
std::queue<Op> qcopy;
|
||||
|
||||
bool is_blocking = false;
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
auto op = std::move(queue_.front());
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||
// worker thread (but not the client thread), wait until all operations are
|
||||
// finished.
|
||||
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||
|
||||
if (is_blocking && rc.OK()) {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
// Push back the remaining operations.
|
||||
if (rc.OK()) {
|
||||
std::unique_lock lock{mu_};
|
||||
while (!qcopy.empty()) {
|
||||
queue_.push(qcopy.front());
|
||||
qcopy.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Notify the client thread who called block after all error conditions are set.
|
||||
auto notify_if_block = [&] {
|
||||
if (is_blocking) {
|
||||
std::unique_lock lock{mu_};
|
||||
block_done_ = true;
|
||||
lock.unlock();
|
||||
block_cv_.notify_one();
|
||||
}
|
||||
};
|
||||
// Clear the local queue.
|
||||
auto rc = this->ProcessQueue(&qcopy);
|
||||
|
||||
// Handle error
|
||||
if (!rc.OK()) {
|
||||
@ -215,8 +187,6 @@ void Loop::Process() {
|
||||
} else {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
|
||||
notify_if_block();
|
||||
} catch (std::exception const& e) {
|
||||
curr_exce_ = std::current_exception();
|
||||
set_rc(Fail("Exception inside the event loop:" + std::string{e.what()}));
|
||||
@ -256,20 +226,28 @@ Result Loop::Stop() {
|
||||
stop_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!this->worker_.joinable()) {
|
||||
std::lock_guard<std::mutex> guard{rc_lock_};
|
||||
return Fail("Worker has stopped.", std::move(rc_));
|
||||
}
|
||||
|
||||
this->Submit(Op{Op::kBlock});
|
||||
{
|
||||
// Wait for the block call to finish.
|
||||
std::unique_lock lock{mu_};
|
||||
block_cv_.wait(lock, [this] { return block_done_ || stop_; });
|
||||
block_done_ = false;
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
for (auto& fut : futures_) {
|
||||
if (fut.valid()) {
|
||||
try {
|
||||
fut.get();
|
||||
} catch (std::future_error const&) {
|
||||
// Do nothing. If something went wrong in the worker, we have a std::future_error
|
||||
// due to broken promise. This function will transfer the rc back to the caller.
|
||||
}
|
||||
}
|
||||
}
|
||||
futures_.clear();
|
||||
|
||||
{
|
||||
// Transfer the rc.
|
||||
std::lock_guard<std::mutex> lock{rc_lock_};
|
||||
@ -278,13 +256,13 @@ Result Loop::Stop() {
|
||||
}
|
||||
|
||||
void Loop::Submit(Op op) {
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
op.pr = std::move(p);
|
||||
futures_.emplace_back(op.pr->get_future());
|
||||
CHECK_NE(op.n, 0);
|
||||
|
||||
std::unique_lock lock{mu_};
|
||||
if (op.code != Op::kBlock) {
|
||||
CHECK_NE(op.n, 0);
|
||||
}
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
|
||||
@ -7,9 +7,12 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <exception> // for exception_ptr
|
||||
#include <mutex> // for unique_lock, mutex
|
||||
#include <future> // for future
|
||||
#include <memory> // for shared_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/timer.h" // for Monitor
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@ -20,14 +23,15 @@ class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
// kSleep is only for testing
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
std::shared_ptr<std::promise<void>> pr;
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kSleep); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
@ -45,12 +49,11 @@ class Loop {
|
||||
private:
|
||||
std::thread worker_; // thread worker to execute the tasks
|
||||
|
||||
std::condition_variable cv_; // CV used to notify a new submit call
|
||||
std::condition_variable block_cv_; // CV used to notify the blocking call
|
||||
bool block_done_{false}; // Flag to indicate whether the blocking call has finished.
|
||||
std::condition_variable cv_; // CV used to notify a new submit call
|
||||
|
||||
std::queue<Op> queue_; // event queue
|
||||
std::mutex mu_; // mutex to protect the queue, cv, and block_done
|
||||
std::vector<std::future<void>> futures_;
|
||||
std::mutex mu_; // mutex to protect the queue, cv, and block_done
|
||||
|
||||
std::chrono::seconds timeout_;
|
||||
|
||||
@ -61,7 +64,7 @@ class Loop {
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||
Result ProcessQueue(std::queue<Op>* p_queue) const;
|
||||
// The cunsumer function that runs inside a worker thread.
|
||||
void Process();
|
||||
|
||||
|
||||
@ -1,243 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "comm.cuh"
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
|
||||
StringView nccl_path)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
|
||||
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world_size_)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result{ncclInt8};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IsBitwiseOp(Operation const &op) {
|
||||
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
|
||||
op == Operation::kBitwiseXOR;
|
||||
}
|
||||
|
||||
ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result{ncclMax};
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
auto const size = count * GetTypeSize(data_type);
|
||||
dh::caching_device_vector<char> buffer(size * world_size_);
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
}
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [&] { return stub_->GroupEnd(); };
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
@ -1,91 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "nccl_stub.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new NCCL communicator.
|
||||
* @param device_ordinal The GPU device id.
|
||||
* @param needs_sync Whether extra CUDA stream synchronization is needed.
|
||||
*
|
||||
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
|
||||
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
|
||||
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
|
||||
*
|
||||
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override;
|
||||
void Synchronize() override;
|
||||
|
||||
private:
|
||||
static constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rank_ == kRootRank) {
|
||||
auto rc = stub_->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
bool const needs_sync_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -1,32 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* A no-op communicator, used for non-distributed training.
|
||||
*/
|
||||
class NoOpCommunicator : public Communicator {
|
||||
public:
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
std::string AllGather(std::string_view) override { return {}; }
|
||||
std::string AllGatherV(std::string_view) override { return {}; }
|
||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||
void Broadcast(void *, std::size_t, int) override {}
|
||||
std::string GetProcessorName() override { return {}; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -41,20 +41,26 @@ struct Magic {
|
||||
|
||||
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
|
||||
std::int32_t magic{kMagic};
|
||||
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
|
||||
magic = 0;
|
||||
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return p_sock->SendAll(&magic, sizeof(magic), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
magic = 0;
|
||||
return p_sock->RecvAll(&magic, sizeof(magic), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
@ -227,31 +233,43 @@ struct Error {
|
||||
|
||||
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
|
||||
std::int32_t err{ErrorSignal()};
|
||||
auto n_sent = worker->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return worker->SendAll(&err, sizeof(err), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
};
|
||||
}
|
||||
// self is localhost, we are sending the signal to the error handling thread for it to
|
||||
// close.
|
||||
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_sent = self->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return self->SendAll(&err, sizeof(err), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
};
|
||||
}
|
||||
// get signal, either for error or for shutdown.
|
||||
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_recv = peer->RecvAll(&err, sizeof(err));
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
std::size_t n_recv{0};
|
||||
return Success() << [&] {
|
||||
return peer->RecvAll(&err, sizeof(err), &n_recv);
|
||||
} << [&] {
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
};
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective::proto
|
||||
|
||||
@ -1,175 +0,0 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class RabitCommunicator : public Communicator {
|
||||
public:
|
||||
static Communicator *Create(Json const &config) {
|
||||
std::vector<std::string> args_str;
|
||||
for (auto &items : get<Object const>(config)) {
|
||||
switch (items.second.GetValue().Type()) {
|
||||
case xgboost::Value::ValueKind::kString: {
|
||||
args_str.push_back(items.first + "=" + get<String const>(items.second));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kInteger: {
|
||||
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kBoolean: {
|
||||
if (get<Boolean const>(items.second)) {
|
||||
args_str.push_back(items.first + "=1");
|
||||
} else {
|
||||
args_str.push_back(items.first + "=0");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<char *> args;
|
||||
for (auto &key_value : args_str) {
|
||||
args.push_back(&key_value[0]);
|
||||
}
|
||||
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
|
||||
LOG(FATAL) << "Failed to initialize Rabit";
|
||||
}
|
||||
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
|
||||
}
|
||||
|
||||
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
|
||||
|
||||
bool IsDistributed() const override { return rabit::IsDistributed(); }
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
std::string AllGather(std::string_view input) override {
|
||||
auto const per_rank = input.size();
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(index, per_rank, input);
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
auto const size_node_slice = input.size();
|
||||
auto const all_sizes = collective::Allgather(size_node_slice);
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice =
|
||||
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
DoAllReduce<char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
DoAllReduce<float>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
DoAllReduce<double>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type";
|
||||
}
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||
rabit::Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
|
||||
|
||||
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
|
||||
|
||||
protected:
|
||||
void Shutdown() override { rabit::Finalize(); }
|
||||
|
||||
private:
|
||||
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
|
||||
count);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
|
||||
count);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
template <typename DType>
|
||||
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kMin:
|
||||
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kSum:
|
||||
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||
ptr->prev = std::move(rhs);
|
||||
}
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||
auto name = std::filesystem::path{file}.filename();
|
||||
dmlc::DateLogger logger;
|
||||
if (file && line != -1) {
|
||||
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||
auto name = std::filesystem::path{ file }.filename();
|
||||
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
|
||||
"]: " + std::forward<std::string>(msg);
|
||||
}
|
||||
return std::forward<std::string>(msg);
|
||||
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
|
||||
}
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
void SafeColl(Result const& rc) {
|
||||
|
||||
@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) {
|
||||
CHECK(!this->IsClosed());
|
||||
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
|
||||
std::int32_t len = static_cast<std::int32_t>(str.size());
|
||||
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
|
||||
auto bytes = this->SendAll(str.c_str(), str.size());
|
||||
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
|
||||
return bytes;
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = Success() << [&] {
|
||||
return this->SendAll(&len, sizeof(len), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != sizeof(len)) {
|
||||
return Fail("Failed to send string length.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
return this->SendAll(str.c_str(), str.size(), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != str.size()) {
|
||||
return Fail("Failed to send string.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
SafeColl(rc);
|
||||
return n_bytes;
|
||||
}
|
||||
|
||||
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
||||
CHECK(!this->IsClosed());
|
||||
std::int32_t len;
|
||||
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
p_str->resize(len);
|
||||
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||
if (static_cast<decltype(len)>(bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
std::size_t n_bytes{0};
|
||||
return Success() << [&] {
|
||||
return this->RecvAll(&len, sizeof(len), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
p_str->resize(len);
|
||||
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
|
||||
} << [&] {
|
||||
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
|
||||
@ -31,14 +31,20 @@
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
Tracker::Tracker(Json const& config)
|
||||
: sortby_{static_cast<SortBy>(
|
||||
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
|
||||
n_workers_{
|
||||
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||
timeout_{std::chrono::seconds{
|
||||
OptionalArg<Integer const>(config, "timeout", static_cast<std::int64_t>(0))}} {
|
||||
using std::chrono_literals::operator""s;
|
||||
// Some old configurations in JVM for the scala implementation (removed) use 0 to
|
||||
// indicate blocking. We continue that convention here.
|
||||
timeout_ = (timeout_ == 0s) ? -1s : timeout_;
|
||||
}
|
||||
|
||||
Result Tracker::WaitUntilReady() const {
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const {
|
||||
timer.Start();
|
||||
while (!this->Ready()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > this->Timeout().count()) {
|
||||
if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) {
|
||||
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
return listener_.NonBlocking(true);
|
||||
} << [&] {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
{
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
}
|
||||
if (state.running) {
|
||||
// Don't timeout if the communicator group is up and running.
|
||||
return poll.Poll(std::chrono::seconds{-1});
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; }
|
||||
/**
|
||||
*
|
||||
* @brief Implementation of RABIT tracker.
|
||||
@ -52,7 +53,7 @@ class Tracker {
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
std::int32_t port_{-1};
|
||||
std::chrono::seconds timeout_{0};
|
||||
std::chrono::seconds timeout_{-1};
|
||||
std::atomic<bool> ready_{false};
|
||||
|
||||
public:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file io.h
|
||||
* \brief general stream interface for serialization, I/O
|
||||
* \author Tianqi Chen
|
||||
@ -8,7 +8,6 @@
|
||||
#define XGBOOST_COMMON_IO_H_
|
||||
|
||||
#include <dmlc/io.h>
|
||||
#include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
|
||||
|
||||
#include <algorithm> // for min, fill_n, copy_n
|
||||
#include <array> // for array
|
||||
@ -23,12 +22,99 @@
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "common.h"
|
||||
#include "common.h" // for DivRoundUp
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::common {
|
||||
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
|
||||
using MemoryBufferStream = rabit::utils::MemoryBufferStream;
|
||||
struct MemoryFixSizeBuffer : public dmlc::SeekStream {
|
||||
public:
|
||||
// similar to SEEK_END in libc
|
||||
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Ctor
|
||||
*
|
||||
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
|
||||
* @param buffer_size Size of the source buffer
|
||||
*/
|
||||
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
|
||||
~MemoryFixSizeBuffer() override = default;
|
||||
|
||||
std::size_t Read(void *ptr, std::size_t size) override {
|
||||
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
void Write(const void *ptr, std::size_t size) override {
|
||||
if (size == 0) return;
|
||||
CHECK_LE(curr_ptr_ + size, buffer_size_);
|
||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
void Seek(std::size_t pos) override {
|
||||
if (pos == kSeekEnd) {
|
||||
curr_ptr_ = buffer_size_;
|
||||
} else {
|
||||
curr_ptr_ = static_cast<std::size_t>(pos);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Current position in the buffer (stream).
|
||||
*/
|
||||
std::size_t Tell() override { return curr_ptr_; }
|
||||
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
|
||||
|
||||
protected:
|
||||
/*! \brief in memory buffer */
|
||||
char *p_buffer_{nullptr};
|
||||
/*! \brief current pointer */
|
||||
std::size_t buffer_size_{0};
|
||||
/*! \brief current pointer */
|
||||
std::size_t curr_ptr_{0};
|
||||
};
|
||||
|
||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||
struct MemoryBufferStream : public dmlc::SeekStream {
|
||||
public:
|
||||
explicit MemoryBufferStream(std::string *p_buffer)
|
||||
: p_buffer_(p_buffer) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
~MemoryBufferStream() override = default;
|
||||
size_t Read(void *ptr, size_t size) override {
|
||||
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
|
||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
void Write(const void *ptr, size_t size) override {
|
||||
if (size == 0) return;
|
||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||
p_buffer_->resize(curr_ptr_+size);
|
||||
}
|
||||
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
void Seek(size_t pos) override {
|
||||
curr_ptr_ = static_cast<size_t>(pos);
|
||||
}
|
||||
size_t Tell() override {
|
||||
return curr_ptr_;
|
||||
}
|
||||
virtual bool AtEnd() const {
|
||||
return curr_ptr_ == p_buffer_->length();
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief in memory buffer */
|
||||
std::string *p_buffer_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryBufferStream
|
||||
|
||||
/*!
|
||||
* \brief Input stream that support additional PeekRead operation,
|
||||
|
||||
@ -116,19 +116,19 @@ INSTANTIATE(ColumnarAdapterBatch)
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* \brief A view over gathered sketch values.
|
||||
* @brief A view over gathered sketch values.
|
||||
*/
|
||||
template <typename T>
|
||||
struct QuantileAllreduce {
|
||||
common::Span<T> global_values;
|
||||
common::Span<bst_idx_t> worker_indptr;
|
||||
common::Span<bst_idx_t> feature_indptr;
|
||||
size_t n_features{0};
|
||||
bst_feature_t n_features{0};
|
||||
/**
|
||||
* \brief Get sketch values of the a feature from a worker.
|
||||
* @brief Get sketch values of the a feature from a worker.
|
||||
*
|
||||
* \param rank rank of target worker
|
||||
* \param fidx feature idx
|
||||
* @param rank rank of target worker
|
||||
* @param fidx feature idx
|
||||
*/
|
||||
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
|
||||
// get span for worker
|
||||
@ -154,7 +154,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
worker_segments.resize(1, 0);
|
||||
auto world = collective::GetWorldSize();
|
||||
auto rank = collective::GetRank();
|
||||
auto n_columns = sketches_.size();
|
||||
bst_feature_t n_columns = sketches_.size();
|
||||
|
||||
// get the size of each feature.
|
||||
std::vector<bst_idx_t> sketch_size;
|
||||
@ -165,7 +165,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
sketch_size.push_back(reduced[i].size);
|
||||
}
|
||||
}
|
||||
// turn the size into CSC indptr
|
||||
// Turn the size into CSC indptr
|
||||
std::vector<bst_idx_t> &sketches_scan = *p_sketches_scan;
|
||||
sketches_scan.resize((n_columns + 1) * world, 0);
|
||||
size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker.
|
||||
@ -174,7 +174,10 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
// Gather all column pointers
|
||||
auto rc =
|
||||
collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size()));
|
||||
collective::SafeColl(rc);
|
||||
if (!rc.OK()) {
|
||||
collective::SafeColl(collective::Fail("Failed to get sketch scan.", std::move(rc)));
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
size_t back = (i + 1) * (n_columns + 1) - 1;
|
||||
auto n_entries = sketches_scan.at(back);
|
||||
@ -206,7 +209,9 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
ctx, info,
|
||||
linalg::MakeVec(reinterpret_cast<float *>(global_sketches.data()),
|
||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)));
|
||||
collective::SafeColl(rc);
|
||||
if (!rc.OK()) {
|
||||
collective::SafeColl(collective::Fail("Failed to get sketch.", std::move(rc)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WQSketch>
|
||||
@ -260,7 +265,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, Meta
|
||||
rc = collective::GlobalSum(ctx, info,
|
||||
linalg::MakeVec(global_categories.data(), global_categories.size()));
|
||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||
categories_.size()};
|
||||
static_cast<bst_feature_t>(categories_.size())};
|
||||
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
|
||||
if (!IsCat(feature_types_, fidx)) {
|
||||
return;
|
||||
@ -285,8 +290,9 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
|
||||
monitor_.Start(__func__);
|
||||
|
||||
size_t n_columns = sketches_.size();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
||||
bst_feature_t n_columns = sketches_.size();
|
||||
auto rc = collective::Allreduce(ctx, &n_columns, collective::Op::kMax);
|
||||
collective::SafeColl(rc);
|
||||
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
||||
|
||||
AllreduceCategories(ctx, info);
|
||||
@ -300,8 +306,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
|
||||
// Prune the intermediate num cuts for synchronization.
|
||||
std::vector<bst_idx_t> global_column_size(columns_size_);
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
|
||||
rc = collective::GlobalSum(ctx, info,
|
||||
linalg::MakeVec(global_column_size.data(), global_column_size.size()));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user