Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan 2024-05-20 11:56:23 +08:00 committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
195 changed files with 2768 additions and 9234 deletions

View File

@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF)
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
option(RABIT_MOCK "Build rabit with mock" OFF)
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA
@ -282,9 +281,6 @@ if(MSVC)
endif()
endif()
# rabit
add_subdirectory(rabit)
# core xgboost
add_subdirectory(${xgboost_SOURCE_DIR}/src)
target_link_libraries(objxgboost PUBLIC dmlc)

View File

@ -106,10 +106,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
@ -134,7 +131,4 @@ OBJECTS= \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
$(PKGROOT)/src/c_api/c_api_error.o \
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o \
$(PKGROOT)/rabit/src/rabit_c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o
$(PKGROOT)/amalgamation/dmlc-minimum0.o

View File

@ -106,10 +106,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
@ -134,7 +131,4 @@ OBJECTS= \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
$(PKGROOT)/src/c_api/c_api_error.o \
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine.o \
$(PKGROOT)/rabit/src/rabit_c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o
$(PKGROOT)/amalgamation/dmlc-minimum0.o

View File

@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target)
target_include_directories(
${target} PRIVATE
${xgboost_SOURCE_DIR}/gputreeshap
${xgboost_SOURCE_DIR}/rabit/include
${CUDAToolkit_INCLUDE_DIRS})
if(MSVC)

View File

@ -16,7 +16,7 @@ def main(client: Client) -> None:
m = 100000
n = 100
rng = da.random.default_rng(1)
X = rng.normal(size=(m, n))
X = rng.normal(size=(m, n), chunks=(10000, -1))
y = X.sum(axis=1)
# DaskDMatrix acts like normal DMatrix, works as a proxy for local

View File

@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
*
* @return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
char const *c_json_config, DMatrixHandle m,
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
char const *config, DMatrixHandle m,
bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);
@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
*
* @brief Experimental support for exposing internal communicator in XGBoost.
*
* @note This is still under development.
*
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
* changed significantly since its adoption. It consists of a tracker and a set of
* workers. The tracker is responsible for bootstrapping the communication group and
* handling centralized tasks like logging. The workers are actual communicators
* performing collective tasks like allreduce.
*
* To use the collective implementation, one needs to first create a tracker with
* corresponding parameters, then get the arguments for workers using
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
* runtime is shutting down. This requirement is similar to a Python thread or socket,
* which should not be relied upon in a `__del__` function.
*
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
* is called, for instance, training a booster might return a connection error.
*
* @{
*/
/**
* @brief Handle to tracker.
* @brief Handle to the tracker.
*
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
* other one is `federated`.
* other one is `federated`. `rabit` is used for normal collective communication, while
* `federated` is used for federated learning.
*
* This is still under development.
*/
typedef void *TrackerHandle; /* NOLINT */
@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
*
* @param config JSON encoded parameters.
*
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
* and `federated`.
* - dmlc_communicator: String, the type of tracker to create. Available options are
* `rabit` and `federated`. See @ref TrackerHandle for more info.
* - n_workers: Integer, the number of workers.
* - port: (Optional) Integer, the port this tracker should listen to.
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
* - timeout: (Optional) Integer, timeout in seconds for various networking
operations. Default is 300 seconds.
*
* Some configurations are `rabit` specific:
*
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
* This can be useful when the communicator cannot reliably obtain the host address.
* - sortby: (Optional) Integer.
* + 0: Sort workers by their host name.
* + 1: Sort workers by task IDs.
*
* Some `federated` specific configurations:
* - federated_secure: Boolean, whether this is a secure server.
* - federated_secure: Boolean, whether this is a secure server. False for testing.
* - server_key_path: Path to the server key. Used only if this is a secure server.
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
*/
XGB_DLL int XGTrackerFree(TrackerHandle handle);
/*!
* \brief Initialize the collective communicator.
/**
* @brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
* Call this once in the worker process before using anything. Please make sure
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
* thread-local variable.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* @param config JSON encoded configuration. Accepted JSON keys are:
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
* `USE_DLOPEN_NCCL`.
* Only applicable to the Federated communicator (use upper case for environment variables, use
*
* Only applicable to the `rabit` communicator:
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
* - dmlc_tracker_port: Port number of the tracker.
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - dmlc_retry: The number of retries for connection failure.
* - dmlc_timeout: Timeout in seconds.
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
*
* Only applicable to the `federated` communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* config);
/*!
* \brief Finalize the collective communicator.
/**
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
* Call this function after you have finished all jobs.
*
* \return 0 for success, -1 for failure.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);
/*!
* \brief Get rank of current process.
/**
* @brief Get rank of the current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);
/*!
* \brief Get total number of processes.
/**
* @brief Get the total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);
/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);
/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the tracker.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
* This function can be used to communicate the information of the progress to the user
* who monitors the tracker.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
* @param message The message to be printed.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);
/*!
* \brief Get the name of the processor.
/**
* @brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
* @param name_str Pointer to received returned processor name.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
/**
* @brief Broadcast a memory region to all others from root. This function is NOT
* thread-safe.
*
* Example:
* \code
* @code
* int a = 1;
* Broadcast(&a, sizeof(a), root);
* \endcode
* @endcode
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Pointer to the send or receive buffer.
* @param size Size of the data in bytes.
* @param root The process rank to broadcast from.
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
/**
* @brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* \code
* vector<int> data(10);
* @code
* enum class Op {
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
* };
* std::vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
* ...
* \endcode
* @endcode
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
* @param send_receive_buffer Buffer for both sending and receiving data.
* @param count Number of elements to be reduced.
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);

View File

@ -55,10 +55,9 @@ struct ResultImpl {
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
} // namespace detail
/**

View File

@ -16,6 +16,10 @@
#include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap
#if defined(__linux__)
#include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
#endif // defined(__linux__)
#if !defined(xgboost_IS_MINGW)
#if defined(__MINGW32__)
@ -319,7 +323,8 @@ class TCPSocket {
std::int32_t domain;
socklen_t len = sizeof(domain);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
0);
return ret_iafamily(domain);
#else
struct sockaddr sa;
@ -426,6 +431,35 @@ class TCPSocket {
return Success();
}
[[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
socklen_t optlen;
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
&optlen);
if (rc != 0 || optlen != sizeof(std::int32_t)) {
return system::FailWithCode("getsockopt");
}
return Success();
}
[[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
socklen_t optlen;
auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
&optlen);
if (rc != 0 || optlen != sizeof(std::int32_t)) {
return system::FailWithCode("getsockopt");
}
return Success();
}
#if defined(__linux__)
[[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
: system::FailWithCode("ioctl");
}
[[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
: system::FailWithCode("ioctl");
}
#endif // defined(__linux__)
[[nodiscard]] Result SetKeepAlive() {
std::int32_t keepalive = 1;
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
@ -436,10 +470,9 @@ class TCPSocket {
return Success();
}
[[nodiscard]] Result SetNoDelay() {
std::int32_t tcp_no_delay = 1;
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay));
[[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
sizeof(no_delay));
if (rc != 0) {
return system::FailWithCode("Failed to set TCP no delay.");
}
@ -602,45 +635,47 @@ class TCPSocket {
}
/**
* \brief Send data, without error then all data should be sent.
* @brief Send data, without error then all data should be sent.
*/
[[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
[[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
char const *_buf = reinterpret_cast<const char *>(buf);
std::size_t ndone = 0;
std::size_t &ndone = *n_sent;
ndone = 0;
while (ndone < len) {
ssize_t ret = send(handle_, _buf, len - ndone, 0);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
return Success();
}
system::ThrowAtError("send");
return system::FailWithCode("send");
}
_buf += ret;
ndone += ret;
}
return ndone;
return Success();
}
/**
* \brief Receive data, without error then all data should be received.
* @brief Receive data, without error then all data should be received.
*/
[[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
[[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
char *_buf = reinterpret_cast<char *>(buf);
std::size_t ndone = 0;
std::size_t &ndone = *n_recv;
ndone = 0;
while (ndone < len) {
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
return Success();
}
system::ThrowAtError("recv");
return system::FailWithCode("recv");
}
if (ret == 0) {
return ndone;
return Success();
}
_buf += ret;
ndone += ret;
}
return ndone;
return Success();
}
/**
* \brief Send data using the socket

View File

@ -23,6 +23,7 @@ CONFIG = {
"USE_NCCL": "OFF",
"JVM_BINDINGS": "ON",
"LOG_CAPI_INVOCATION": "OFF",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
}
@ -97,10 +98,6 @@ def native_build(args):
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
# if enviorment set rabit_mock
if os.getenv("RABIT_MOCK", None) is not None:
args.append("-DRABIT_MOCK:BOOL=ON")
# if enviorment set GPU_ARCH_FLAG
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
if gpu_arch_flag is not None:
@ -162,12 +159,6 @@ def native_build(args):
maybe_makedirs(output_folder)
cp("../lib/" + library_name, output_folder)
print("copying pure-Python tracker")
cp(
"../python-package/xgboost/tracker.py",
"{}/src/main/resources".format(xgboost4j),
)
print("copying train/test files")
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
with cd("../demo/CLI/regression"):

View File

@ -489,6 +489,11 @@
<artifactId>kryo</artifactId>
<version>5.6.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>

View File

@ -54,9 +54,9 @@ public class XGBoost {
private final Map<String, Object> params;
private final int round;
private final Map<String, String> workerEnvs;
private final Map<String, Object> workerEnvs;
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
this.params = params;
this.round = round;
this.workerEnvs = workerEnvs;
@ -174,9 +174,9 @@ public class XGBoost {
int numBoostRound) throws Exception {
final RabitTracker tracker =
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start(0L)) {
if (tracker.start()) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
.reduce((x, y) -> x)
.collect()
.get(0);

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession
/**
* Rabit tracker configurations.
*
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* @param timeout The number of seconds before timeout waiting for workers to connect. and
* for the tracker to shutdown.
* @param hostIp The Rabit Tracker host IP address.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
* @param port The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L)
def apply(): TrackerConf = TrackerConf(0)
}
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel {
private def buildDistributedBooster(
buildWatches: () => Watches,
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
rabitEnv: java.util.Map[String, Object],
obj: ObjectiveTrait,
eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel {
val taskId = TaskContext.getPartitionId().toString
val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel {
}
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker: ITracker = new RabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
require(tracker.start(), "FAULT: Failed to start tracker")
tracker
}
@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel {
// Train for every ${savingRound} rounds and save the partially completed booster
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
val (booster, metrics) = try {
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val rabitEnv = tracker.getWorkerEnvs
tracker.workerArgs().putAll(xgbRabitParams)
val rabitEnv = tracker.workerArgs
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
var optionWatches: Option[() => Watches] = None
@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel {
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
val trackerReturnVal = tracker.waitFor(0L)
logger.info(s"Rabit returns with exit code $trackerReturnVal")
if (trackerReturnVal != 0) {
throw new XGBoostError("XGBoostModel training failed.")
}
(booster, metrics)
} finally {
tracker.stop()

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params {
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
*
* case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
* trackerImpl: String)
* case class TrackerConf(timeout: Int, hostIp: String, port: Int)
*
* See below for detailed explanations.
*
* - trackerImpl: Select the implementation of Rabit tracker.
* default: "python"
*
* Choice between "python" or "scala". The former utilizes the Java wrapper of the
* Python Rabit tracker (in dmlc_core), and does not support timeout settings.
* The "scala" version removes Python components, and fully supports timeout settings.
*
* - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
* default: 0 millisecond (no timeout)
* - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
* default: 0 (no timeout)
*
* Timeout for constructing the communication group and waiting for the tracker to
* shutdown when it's instructed to, doesn't apply to communication when tracking
* is running.
* The timeout value should take the time of data loading and pre-processing into account,
* due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
* due to potential lazy execution. Alternatively, you may force Spark to
* perform data transformation before calling XGBoost.train(), so that this timeout truly
* reflects the connection delay. Set a reasonable timeout value to prevent model
* training/testing from hanging indefinitely, possible due to network issues.
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*
* - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
* cannot be automatically guessed.
*
* - port : The port number for the tracker to listen to. Use a system allocated one by
* default.
*/
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
setDefault(trackerConf, TrackerConf())

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.funsuite.AnyFunSuite
@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
xgbParamsFactory.buildXGBRuntimeParams
}
test("Customize host ip and python exec for Rabit tracker") {
val hostIp = "192.168.22.111"
val pythonExec = "/usr/bin/python3"
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.contains(hostIp))
assert(cmd.startsWith("python"))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(!cmd.contains(hostIp))
case _ => assert(false, "expected python tracker implementation")
}
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
case pyTracker: PyRabitTracker =>
val cmd = pyTracker.getRabitTrackerCommand
assert(cmd.startsWith(pythonExec))
assert(cmd.contains(s" --host-ip=${hostIp}"))
case _ => assert(false, "expected python tracker implementation")
}
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
*/
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new PyRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker. workerArgs
val workerCount: Int = numWorkers
/*
@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
The Java RabitTracker class reacts to exceptions by killing the spawned process running
the Python tracker. If at least one Rabit worker has yet connected to the tracker before
it is killed, the resulted connection failure will trigger the Rabit worker to call
"exit(-1);" in the native C++ code, effectively ending the dummy Spark task.
In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are
isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks
running in separate containers. However, as unit tests are run in Spark local mode, in which
tasks are executed by threads belonging to the same process, one thread calling "exit(-1);"
ultimately kills the entire process, which also happens to host the Spark driver, causing
the entire Spark application to crash.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
For the same reason, the Java RabitTracker class delays the killing of the Python tracker
process to ensure that pending worker connections are handled.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0) != 0)
}
test("Communicator allreduce works.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.workerArgs
val workerCount: Int = numWorkers
rdd.mapPartitions { iter =>
val index = iter.next()
Communicator.init(trackerEnvs)
val a = Array(1.0f, 2.0f, 3.0f)
System.out.println(a.mkString(", "))
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
for (i <- 0 to 2) {
assert(a(i) * workerCount == b(i))
}
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
for (i <- 0 to 2) {
assert(a(i) == c(i))
}
Communicator.shutdown()
Iterator(index)
}.collect()
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +

View File

@ -23,7 +23,6 @@ import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
intercept[SparkException] {
xgb.fit(trainingDF)
}
}
test("fail training elegantly with unsupported eval metrics") {

View File

@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
.fit(training)
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
).fit(training)
assert(Communicator.communicatorEnvs.asScala.size > 3)
Communicator.communicatorEnvs.asScala.foreach( item => {
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1")
})
// check the equality of single instance prediction
val prediction2 = model2.transform(testDF).select("prediction").collect()
// check parity w/o rabit cache
@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
assert(math.abs(p1 - p2) < predictionErrorMin)
}
}
test("test rabit timeout fail handle") {
val training = buildDataFrame(Classification.train)
// mock rank 0 failure during 8th allreduce synchronization
Communicator.mockList = Array("0,8,0,0").toList.asJava
intercept[SparkException] {
new XGBoostClassifier(Map(
"eta" -> "0.1",
"max_depth" -> "10",
"verbosity" -> "1",
"objective" -> "binary:logistic",
"num_round" -> 5,
"num_workers" -> numWorkers,
"rabit_timeout" -> 0))
.fit(training)
}
Communicator.mockList = Array.empty.toList.asJava
}
}

View File

@ -51,6 +51,11 @@ pom_template = """
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${{scala.binary.version}}</artifactId>

View File

@ -7,6 +7,9 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* Collective communicator global class for synchronization.
*
@ -30,8 +33,9 @@ public class Communicator {
}
public enum DataType implements Serializable {
INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4),
INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8);
FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8),
INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8),
UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8);
private final int enumOp;
private final int size;
@ -56,30 +60,20 @@ public class Communicator {
}
}
// used as way to test/debug passed communicator init parameters
public static Map<String, String> communicatorEnvs;
public static List<String> mockList = new LinkedList<>();
/**
* Initialize the collective communicator on current working thread.
*
* @param envs The additional environment variables to pass to the communicator.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
communicatorEnvs = envs;
String[] args = new String[envs.size() * 2 + mockList.size() * 2];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey();
args[idx++] = e.getValue();
public static void init(Map<String, Object> envs) throws XGBoostError {
ObjectMapper mapper = new ObjectMapper();
try {
String jconfig = mapper.writeValueAsString(envs);
checkCall(XGBoostJNI.CommunicatorInit(jconfig));
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to read arguments for the communicator.", ex);
}
// pass list of rabit mock strings eg mock=0,1,0,0
for (String mock : mockList) {
args[idx++] = "mock";
args[idx++] = mock;
}
checkCall(XGBoostJNI.CommunicatorInit(args));
}
/**

View File

@ -1,14 +1,13 @@
package ml.dmlc.xgboost4j.java;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Interface for Rabit tracker implementations with three public methods:
* Interface for a tracker implementations with three public methods:
*
* - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given
* timeout value (in milliseconds.)
* - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients.
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit;
* The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and
* brokers connections between workers.
*/
public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler {
}
}
Map<String, String> getWorkerEnvs();
boolean start(long workerConnectionTimeout);
void stop();
// taskExecutionTimeout has no effect in current version of XGBoost.
int waitFor(long taskExecutionTimeout);
Map<String, Object> workerArgs() throws XGBoostError;
boolean start() throws XGBoostError;
void stop() throws XGBoostError;
void waitFor(long taskExecutionTimeout) throws XGBoostError;
}

View File

@ -1,101 +1,40 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
* As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both
* start() and waitFor() methods (i.e., the timeout is infinite.)
*
* For systems lacking Python environment, or for timeout functionality, consider using the Scala
* Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and
* provides timeout support.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements IRabitTracker {
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static String tracker_py = null;
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int numWorkers;
private String hostIp = "";
private String pythonExec = "";
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
private long handle = 0;
private Thread tracker_daemon;
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
}
/**
* Tracker logger that logs output from tracker.
*/
private class TrackerProcessLogger implements Runnable {
public void run() {
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
trackerProcess.get().waitFor();
int exitValue = trackerProcess.get().exitValue();
if (exitValue != 0) {
trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue);
} else {
trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue);
}
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
} catch (InterruptedException ie) {
// we should not get here as RabitTracker is accessed in the main thread
ie.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}
private static void initTrackerPy() throws IOException {
try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
}
public RabitTracker(int numWorkers)
public RabitTracker(int numWorkers, String hostIp)
throws XGBoostError {
this(numWorkers, hostIp, 0, 300);
}
public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError {
if (numWorkers < 1) {
throw new XGBoostError("numWorkers must be greater equal to one");
}
this.numWorkers = numWorkers;
}
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
throws XGBoostError {
this(numWorkers);
this.hostIp = hostIp;
this.pythonExec = pythonExec;
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out));
this.handle = out[0];
}
public void uncaughtException(Thread t, Throwable e) {
@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
trackerProcess.get().destroy();
this.tracker_daemon.interrupt();
}
}
@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker {
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, String> getWorkerEnvs() {
return envs;
public Map<String, Object> workerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
return config;
}
private void loadEnvs(InputStream ins) throws IOException {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String[] sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
public void stop() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle));
}
public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
reader.close();
} catch (IOException ioe){
logger.error("cannot get runtime configuration from tracker process");
ioe.printStackTrace();
throw ioe;
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
return this.tracker_daemon.isAlive();
}
/** visible for testing */
public String getRabitTrackerCommand() {
StringBuilder sb = new StringBuilder();
if (pythonExec == null || pythonExec.isEmpty()) {
sb.append("python ");
} else {
sb.append(pythonExec + " ");
}
sb.append(" " + tracker_py + " ");
sb.append(" --log-level=DEBUG" + " ");
sb.append(" --num-workers=" + numWorkers + " ");
// we first check the property then check the parameter
String hostIpFromProperties = trackerProperties.getHostIp();
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
sb.append(" --host-ip=" + hostIpFromProperties + " ");
} else if (hostIp != null & !hostIp.isEmpty()) {
logger.debug("Using the parametr host-ip: " + hostIp);
sb.append(" --host-ip=" + hostIp + " ");
}
return sb.toString();
}
private boolean startTrackerProcess() {
try {
String cmd = getRabitTrackerCommand();
trackerProcess.set(Runtime.getRuntime().exec(cmd));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
public void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for all workers to connect indefinitely, unless " +
"it is interrupted manually. Use the Scala RabitTracker for timeout support.");
}
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public int waitFor(long timeout) {
if (timeout > 0L) {
logger.warn("Python RabitTracker does not support timeout. " +
"The tracker will wait for either all workers to finish tasks and send " +
"shutdown signal, or manual interruptions. " +
"Use the Scala RabitTracker for timeout support.");
}
try {
trackerProcess.get().waitFor();
int returnVal = trackerProcess.get().exitValue();
logger.info("Tracker Process ends with exit code " + returnVal);
stop();
return returnVal;
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
return TrackerStatus.INTERRUPTED.getStatusCode();
}
public void waitFor(long timeout) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout));
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2023 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -54,7 +54,7 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam,
@ -146,12 +146,24 @@ class XGBoostJNI {
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorInit(String args);
public final static native int CommunicatorFinalize();
public final static native int CommunicatorPrint(String msg);
public final static native int CommunicatorGetRank(int[] out);
public final static native int CommunicatorGetWorldSize(int[] out);
// Tracker functions
public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout,
long[] out);
public final static native int TrackerRun(long handle);
public final static native int TrackerWaitFor(long handle, long timeout);
public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out);
public final static native int TrackerFree(long handle);
// Perform Allreduce operation on data in sendrecvbuf.
final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
@ -168,5 +180,4 @@ class XGBoostJNI {
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
}

View File

@ -42,5 +42,4 @@ public final class UtilUnsafe {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams {
}
}
}

View File

@ -1,20 +1,21 @@
/**
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
* Copyright 2014-2024, XGBoost Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./xgboost4j.h"
#include <rabit/c_api.h>
#include <xgboost/base.h>
#include <xgboost/c_api.h>
#include <xgboost/json.h>
@ -23,7 +24,6 @@
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv,
jclass jcls,
jstring jargs) {
xgboost::Json config{xgboost::Object{}};
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
assert(len % 2 == 0);
for (bst_ulong i = 0; i < len / 2; ++i) {
jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i);
std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key));
jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1);
std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value));
config[key_str] = xgboost::String(value_str);
const char *args = jenv->GetStringUTFChars(jargs, nullptr);
JVM_CHECK_CALL(XGCommunicatorInit(args));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate(
JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout,
jlongArray jout) {
using namespace xgboost; // NOLINT
TrackerHandle handle;
Json config{Object{}};
std::string shost{jenv->GetStringUTFChars(host, nullptr),
static_cast<std::string::size_type>(jenv->GetStringLength(host))};
if (!shost.empty()) {
config["host"] = shost;
}
std::string json_str;
xgboost::Json::Dump(config, &json_str);
JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str()));
config["port"] = Integer{static_cast<Integer::Int>(port)};
config["n_workers"] = Integer{static_cast<Integer::Int>(n_workers)};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
config["sortby"] = Integer{static_cast<Integer::Int>(sortby)};
config["dmlc_communicator"] = String{"rabit"};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle));
setHandle(jenv, jout, handle);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerRun(handle, nullptr));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass,
jlong jhandle,
jlong timeout) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str()));
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
using namespace xgboost; // NOLINT
Json config{Object{}};
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
std::string sconfig = Json::Dump(config);
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
char const *args;
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
auto jargs = Json::Load(StringView{args});
jstring jret = jenv->NewStringUTF(args);
jenv->SetObjectArrayElement(jout, 0, jret);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass,
jlong jhandle) {
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
JVM_CHECK_CALL(XGTrackerFree(handle));
return 0;
}
@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
* Method: CommunicatorFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize
(JNIEnv *jenv, jclass jcls) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *,
jclass) {
JVM_CHECK_CALL(XGCommunicatorFinalize());
return 0;
}

View File

@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
* Signature: ([Ljava/lang/String;)I
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit
(JNIEnv *, jclass, jobjectArray);
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerCreate
* Signature: (Ljava/lang/String;IIIJ[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate
(JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerRun
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWaitFor
* Signature: (JJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerWorkerArgs
* Signature: (JJ[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs
(JNIEnv *, jclass, jlong, jlong, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: TrackerFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorAllreduce

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -298,7 +298,7 @@ public class DMatrixTest {
@Test
public void testTrainWithDenseMatrixRef() throws XGBoostError {
Map<String, String> rabitEnv = new HashMap<>();
Map<String, Object> rabitEnv = new HashMap<>();
rabitEnv.put("DMLC_TASK_ID", "0");
Communicator.init(rabitEnv);
DMatrix trainMat = null;

View File

@ -31,31 +31,13 @@ protobuf_generate(
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
add_library(federated_old_proto STATIC federated.old.proto)
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
xgboost_target_properties(federated_old_proto)
protobuf_generate(
TARGET federated_old_proto
LANGUAGE cpp
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
protobuf_generate(
TARGET federated_old_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
# Wrapper for the gRPC client.
add_library(federated_client INTERFACE)
target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
target_link_libraries(federated_client INTERFACE federated_old_proto)
# Rabit engine for Federated Learning.
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)

View File

@ -1,81 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
enum DataType {
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
}
enum ReduceOperation {
MAX = 0;
MIN = 1;
SUM = 2;
BITWISE_AND = 3;
BITWISE_OR = 4;
BITWISE_XOR = 5;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllgatherVRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherVReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
DataType data_type = 4;
ReduceOperation reduce_operation = 5;
}
message AllreduceReply {
bytes receive_buffer = 1;
}
message BroadcastRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
// The root rank to broadcast from.
int32 root = 4;
}
message BroadcastReply {
bytes receive_buffer = 1;
}

View File

@ -1,132 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.old.grpc.pb.h>
#include <federated.old.pb.h>
#include <grpcpp/grpcpp.h>
#include <cstdio>
#include <cstdlib>
#include <limits>
#include <string>
namespace xgboost::federated {
/**
* @brief A wrapper around the gRPC client.
*/
class FederatedClient {
public:
FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
std::string const &client_key, std::string const &client_cert)
: stub_{[&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel =
grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return Federated::NewStub(channel);
}()},
rank_{rank} {}
/** @brief Insecure client for connecting to localhost only. */
FederatedClient(std::string const &server_address, int rank)
: stub_{[&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return Federated::NewStub(
grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args));
}()},
rank_{rank} {}
std::string Allgather(std::string_view send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer.data(), send_buffer.size());
AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Allgather(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}
std::string AllgatherV(std::string_view send_buffer) {
AllgatherVRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer.data(), send_buffer.size());
AllgatherVReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->AllgatherV(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("AllgatherV RPC failed");
}
}
std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_data_type(data_type);
request.set_reduce_operation(reduce_operation);
AllreduceReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Allreduce(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allreduce RPC failed");
}
}
std::string Broadcast(std::string const &send_buffer, int root) {
BroadcastRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);
request.set_root(root);
BroadcastReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Broadcast(&context, request, &reply);
if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Broadcast RPC failed");
}
}
private:
std::unique_ptr<Federated::Stub> const stub_;
int const rank_;
uint64_t sequence_number_{};
};
} // namespace xgboost::federated

View File

@ -1,195 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include "../../src/c_api/c_api_utils.h"
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"
namespace xgboost::collective {
/**
* @brief A Federated Learning communicator class that handles collective communication.
*/
class FederatedCommunicator : public Communicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static Communicator *Create(Json const &config) {
std::string server_address{};
int world_size{0};
int rank{-1};
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
// Parse environment variables first.
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
value = getenv("FEDERATED_SERVER_CERT");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT");
if (value != nullptr) {
client_cert = value;
}
// Runtime configuration overrides, optional as users can specify them as env vars.
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";
}
if (world_size == 0) {
LOG(FATAL) << "Federated world size must be set.";
}
if (rank == -1) {
LOG(FATAL) << "Federated rank must be set.";
}
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
client_cert);
}
/**
* @brief Construct a new federated communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
* @param server_address Address of the federated server (host:port).
* @param server_cert_path Path to the server cert file.
* @param client_key_path Path to the client key file.
* @param client_cert_path Path to the client cert file.
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
std::string const &server_cert_path, std::string const &client_key_path,
std::string const &client_cert_path)
: Communicator{world_size, rank} {
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
} else {
client_.reset(new xgboost::federated::FederatedClient(
server_address, rank, xgboost::common::ReadAll(server_cert_path),
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
}
}
/**
* @brief Construct an insecure federated communicator without using SSL.
* @param world_size Total number of processes.
* @param rank Rank of the current process.
* @param server_address Address of the federated server (host:port).
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
: Communicator{world_size, rank} {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
}
~FederatedCommunicator() override { client_.reset(); }
/**
* \brief Get if the communicator is distributed.
* \return True.
*/
[[nodiscard]] bool IsDistributed() const override { return true; }
/**
* \brief Get if the communicator is federated.
* \return True.
*/
[[nodiscard]] bool IsFederated() const override { return true; }
/**
* \brief Perform allgather.
* \param input Buffer for sending data.
*/
std::string AllGather(std::string_view input) override {
return client_->Allgather(input);
}
/**
* \brief Perform variable-length allgather.
* \param input Buffer for sending data.
*/
std::string AllGatherV(std::string_view input) override {
return client_->AllgatherV(input);
}
/**
* \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type.
* \param op Enumeration of operation type.
*/
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
count * GetTypeSize(data_type));
auto const received =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
static_cast<xgboost::federated::ReduceOperation>(op));
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
}
/**
* \brief Broadcast a memory region to all others from root.
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
if (GetWorldSize() == 1) return;
if (GetRank() == root) {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
client_->Broadcast(send_buffer, root);
} else {
auto const received = client_->Broadcast("", root);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
}
/**
* \brief Get the name of the processor.
* \return Name of the processor.
*/
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
/**
* \brief Print the message to the communicator.
* \param message The message to be printed.
*/
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
private:
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
};
} // namespace xgboost::collective

View File

@ -1,86 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "federated_server.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/server.h> // for Server
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>
#include <sstream>
#include "../../src/collective/comm.h"
#include "../../src/common/io.h"
#include "../../src/common/json_utils.h"
namespace xgboost::federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
AllgatherReply* reply) {
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
AllgatherVReply* reply) {
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
static_cast<xgboost::collective::DataType>(request->data_type()),
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
return grpc::Status::OK;
}
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
BroadcastReply* reply) {
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
request->root());
return grpc::Status::OK;
}
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{static_cast<std::int32_t>(world_size)};
grpc::ServerBuilder builder;
auto options =
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = xgboost::common::ReadAll(server_key_file);
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< world_size;
server->Wait();
}
void RunInsecureServer(int port, std::size_t world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{static_cast<std::int32_t>(world_size)};
grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< world_size;
server->Wait();
}
} // namespace xgboost::federated

View File

@ -1,37 +0,0 @@
/**
* Copyright 2022-2024, XGBoost contributors
*/
#pragma once
#include <federated.old.grpc.pb.h>
#include <cstdint> // for int32_t
#include "../../src/collective/in_memory_handler.h"
namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
AllgatherVReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
BroadcastReply* reply) override;
private:
xgboost::collective::InMemoryHandler handler_;
};
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file);
void RunInsecureServer(int port, std::size_t world_size);
} // namespace xgboost::federated

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023, XGBoost contributors
* Copyright 2022-2024, XGBoost contributors
*/
#include "federated_tracker.h"
@ -8,13 +8,12 @@
#include <cstdint> // for int32_t
#include <exception> // for exception
#include <future> // for future, async
#include <limits> // for numeric_limits
#include <string> // for string
#include <thread> // for sleep_for
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
namespace xgboost::collective {
namespace federated {
@ -36,8 +35,8 @@ grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest
AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
static_cast<xgboost::collective::DataType>(request->data_type()),
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
static_cast<xgboost::ArrayInterfaceHandler::Type>(request->data_type()),
static_cast<xgboost::collective::Op>(request->reduce_operation()));
return grpc::Status::OK;
}
@ -53,9 +52,13 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
StringView msg{"Empty certificate path."};
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
CHECK(!server_key_path_.empty()) << msg;
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
CHECK(!server_cert_file_.empty()) << msg;
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
CHECK(!client_cert_file_.empty()) << msg;
}
}

View File

@ -5,11 +5,12 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include "../sycl/device_manager.h"
#include "../../src/collective/communicator-inl.h"
namespace xgboost {
namespace sycl {
@ -21,22 +22,23 @@ namespace sycl {
}
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(rabit::IsDistributed());
(collective::IsDistributed());
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
return devices[device_idx];
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
return devices[device_idx];
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
return cpu_devices[device_idx];
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
return cpu_devices[device_idx];
} else {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
return gpu_devices[device_idx];
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
return gpu_devices[device_idx];
}
} else {
if (device_spec.IsSyclCPU()) {
@ -62,24 +64,25 @@ namespace sycl {
}
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(rabit::IsDistributed());
(collective::IsDistributed());
std::lock_guard<std::mutex> guard(queue_registering_mutex);
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);;
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
}
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
}
} else {
if (device_spec.IsSyclCPU()) {
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v);

View File

@ -6,7 +6,6 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include <vector>

View File

@ -9,7 +9,6 @@
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#pragma GCC diagnostic pop
#include <rabit/rabit.h>
#include <cmath>
#include <memory>

View File

@ -4,7 +4,6 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include <cstddef>

View File

@ -1,17 +1,17 @@
"""XGBoost collective communication related API."""
import ctypes
import json
import logging
import os
import pickle
import platform
from enum import IntEnum, unique
from typing import Any, Dict, List, Optional
import numpy as np
from ._typing import _T
from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str
from .core import _LIB, _check_call, build_info, c_str, make_jcargs, py_str
LOGGER = logging.getLogger("[xgboost.collective]")
@ -21,49 +21,35 @@ def init(**args: Any) -> None:
Parameters
----------
args: Dict[str, Any]
args :
Keyword arguments representing the parameters and their values.
Accepted parameters:
- xgboost_communicator: The type of the communicator. Can be set as an environment
variable.
- dmlc_communicator: The type of the communicator.
* rabit: Use Rabit. This is the default if the type is unspecified.
* federated: Use the gRPC interface for Federated Learning.
Only applicable to the Rabit communicator (these are case sensitive):
-- rabit_tracker_uri: Hostname of the tracker.
-- rabit_tracker_port: Port number of the tracker.
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
assignment.
-- rabit_world_size: Total number of workers.
-- rabit_hadoop_mode: Enable Hadoop support.
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
-- rabit_reduce_buffer: Size of the reduce buffer.
-- rabit_bootstrap_cache: Size of the bootstrap cache.
-- rabit_debug: Enable debugging.
-- rabit_timeout: Enable timeout.
-- rabit_timeout_sec: Timeout in seconds.
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
environment variables):
-- DMLC_TRACKER_URI: Hostname of the tracker.
-- DMLC_TRACKER_PORT: Port number of the tracker.
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
assignment.
-- DMLC_ROLE: Role of the current task, "worker" or "server".
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
Only applicable to the Federated communicator (use upper case for environment variables, use
lower case for runtime configuration):
-- federated_server_address: Address of the federated server.
-- federated_world_size: Number of federated workers.
-- federated_rank: Rank of the current worker.
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
-- federated_client_key: Client key file path. Only needed for the SSL mode.
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
Only applicable to the Rabit communicator:
- dmlc_tracker_uri: Hostname of the tracker.
- dmlc_tracker_port: Port number of the tracker.
- dmlc_task_id: ID of the current task, can be used to obtain deterministic
- dmlc_retry: The number of retry when handling network errors.
- dmlc_timeout: Timeout in seconds.
- dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
Only applicable to the Federated communicator (use upper case for environment
variables, use lower case for runtime configuration):
- federated_server_address: Address of the federated server.
- federated_world_size: Number of federated workers.
- federated_rank: Rank of the current worker.
- federated_server_cert: Server certificate file path. Only needed for the SSL
mode.
- federated_client_key: Client key file path. Only needed for the SSL mode.
- federated_client_cert: Client certificate file path. Only needed for the SSL
mode.
"""
config = from_pystr_to_cstr(json.dumps(args))
_check_call(_LIB.XGCommunicatorInit(config))
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))
def finalize() -> None:
@ -157,7 +143,7 @@ def broadcast(data: _T, root: int) -> _T:
assert data is not None, "need to pass in data when broadcasting"
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
# Run first broadcast
_check_call(
_LIB.XGCommunicatorBroadcast(
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
@ -184,16 +170,27 @@ def broadcast(data: _T, root: int) -> _T:
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype("int8"): 0,
np.dtype("uint8"): 1,
np.dtype("int32"): 2,
np.dtype("uint32"): 3,
np.dtype("int64"): 4,
np.dtype("uint64"): 5,
np.dtype("float32"): 6,
np.dtype("float64"): 7,
}
def _map_dtype(dtype: np.dtype) -> int:
dtype_map = {
np.dtype("float16"): 0,
np.dtype("float32"): 1,
np.dtype("float64"): 2,
np.dtype("int8"): 4,
np.dtype("int16"): 5,
np.dtype("int32"): 6,
np.dtype("int64"): 7,
np.dtype("uint8"): 8,
np.dtype("uint16"): 9,
np.dtype("uint32"): 10,
np.dtype("uint64"): 11,
}
if platform.system() != "Windows":
dtype_map.update({np.dtype("float128"): 3})
if dtype not in dtype_map:
raise TypeError(f"data type {dtype} is not supported on the current platform.")
return dtype_map[dtype]
@unique
@ -229,24 +226,23 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid
"""
if not isinstance(data, np.ndarray):
raise TypeError("allreduce only takes in numpy.ndarray")
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise TypeError(f"data type {buf.dtype} not supported")
buf = data.ravel().copy()
_check_call(
_LIB.XGCommunicatorAllreduce(
buf.ctypes.data_as(ctypes.c_void_p),
buf.size,
DTYPE_ENUM__[buf.dtype],
_map_dtype(buf.dtype),
int(op),
None,
None,
)
)
return buf
def signal_error() -> None:
"""Kill the process."""
_check_call(_LIB.XGCommunicatorSignalError())
class CommunicatorContext:
"""A context controlling collective communicator initialization and finalization."""

View File

@ -295,7 +295,7 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
if device and device.find(":") != -1:
raise ValueError(
"Distributed training doesn't support selecting device ordinal as GPUs are"
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
" managed by the distributed frameworks. use `device=cuda` or `device=gpu`"
" instead."
)

View File

@ -71,6 +71,7 @@ from xgboost.core import (
Metric,
Objective,
QuantileDMatrix,
XGBoostError,
_check_distributed_params,
_deprecate_positional_args,
_expect,
@ -90,7 +91,7 @@ from xgboost.sklearn import (
_wrap_evaluation_matrices,
xgboost_model_doc,
)
from xgboost.tracker import RabitTracker, get_host_ip
from xgboost.tracker import RabitTracker
from xgboost.training import train as worker_train
from .utils import get_n_threads
@ -160,36 +161,38 @@ def _try_start_tracker(
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
env: Dict[str, Union[int, str]] = {}
try:
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_tracker = RabitTracker(
host_ip=get_host_ip(host_ip),
n_workers=n_workers,
host_ip=host_ip,
port=port,
use_logger=False,
sortby="task",
)
else:
addr = addrs[0]
assert isinstance(addr, str) or addr is None
host_ip = get_host_ip(addr)
rabit_tracker = RabitTracker(
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
n_workers=n_workers, host_ip=addr, sortby="task"
)
env.update(rabit_tracker.worker_envs())
rabit_tracker.start(n_workers)
thread = Thread(target=rabit_tracker.join)
rabit_tracker.start()
thread = Thread(target=rabit_tracker.wait_for)
thread.daemon = True
thread.start()
except socket.error as e:
if len(addrs) < 2 or e.errno != 99:
env.update(rabit_tracker.worker_args())
except XGBoostError as e:
if len(addrs) < 2:
raise
LOGGER.warning(
"Failed to bind address '%s', trying to use '%s' instead.",
"Failed to bind address '%s', trying to use '%s' instead. Error:\n %s",
str(addrs[0]),
str(addrs[1]),
str(e),
)
env = _try_start_tracker(n_workers, addrs[1:])

View File

@ -1,45 +1,85 @@
"""XGBoost Federated Learning related API."""
"""XGBoost Experimental Federated Learning related API."""
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
import ctypes
from threading import Thread
from typing import Any, Dict, Optional
from .core import _LIB, _check_call, make_jcargs
from .tracker import RabitTracker
def run_federated_server(
port: int,
world_size: int,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
) -> None:
"""Run the Federated Learning server.
class FederatedTracker(RabitTracker):
"""Tracker for federated training.
Parameters
----------
port : int
The port to listen on.
world_size: int
n_workers :
The number of federated workers.
server_key_path: str
Path to the server private key file. SSL is turned off if empty.
server_cert_path: str
Path to the server certificate file. SSL is turned off if empty.
client_cert_path: str
Path to the client certificate file. SSL is turned off if empty.
port :
The port to listen on.
secure :
Whether this is a secure instance. If True, then the following arguments for SSL
must be provided.
server_key_path :
Path to the server private key file.
server_cert_path :
Path to the server certificate file.
client_cert_path :
Path to the client certificate file.
"""
if build_info()["USE_FEDERATED"]:
if not server_key_path or not server_cert_path or not client_cert_path:
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
else:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
)
)
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "
"enabled in order to use this module"
def __init__( # pylint: disable=R0913, W0231
self,
n_workers: int,
port: int,
secure: bool,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
timeout: int = 300,
) -> None:
handle = ctypes.c_void_p()
args = make_jcargs(
n_workers=n_workers,
port=port,
dmlc_communicator="federated",
federated_secure=secure,
server_key_path=server_key_path,
server_cert_path=server_cert_path,
client_cert_path=client_cert_path,
timeout=int(timeout),
)
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
self.handle = handle
def run_federated_server( # pylint: disable=too-many-arguments
n_workers: int,
port: int,
server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None,
timeout: int = 300,
) -> Dict[str, Any]:
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
args: Dict[str, Any] = {"n_workers": n_workers}
secure = all(
path is not None
for path in [server_key_path, server_cert_path, client_cert_path]
)
tracker = FederatedTracker(
n_workers=n_workers, port=port, secure=secure, timeout=timeout
)
tracker.start()
thread = Thread(target=tracker.wait_for)
thread.daemon = True
thread.start()
args.update(tracker.worker_args())
return args

View File

@ -47,21 +47,21 @@ class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
"""Context with PySpark specific task ID."""
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
args["DMLC_TASK_ID"] = str(context.partitionId())
args["dmlc_task_id"] = str(context.partitionId())
super().__init__(**args)
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker with n_workers"""
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
args: Dict[str, Any] = {"n_workers": n_workers}
host = _get_host_ip(context)
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers, sortby="task")
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task")
tracker.start()
thread = Thread(target=tracker.wait_for)
thread.daemon = True
thread.start()
return env
args.update(tracker.worker_args())
return args
def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:

View File

@ -111,8 +111,6 @@ def no_sklearn() -> PytestSkip:
def no_dask() -> PytestSkip:
if sys.platform.startswith("win"):
return {"reason": "Unsupported platform.", "condition": True}
return no_mod("dask")
@ -193,6 +191,10 @@ def no_multiple(*args: Any) -> PytestSkip:
return {"condition": condition, "reason": reason}
def skip_win() -> PytestSkip:
return {"reason": "Unsupported platform.", "condition": is_windows()}
def skip_s390x() -> PytestSkip:
condition = platform.machine() == "s390x"
reason = "Known to fail on s390x"
@ -968,18 +970,18 @@ def run_with_rabit(
exception_queue.put(e)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start(world_size)
tracker.start()
workers = []
for _ in range(world_size):
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
worker = threading.Thread(target=run_worker, args=(tracker.worker_args(),))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}"
tracker.join()
tracker.wait_for()
def column_split_feature_names(

View File

@ -1,64 +1,12 @@
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
"""
This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
which is a specialized version for xgboost tasks.
"""
import argparse
import logging
"""Tracker for XGBoost collective."""
import ctypes
import json
import socket
import struct
import sys
from threading import Thread
from typing import Dict, List, Optional, Set, Tuple, Union
from enum import IntEnum, unique
from typing import Dict, Optional, Union
_RingMap = Dict[int, Tuple[int, int]]
_TreeMap = Dict[int, List[int]]
class ExSocket:
"""
Extension of socket to handle recv and send of special data
"""
def __init__(self, sock: socket.socket) -> None:
self.sock = sock
def recvall(self, nbytes: int) -> bytes:
"""Receive number of bytes."""
res = []
nread = 0
while nread < nbytes:
chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b"".join(res)
def recvint(self) -> int:
"""Receive an integer of 32 bytes"""
return struct.unpack("@i", self.recvall(4))[0]
def sendint(self, value: int) -> None:
"""Send an integer of 32 bytes"""
self.sock.sendall(struct.pack("@i", value))
def sendstr(self, value: str) -> None:
"""Send a Python string"""
self.sendint(len(value))
self.sock.sendall(value.encode())
def recvstr(self) -> str:
"""Receive a Python string"""
slen = self.recvint()
return self.recvall(slen).decode()
# magic number used to verify existence of data
MAGIC_NUM = 0xFF99
def get_some_ip(host: str) -> str:
"""Get ip from host"""
return socket.getaddrinfo(host, None)[0][4][0]
from .core import _LIB, _check_call, make_jcargs
def get_family(addr: str) -> int:
@ -66,439 +14,95 @@ def get_family(addr: str) -> int:
return socket.getaddrinfo(addr, None)[0][0]
class WorkerEntry:
"""Hanlder to each worker."""
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock)
self.sock = worker
self.host = get_some_ip(s_addr[0])
magic = worker.recvint()
assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
worker.sendint(MAGIC_NUM)
self.rank = worker.recvint()
self.world_size = worker.recvint()
self.task_id = worker.recvstr()
self.cmd = worker.recvstr()
self.wait_accept = 0
self.port: Optional[int] = None
def print(self, use_logger: bool) -> None:
"""Execute the print command from worker."""
msg = self.sock.recvstr()
# On dask we use print to avoid setting global verbosity.
if use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
def decide_rank(self, job_map: Dict[str, int]) -> int:
"""Get the rank of current entry."""
if self.rank >= 0:
return self.rank
if self.task_id != "NULL" and self.task_id in job_map:
return job_map[self.task_id]
return -1
def assign_rank(
self,
rank: int,
wait_conn: Dict[int, "WorkerEntry"],
tree_map: _TreeMap,
parent_map: Dict[int, int],
ring_map: _RingMap,
) -> List[int]:
"""Assign the rank for current entry."""
self.rank = rank
nnset = set(tree_map[rank])
rprev, next_rank = ring_map[rank]
self.sock.sendint(rank)
# send parent rank
self.sock.sendint(parent_map[rank])
# send world size
self.sock.sendint(len(tree_map))
self.sock.sendint(len(nnset))
# send the rprev and next link
for r in nnset:
self.sock.sendint(r)
# send prev link
if rprev not in (-1, rank):
nnset.add(rprev)
self.sock.sendint(rprev)
else:
self.sock.sendint(-1)
# send next link
if next_rank not in (-1, rank):
nnset.add(next_rank)
self.sock.sendint(next_rank)
else:
self.sock.sendint(-1)
return self._get_remote(wait_conn, nnset)
def _get_remote(
self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int]
) -> List[int]:
while True:
conset = []
for r in badset:
if r in wait_conn:
conset.append(r)
self.sock.sendint(len(conset))
self.sock.sendint(len(badset) - len(conset))
for r in conset:
self.sock.sendstr(wait_conn[r].host)
port = wait_conn[r].port
assert port is not None
# send port of this node to other workers so that they can call connect
self.sock.sendint(port)
self.sock.sendint(r)
nerr = self.sock.recvint()
if nerr != 0:
continue
self.port = self.sock.recvint()
rmset = []
# all connection was successuly setup
for r in conset:
wait_conn[r].wait_accept -= 1
if wait_conn[r].wait_accept == 0:
rmset.append(r)
for r in rmset:
wait_conn.pop(r, None)
self.wait_accept = len(badset) - len(conset)
return rmset
class RabitTracker:
"""
tracker for rabit
"""
def __init__(
self,
host_ip: str,
n_workers: int,
port: int = 0,
use_logger: bool = False,
sortby: str = "host",
) -> None:
"""A Python implementation of RABIT tracker.
Parameters
..........
use_logger:
Use logging.info for tracker print command. When set to False, Python print
function is used instead.
sortby:
How to sort the workers for rank assignment. The default is host, but users
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
"""
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((host_ip, port))
self.port = sock.getsockname()[1]
sock.listen(256)
self.sock = sock
self.host_ip = host_ip
self.thread: Optional[Thread] = None
self.n_workers = n_workers
self._use_logger = use_logger
self._sortby = sortby
logging.info("start listen on %s:%d", host_ip, self.port)
def __del__(self) -> None:
if hasattr(self, "sock"):
self.sock.close()
@staticmethod
def _get_neighbor(rank: int, n_workers: int) -> List[int]:
rank = rank + 1
ret = []
if rank > 1:
ret.append(rank // 2 - 1)
if rank * 2 - 1 < n_workers:
ret.append(rank * 2 - 1)
if rank * 2 < n_workers:
ret.append(rank * 2)
return ret
def worker_envs(self) -> Dict[str, Union[str, int]]:
"""
get environment variables for workers
can be passed in as args or envs
"""
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {}
parent_map: Dict[int, int] = {}
for r in range(n_workers):
tree_map[r] = self._get_neighbor(r, n_workers)
parent_map[r] = (r + 1) // 2 - 1
return tree_map, parent_map
def find_share_ring(
self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
) -> List[int]:
"""
get a ring structure that tends to share nodes with the tree
return a list starting from rank
"""
nset = set(tree_map[rank])
cset = nset - {parent_map[rank]}
if not cset:
return [rank]
rlst = [rank]
cnt = 0
for v in cset:
vlst = self.find_share_ring(tree_map, parent_map, v)
cnt += 1
if cnt == len(cset):
vlst.reverse()
rlst += vlst
return rlst
def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
"""
get a ring connection used to recover local data
"""
assert parent_map[0] == -1
rlst = self.find_share_ring(tree_map, parent_map, 0)
assert len(rlst) == len(tree_map)
ring_map: _RingMap = {}
n_workers = len(tree_map)
for r in range(n_workers):
rprev = (r + n_workers - 1) % n_workers
rnext = (r + 1) % n_workers
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map
def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
"""
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self._get_tree(n_workers)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0: 0}
k = 0
for i in range(n_workers - 1):
k = ring_map[k][1]
rmap[k] = i + 1
ring_map_: _RingMap = {}
tree_map_: _TreeMap = {}
parent_map_: Dict[int, int] = {}
for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, tree_nodes in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
for k, parent in parent_map.items():
if k != 0:
parent_map_[rmap[k]] = rmap[parent]
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
if self._sortby == "host":
pending.sort(key=lambda s: s.host)
elif self._sortby == "task":
pending.sort(key=lambda s: s.task_id)
return pending
def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker."""
# set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections
wait_conn: Dict[int, WorkerEntry] = {}
# maps job id to rank
job_map: Dict[str, int] = {}
# list of workers that is pending to be assigned rank
pending: List[WorkerEntry] = []
# lazy initialize tree_map
tree_map = None
while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept()
s = WorkerEntry(fd, s_addr)
if s.cmd == "print":
s.print(self._use_logger)
continue
if s.cmd == "shutdown":
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
shutdown[s.rank] = s
logging.debug("Received %s signal from %d", s.cmd, s.rank)
continue
assert s.cmd == "start"
# lazily initialize the workers
if tree_map is None:
assert s.cmd == "start"
if s.world_size > 0:
n_workers = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
# set of nodes that is pending for getting up
todo_nodes = list(range(n_workers))
else:
assert s.world_size in (-1, n_workers)
if s.cmd == "recover":
assert s.rank >= 0
rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert todo_nodes
pending.append(s)
if len(pending) == len(todo_nodes):
pending = self._sort_pending(pending)
for s in pending:
rank = todo_nodes.pop(0)
if s.task_id != "NULL":
job_map[s.task_id] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.debug(
"Received %s signal from %s; assign rank %d",
s.cmd,
s.host,
s.rank,
)
if not todo_nodes:
logging.info("@tracker All of %d nodes getting started", n_workers)
else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logging.debug("Received %s signal from %d", s.cmd, s.rank)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.info("@tracker All nodes finishes job")
def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect."""
def run() -> None:
self.accept_workers(n_workers)
self.thread = Thread(target=run, args=(), daemon=True)
self.thread.start()
def join(self) -> None:
"""Wait for the tracker to finish."""
while self.thread is not None and self.thread.is_alive():
self.thread.join(100)
def alive(self) -> bool:
"""Wether the tracker thread is alive"""
return self.thread is not None and self.thread.is_alive()
def get_host_ip(host_ip: Optional[str] = None) -> str:
"""Get the IP address of current host. If `host_ip` is not none then it will be
returned as it's
"""
if host_ip is None or host_ip == "auto":
host_ip = "ip"
if host_ip == "dns":
host_ip = socket.getfqdn()
elif host_ip == "ip":
from socket import gaierror
try:
host_ip = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.debug(
"gethostbyname(socket.getfqdn()) failed... trying on hostname()"
)
host_ip = socket.gethostbyname(socket.gethostname())
if host_ip.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(("10.255.255.255", 1))
host_ip = s.getsockname()[0]
assert host_ip is not None
return host_ip
def start_rabit_tracker(args: argparse.Namespace) -> None:
"""Standalone function to start rabit tracker.
"""Tracker for the collective used in XGBoost, acting as a coordinator between
workers.
Parameters
----------
args: arguments to start the rabit tracker.
..........
sortby:
How to sort the workers for rank assignment. The default is host, but users can
set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
timeout :
Timeout for constructing the communication group and waiting for the tracker to
shutdown when it's instructed to, doesn't apply to communication when tracking
is running.
The timeout value should take the time of data loading and pre-processing into
account, due to potential lazy execution.
The :py:meth:`.wait_for` method has a different timeout parameter that can stop
the tracker even if the tracker is still being used. A value error is raised
when timeout is reached.
"""
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
rabit = RabitTracker(
host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
)
envs.update(rabit.worker_envs())
rabit.start(args.num_workers)
sys.stdout.write("DMLC_TRACKER_ENV_START\n")
# simply write configuration to stdout
for k, v in envs.items():
sys.stdout.write(f"{k}={v}\n")
sys.stdout.write("DMLC_TRACKER_ENV_END\n")
sys.stdout.flush()
rabit.join()
@unique
class _SortBy(IntEnum):
HOST = 0
TASK = 1
def main() -> None:
"""Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description="Rabit Tracker start.")
parser.add_argument(
"--num-workers",
required=True,
type=int,
help="Number of worker process to be launched.",
)
parser.add_argument(
"--num-servers",
default=0,
type=int,
help="Number of server process to be launched. Only used in PS jobs.",
)
parser.add_argument(
"--host-ip",
default=None,
type=str,
help=(
"Host IP addressed, this is only needed "
+ "if the host IP cannot be automatically guessed."
),
)
parser.add_argument(
"--log-level",
default="INFO",
type=str,
choices=["INFO", "DEBUG"],
help="Logging level of the logger.",
)
args = parser.parse_args()
def __init__( # pylint: disable=too-many-arguments
self,
n_workers: int,
host_ip: Optional[str],
port: int = 0,
sortby: str = "host",
timeout: int = 0,
) -> None:
fmt = "%(asctime)s %(levelname)s %(message)s"
if args.log_level == "INFO":
level = logging.INFO
elif args.log_level == "DEBUG":
level = logging.DEBUG
else:
raise RuntimeError(f"Unknown logging level {args.log_level}")
handle = ctypes.c_void_p()
if sortby not in ("host", "task"):
raise ValueError("Expecting either 'host' or 'task' for sortby.")
if host_ip is not None:
get_family(host_ip) # use python socket to stop early for invalid address
args = make_jcargs(
host=host_ip,
n_workers=n_workers,
port=port,
dmlc_communicator="rabit",
sortby=self._SortBy.HOST if sortby == "host" else self._SortBy.TASK,
timeout=int(timeout),
)
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
self.handle = handle
logging.basicConfig(format=fmt, level=level)
def free(self) -> None:
"""Internal function for testing."""
if hasattr(self, "handle"):
handle = self.handle
del self.handle
_check_call(_LIB.XGTrackerFree(handle))
if args.num_servers == 0:
start_rabit_tracker(args)
else:
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
def __del__(self) -> None:
self.free()
def start(self) -> None:
"""Start the tracker. Once started, the client still need to call the
:py:meth:`wait_for` method in order to wait for it to finish (think of it as a
thread).
if __name__ == "__main__":
main()
"""
_check_call(_LIB.XGTrackerRun(self.handle, make_jcargs()))
def wait_for(self, timeout: Optional[int] = None) -> None:
"""Wait for the tracker to finish all the work and shutdown. When timeout is
reached, a value error is raised. By default we don't have timeout since we
don't know how long it takes for the model to finish training.
"""
_check_call(_LIB.XGTrackerWaitFor(self.handle, make_jcargs(timeout=timeout)))
def worker_args(self) -> Dict[str, Union[str, int]]:
"""Get arguments for workers."""
c_env = ctypes.c_char_p()
_check_call(_LIB.XGTrackerWorkerArgs(self.handle, ctypes.byref(c_env)))
assert c_env.value is not None
env = json.loads(c_env.value)
return env

View File

@ -1,15 +0,0 @@
cmake_minimum_required(VERSION 3.18)
find_package(Threads REQUIRED)
set(RABIT_SOURCES
${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc
${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc)
if(RABIT_MOCK)
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc)
else()
list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine.cc)
endif()
set(RABIT_SOURCES ${RABIT_SOURCES} PARENT_SCOPE)

View File

@ -1,28 +0,0 @@
Copyright (c) 2014 by Contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of rabit nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1 +0,0 @@
# This directory contains the CPU network module for XGBoost. The library originates from [RABIT](https://github.com/dmlc/rabit)

View File

@ -1,19 +0,0 @@
/*!
* Copyright (c) 2020 by Contributors
* \file base.h
* \brief Macros common to all headers
*
* \author Hyunsu Cho
*/
#ifndef RABIT_BASE_H_
#define RABIT_BASE_H_
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS
#endif // _CRT_SECURE_NO_WARNINGS
#ifndef _CRT_SECURE_NO_DEPRECATE
#define _CRT_SECURE_NO_DEPRECATE
#endif // _CRT_SECURE_NO_DEPRECATE
#endif // RABIT_BASE_H_

View File

@ -1,157 +0,0 @@
/*!
* Copyright by Contributors
* \file c_api.h
* \author Tianqi Chen
* \brief a C style API of rabit.
*/
#ifndef RABIT_C_API_H_
#define RABIT_C_API_H_
#ifdef __cplusplus
#define RABIT_EXTERN_C extern "C"
#include <cstdio>
#else
#define RABIT_EXTERN_C
#include <stdio.h>
#endif // __cplusplus
#if defined(_MSC_VER) || defined(_WIN32)
#define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport)
#else
#define RABIT_DLL RABIT_EXTERN_C __attribute__ ((visibility ("default")))
#endif // defined(_MSC_VER) || defined(_WIN32)
/*! \brief rabit unsigned long type */
typedef unsigned long rbt_ulong; // NOLINT(*)
/*!
* \brief initialize the rabit module,
* call this once before using anything
* The additional arguments is not necessary.
* Usually rabit will detect settings
* from environment variables.
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL bool RabitInit(int argc, char *argv[]);
/*!
* \brief finalize the rabit engine,
* call this function after you finished all jobs.
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL int RabitFinalize(void);
/*!
* \brief get rank of previous process in ring topology
* \return rank number of worker
* */
RABIT_DLL int RabitGetRingPrevRank(void);
/*!
* \brief get rank of current process
* \return rank number of worker
* */
RABIT_DLL int RabitGetRank(void);
/*!
* \brief get total number of process
* \return total world size
* */
RABIT_DLL int RabitGetWorldSize(void);
/*!
* \brief get rank of current process
* \return if rabit is distributed
* */
RABIT_DLL int RabitIsDistributed(void);
/*!
* \brief print the msg to the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg the message to be printed
*/
RABIT_DLL int RabitTrackerPrint(const char *msg);
/*!
* \brief get name of processor
* \param out_name hold output string
* \param out_len hold length of output string
* \param max_len maximum buffer length of input
*/
RABIT_DLL void RabitGetProcessorName(char *out_name,
rbt_ulong *out_len,
rbt_ulong max_len);
/*!
* \brief broadcast an memory region to all others from root
*
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to send or receive buffer,
* \param size the size of the data
* \param root the root of process
*/
RABIT_DLL int RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
* \param size_node_slice size of the current node slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
RABIT_DLL int RabitAllgather(void *sendrecvbuf, size_t total_size,
size_t beginIndex, size_t size_node_slice,
size_t size_prev_slice, int enum_dtype);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size());
* ...
* \param sendrecvbuf buffer for both sending and receiving data
* \param count number of elements to be reduced
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
* \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
int enum_op, void (*prepare_fun)(void *arg),
void *prepare_arg);
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \return rabit version number
*/
RABIT_DLL int RabitVersionNumber(void);
/*!
* \brief a Dummy function,
* used to cause force link of C API into the DLL.
* \code
* \/\/force link rabit C API library.
* static int must_link_rabit_ = RabitLinkTag();
* \endcode
* \return a dummy integer.
*/
RABIT_DLL int RabitLinkTag(void);
#endif // RABIT_C_API_H_

View File

@ -1,197 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.h
* \brief This file defines the core interface of rabit library
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_INTERNAL_ENGINE_H_
#define RABIT_INTERNAL_ENGINE_H_
#include <string>
#include "rabit/serializable.h"
namespace MPI { // NOLINT
/*! \brief MPI data type just to be compatible with MPI reduce function*/
class Datatype;
}
/*! \brief namespace of rabit */
namespace rabit {
/*! \brief core interface of the engine */
namespace engine {
/*! \brief interface of core Allreduce engine */
class IEngine {
public:
/*!
* \brief Preprocessing function, that is called before AllReduce,
* used to prepare the data used by AllReduce
* \param arg additional possible argument used to invoke the preprocessor
*/
typedef void (PreprocFunction) (void *arg); // NOLINT
/*!
* \brief reduce function, the same form of MPI reduce function is used,
* to be compatible with MPI interface
* In all the functions, the memory is ensured to aligned to 64-bit
* which means it is OK to cast src,dst to double* int* etc
* \param src pointer to source space
* \param dst pointer to destination reduction
* \param count total number of elements to be reduced (note this is total number of elements instead of bytes)
* the definition of the reduce function should be type aware
* \param dtype the data type object, to be compatible with MPI reduce
*/
typedef void (ReduceFunction) (const void *src, // NOLINT
void *dst, int count,
const MPI::Datatype &dtype);
/*! \brief virtual destructor */
~IEngine() = default;
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) = 0;
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the number of bytes the type has
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr) = 0;
/*!
* \brief broadcasts data from root to every other node
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
*/
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0;
/*!
* deprecated
*/
virtual int LoadCheckPoint() = 0;
/*!
* \brief Increase internal version number. Deprecated.
*/
virtual void CheckPoint() = 0;
/*!
* \return version number of the current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber() const = 0;
/*! \brief gets rank of previous node in ring topology */
virtual int GetRingPrevRank() const = 0;
/*! \brief gets rank of current node */
virtual int GetRank() const = 0;
/*! \brief gets total number of nodes */
virtual int GetWorldSize() const = 0;
/*! \brief whether we run in distribted mode */
virtual bool IsDistributed() const = 0;
/*! \brief gets the host name of the current node */
virtual std::string GetHost() const = 0;
/*!
* \brief prints the msg in the tracker,
* this function can be used to communicate progress information to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg) = 0;
};
/*! \brief initializes the engine module */
bool Init(int argc, char *argv[]);
/*! \brief finalizes the engine module */
bool Finalize();
/*! \brief singleton method to get engine */
IEngine *GetEngine();
/*! \brief namespace that contains stubs to be compatible with MPI */
namespace mpi {
/*!\brief enum of all operators */
enum OpType {
kMax = 0,
kMin = 1,
kSum = 2,
kBitwiseAND = 3,
kBitwiseOR = 4,
kBitwiseXOR = 5,
};
/*!\brief enum of supported data types */
enum DataType {
kChar = 0,
kUChar = 1,
kInt = 2,
kUInt = 3,
kLong = 4,
kULong = 5,
kFloat = 6,
kDouble = 7,
kLongLong = 8,
kULongLong = 9
};
} // namespace mpi
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
void Allgather(void* sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI
* do not use this function directly
* \param sendrecvbuf buffer for both sending and receiving data
* \param type_nbytes the number of bytes the type has
* \param count number of elements to be reduced
* \param reducer reduce function
* \param dtype the data type
* \param op the reduce operator type
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function.
*/
void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr);
} // namespace engine
} // namespace rabit
#endif // RABIT_INTERNAL_ENGINE_H_

View File

@ -1,118 +0,0 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* \file io.h
* \brief utilities with different serializable implementations
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_IO_H_
#define RABIT_INTERNAL_IO_H_
#include <algorithm>
#include <cstddef> // for size_t
#include <cstdio>
#include <cstring> // for memcpy
#include <limits>
#include <numeric>
#include <string>
#include <vector>
#include "dmlc/io.h"
#include "xgboost/logging.h"
namespace rabit::utils {
/*! \brief re-use definition of dmlc::SeekStream */
using SeekStream = dmlc::SeekStream;
/**
* @brief Fixed size memory buffer as a stream.
*/
struct MemoryFixSizeBuffer : public SeekStream {
public:
// similar to SEEK_END in libc
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
public:
/**
* @brief Ctor
*
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
* @param buffer_size Size of the source buffer
*/
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
~MemoryFixSizeBuffer() override = default;
std::size_t Read(void *ptr, std::size_t size) override {
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, std::size_t size) override {
if (size == 0) return;
CHECK_LE(curr_ptr_ + size, buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(std::size_t pos) override {
if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_;
} else {
curr_ptr_ = static_cast<std::size_t>(pos);
}
}
/**
* @brief Current position in the buffer (stream).
*/
std::size_t Tell() override { return curr_ptr_; }
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
protected:
/*! \brief in memory buffer */
char *p_buffer_{nullptr};
/*! \brief current pointer */
std::size_t buffer_size_{0};
/*! \brief current pointer */
std::size_t curr_ptr_{0};
};
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public SeekStream {
public:
explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, size_t size) override {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
size_t Tell() override {
return curr_ptr_;
}
virtual bool AtEnd() const {
return curr_ptr_ == p_buffer_->length();
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
} // namespace rabit::utils
#endif // RABIT_INTERNAL_IO_H_

View File

@ -1,234 +0,0 @@
/*!
* Copyright (c) 2014-2019 by Contributors
* \file rabit-inl.h
* \brief implementation of inline template function for rabit interface
*
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_RABIT_INL_H_
#define RABIT_INTERNAL_RABIT_INL_H_
// use engine for implementation
#include <vector>
#include <string>
#include "rabit/internal/io.h"
#include "rabit/internal/utils.h"
#include "rabit/rabit.h"
namespace rabit {
namespace engine {
namespace mpi {
// template function to translate type to enum indicator
template<typename DType>
inline DataType GetType();
template<>
inline DataType GetType<char>() {
return kChar;
}
template<>
inline DataType GetType<unsigned char>() {
return kUChar;
}
template<>
inline DataType GetType<int>() {
return kInt;
}
template<>
inline DataType GetType<unsigned int>() { // NOLINT(*)
return kUInt;
}
template<>
inline DataType GetType<long>() { // NOLINT(*)
return kLong;
}
template<>
inline DataType GetType<unsigned long>() { // NOLINT(*)
return kULong;
}
template<>
inline DataType GetType<float>() {
return kFloat;
}
template<>
inline DataType GetType<double>() {
return kDouble;
}
template<>
inline DataType GetType<long long>() { // NOLINT(*)
return kLongLong;
}
template<>
inline DataType GetType<unsigned long long>() { // NOLINT(*)
return kULongLong;
}
} // namespace mpi
} // namespace engine
namespace op {
struct Max {
static const engine::mpi::OpType kType = engine::mpi::kMax;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
if (dst < src) dst = src;
}
};
struct Min {
static const engine::mpi::OpType kType = engine::mpi::kMin;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
if (dst > src) dst = src;
}
};
struct Sum {
static const engine::mpi::OpType kType = engine::mpi::kSum;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst += src;
}
};
struct BitAND {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst &= src;
}
};
struct BitOR {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst |= src;
}
};
struct BitXOR {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst ^= src;
}
};
template <typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) {
const DType *src = static_cast<const DType *>(src_);
DType *dst = (DType *)dst_; // NOLINT(*)
for (int i = 0; i < len; i++) {
OP::Reduce(dst[i], src[i]);
}
}
} // namespace op
// initialize the rabit engine
inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv);
}
// finalize the rabit engine
inline bool Finalize() {
return engine::Finalize();
}
// get the rank of the previous worker in ring topology
inline int GetRingPrevRank() {
return engine::GetEngine()->GetRingPrevRank();
}
// get the rank of current process
inline int GetRank() {
return engine::GetEngine()->GetRank();
}
// the the size of the world
inline int GetWorldSize() {
return engine::GetEngine()->GetWorldSize();
}
// whether rabit is distributed
inline bool IsDistributed() {
return engine::GetEngine()->IsDistributed();
}
// get the name of current processor
inline std::string GetProcessorName() {
return engine::GetEngine()->GetHost();
}
// broadcast data to all other nodes from root
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
}
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
size_t size = sendrecv_data->size();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->size() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
}
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
// perform inplace Allreduce
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
}
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
inline void InvokeLambda(void *fun) {
(*static_cast<std::function<void()>*>(fun))();
}
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun);
}
// Performs inplace Allgather
template<typename DType>
inline void Allgather(DType *sendrecvbuf,
size_t totalSize,
size_t beginIndex,
size_t sizeNodeSlice,
size_t sizePrevSlice) {
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType));
}
#endif // C++11
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}
#ifndef RABIT_STRICT_CXX98_
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
msg.resize(strlen(msg.c_str()));
TrackerPrint(msg);
}
#endif // RABIT_STRICT_CXX98_
// deprecated, planned for removal after checkpoing from JVM package is removed.
inline int LoadCheckPoint() { return engine::GetEngine()->LoadCheckPoint(); }
// deprecated, increase internal version number
inline void CheckPoint() { engine::GetEngine()->CheckPoint(); }
// return the version number of currently stored model
inline int VersionNumber() {
return engine::GetEngine()->VersionNumber();
}
} // namespace rabit
#endif // RABIT_INTERNAL_RABIT_INL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file socket.h
* \author Tianqi Chen
*/
@ -95,7 +95,10 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
template <typename E>
std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
if ((revents & POLLERR) != 0) {
return xgboost::system::FailWithCode("Poll error condition.");
auto err = errno;
auto str = strerror(err);
return xgboost::system::FailWithCode(std::string{"Poll error condition:"} + std::string{str} +
" code:" + std::to_string(err));
}
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
@ -211,12 +214,7 @@ struct PollHelper {
}
auto revents = pfd.revents & pfd.events;
if (!revents) {
// FIXME(jiamingy): remove this once rabit is replaced.
fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
fds[pfd.fd].events = revents;
}
return xgboost::collective::Success();
}

View File

@ -1,146 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file utils.h
* \brief simple utils to support the code
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_UTILS_H_
#define RABIT_INTERNAL_UTILS_H_
#include <rabit/base.h>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <stdexcept>
#include <string>
#include <vector>
#include "dmlc/io.h"
#include "xgboost/logging.h"
#if !defined(__GNUC__) || defined(__FreeBSD__)
#define fopen64 std::fopen
#endif // !defined(__GNUC__) || defined(__FreeBSD__)
#ifndef _MSC_VER
#ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif // _FILE_OFFSET_BITS == 32
#endif // _FILE_OFFSET_BITS
#ifdef __APPLE__
#define off64_t off_t
#define fopen64 std::fopen
#endif // __APPLE__
extern "C" {
#include <sys/types.h>
}
#endif // _MSC_VER
#include <cinttypes>
namespace rabit {
/*! \brief namespace for helper utils of the project */
namespace utils {
/*! \brief error message buffer length */
const int kPrintBuffer = 1 << 12;
/* \brief Case-insensitive string comparison */
inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
#ifdef _MSC_VER
return _stricmp(s1, s2);
#else // _MSC_VER
return strcasecmp(s1, s2);
#endif // _MSC_VER
}
/* \brief parse config string too bool*/
inline bool StringToBool(const char* s) {
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
}
/*! \brief printf, prints messages to the console */
inline void Printf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
LOG(CONSOLE) << msg;
}
/*! \brief assert a condition is true, use this to handle debug information */
inline void Assert(bool exp, const char *fmt, ...) {
if (!exp) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
LOG(FATAL) << msg;
}
}
/*!\brief same as assert, but this is intended to be used as a message for users */
inline void Check(bool exp, const char *fmt, ...) {
if (!exp) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
LOG(FATAL) << msg;
}
}
/*! \brief report error message, same as check */
inline void Error(const char *fmt, ...) {
{
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
LOG(FATAL) << msg;
}
}
} // namespace utils
// Can not use std::min on Windows with msvc due to:
// error C2589: '(': illegal token on right side of '::'
template <typename T>
auto Min(T const& l, T const& r) {
return l < r ? l : r;
}
// same with Min
template <typename T>
auto Max(T const& l, T const& r) {
return l > r ? l : r;
}
// easy utils that can be directly accessed in xgboost
/*! \brief get the beginning address of a vector */
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return nullptr;
} else {
return &vec[0];
}
}
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return nullptr;
return &str[0];
}
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return nullptr;
return &str[0];
}
} // namespace rabit
#endif // RABIT_INTERNAL_UTILS_H_

View File

@ -1,237 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file rabit.h
* \brief This file defines rabit's Allreduce/Broadcast interface
* The rabit engine contains the actual implementation
* Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant),
*
* rabit.h and serializable.h is all what the user needs to use the rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_RABIT_H_ // NOLINT(*)
#define RABIT_RABIT_H_ // NOLINT(*)
#include <string>
#include <vector>
#include <functional>
// engine definition of rabit, defines internal implementation
// to use rabit interface, there is no need to read engine.h
// rabit.h and serializable.h are enough to use the interface
#include "./internal/engine.h"
/*! \brief rabit namespace */
namespace rabit {
/*!
* \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h
*/
using Stream = dmlc::Stream;
/*!
* \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h
*/
using Serializable = dmlc::Serializable;
/*!
* \brief reduction operators namespace
*/
namespace op {
/*!
* \class rabit::op::Max
* \brief maximum reduction operator
*/
struct Max;
/*!
* \class rabit::op::Min
* \brief minimum reduction operator
*/
struct Min;
/*!
* \class rabit::op::Sum
* \brief sum reduction operator
*/
struct Sum;
/*!
* \class rabit::op::BitAND
* \brief bitwise AND reduction operator
*/
struct BitAND;
/*!
* \class rabit::op::BitOR
* \brief bitwise OR reduction operator
*/
struct BitOR;
/*!
* \class rabit::op::BitXOR
* \brief bitwise XOR reduction operator
*/
struct BitXOR;
} // namespace op
/*!
* \brief initializes rabit, call this once at the beginning of your program
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if initialized successfully, otherwise false
*/
inline bool Init(int argc, char *argv[]);
/*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
* \return true if finalized successfully, otherwise false
*/
inline bool Finalize();
/*! \brief gets rank of the current process
* \return rank number of worker*/
inline int GetRank();
/*! \brief gets total number of processes
* \return total world size*/
inline int GetWorldSize();
/*! \brief whether rabit env is in distributed mode
* \return is distributed*/
inline bool IsDistributed();
/*! \brief gets processor's name
* \return processor name*/
inline std::string GetProcessorName();
/*!
* \brief prints the msg to the tracker,
* this function can be used to communicate progress information to
* the user who monitors the tracker
* \param msg the message to be printed
*/
inline void TrackerPrint(const std::string &msg);
#ifndef RABIT_STRICT_CXX98_
/*!
* \brief prints the msg to the tracker, this function may not be available
* in very strict c++98 compilers, though it usually is.
* this function can be used to communicate progress information to
* the user who monitors the tracker
* \param fmt the format string
*/
inline void TrackerPrintf(const char *fmt, ...);
#endif // RABIT_STRICT_CXX98_
/*!
* \brief broadcasts a memory region to every node from the root
*
* Example: int a = 1; Broadcast(&a, sizeof(a), root);
* \param sendrecv_data the pointer to the send/receive buffer,
* \param size the data size
* \param root the process root
*/
inline void Broadcast(void *sendrecv_data, size_t size, int root);
/*!
* \brief broadcasts an std::vector<DType> to every node from root
* \param sendrecv_data the pointer to send/receive vector,
* for the receiver, the vector does not need to be pre-allocated
* \param root the process root
* \tparam DType the data type stored in the vector, has to be a simple data type
* that can be directly transmitted by sending the sizeof(DType)
*/
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
/*!
* \brief broadcasts a std::string to every node from the root
* \param sendrecv_data the pointer to the send/receive buffer,
* for the receiver, the vector does not need to be pre-allocated
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
* \param root the process root
*/
inline void Broadcast(std::string *sendrecv_data, int root);
/*!
* \brief performs in-place Allreduce on sendrecvbuf
* this function is NOT thread-safe
*
* Example Usage: the following code does an Allreduce and outputs the sum as the result
* \code{.cpp}
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size());
* ...
* \endcode
*
* \param sendrecvbuf buffer for both sending and receiving data
* \param count number of elements to be reduced
* \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to pass into the lazy preprocessing function
* \tparam OP see namespace op, reduce operator
* \tparam DType data type
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
template<typename DType>
inline void Allgather(DType *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice);
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* with a prepare function specified by a lambda function
*
* Example Usage:
* \code{.cpp}
* // the following code does an Allreduce and outputs the sum as the result
* vector<int> data(10);
* ...
* Allreduce<op::Sum>(&data[0], data.size(), [&]() {
* for (int i = 0; i < 10; ++i) {
* data[i] = i;
* }
* });
* ...
* \endcode
* \param sendrecvbuf buffer for both sending and receiving data
* \param count number of elements to be reduced
* \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked
* by the function before performing Allreduce in order to initialize the data in sendrecvbuf.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \tparam OP see namespace op, reduce operator
* \tparam DType data type
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief deprecated, planned for removal after checkpoing from JVM package is removed.
*/
inline int LoadCheckPoint();
/*!
* \brief deprecated, planned for removal after checkpoing from JVM package is removed.
*/
inline void CheckPoint();
/*!
* \return version number of the current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber();
} // namespace rabit
// implementation of template functions
#include "./internal/rabit-inl.h"
#endif // RABIT_RABIT_H_ // NOLINT(*)

View File

@ -1,26 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
#ifndef RABIT_SERIALIZABLE_H_
#define RABIT_SERIALIZABLE_H_
#include <vector>
#include <string>
#include "rabit/internal/utils.h"
namespace rabit {
/*!
* \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h
*/
using Stream = dmlc::Stream ;
/*!
* \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h
*/
using Serializable = dmlc::Serializable;
} // namespace rabit
#endif // RABIT_SERIALIZABLE_H_

View File

@ -1,997 +0,0 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include "allreduce_base.h"
#include "rabit/base.h"
#include "rabit/internal/rabit-inl.h"
#include "xgboost/collective/result.h"
#ifndef _WIN32
#include <netinet/tcp.h>
#endif // _WIN32
#include <cstring>
#include <map>
namespace rabit::engine {
// constructor
AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
rank = 0;
world_size = -1;
connect_retry = 5;
hadoop_mode = false;
version_number = 0;
// 32 K items
reduce_ring_mincount = 32 << 10;
// 1M reducer size each time
tree_reduce_minsize = 1 << 20;
// tracker URL
task_id = "NULL";
err_link = nullptr;
dmlc_role = "worker";
this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible environment variable of interest
// include dmlc support direct variables
env_vars.emplace_back("DMLC_TASK_ID");
env_vars.emplace_back("DMLC_ROLE");
env_vars.emplace_back("DMLC_NUM_ATTEMPT");
env_vars.emplace_back("DMLC_TRACKER_URI");
env_vars.emplace_back("DMLC_TRACKER_PORT");
env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
}
// initialization function
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from environment variables
// handler to get variables from env
for (auto & env_var : env_vars) {
const char *value = getenv(env_var.c_str());
if (value != nullptr) {
this->SetParam(env_var.c_str(), value);
}
}
// pass in arguments override env variable.
for (int i = 0; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
this->SetParam(name, val);
}
}
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == nullptr) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode) {
utils::Check(task_id != nullptr,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != nullptr) {
this->SetParam("rabit_task_id", task_id);
this->SetParam("rabit_hadoop_mode", "1");
}
const char *attempt_id = getenv("mapred_task_id");
if (attempt_id != nullptr) {
const char *att = strrchr(attempt_id, '_');
int num_trial;
if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1);
}
}
// handling for hadoop
const char *num_task = getenv("mapred_map_tasks");
if (num_task == nullptr) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode) {
utils::Check(num_task != nullptr,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != nullptr) {
this->SetParam("rabit_world_size", num_task);
}
}
if (dmlc_role != "worker") {
LOG(FATAL) << "Rabit Module currently only works with dmlc worker";
}
// clear the setting before start reconnection
this->rank = -1;
//---------------------
// start socket
xgboost::system::SocketStartup();
utils::Assert(all_links.size() == 0, "can only call Init once");
auto rc = xgboost::collective::GetHostName(&this->host_uri);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
// get information from tracker
rc = this->ReConnectLinks();
if (rc.OK()) {
return true;
}
LOG(FATAL) << rc.Report();
return false;
}
bool AllreduceBase::Shutdown() {
try {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
SafeColl(all_link.sock.Close());
}
}
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
SafeColl(tracker.Close());
xgboost::system::SocketFinalize();
return true;
} catch (std::exception const &e) {
LOG(WARNING) << "Failed to shutdown due to" << e.what();
return false;
}
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
SafeColl(tracker.Close());
}
// util to parse data with unit suffix
inline size_t ParseUnit(const char *name, const char *val) {
char unit;
unsigned long amt; // NOLINT(*)
int n = sscanf(val, "%lu%c", &amt, &unit);
size_t amount = amt;
if (n == 2) {
switch (unit) {
case 'B': return amount;
case 'K': return amount << 10UL;
case 'M': return amount << 20UL;
case 'G': return amount << 30UL;
default: utils::Error("invalid format for %s", name); return 0;
}
} else if (n == 1) {
return amount;
} else {
utils::Error("invalid format for %s," \
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
return 0;
}
}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val;
if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val);
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
if (!strcmp(name, "rabit_tree_reduce_minsize")) tree_reduce_minsize = atoi(val);
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
reduce_ring_mincount = atoi(val);
utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
}
if (!strcmp(name, "rabit_reduce_buffer")) {
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
}
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
if (!strcmp(name, "rabit_timeout")) {
rabit_timeout = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_timeout_sec")) {
timeout_sec = std::chrono::seconds(atoi(val));
utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second");
}
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true")) {
rabit_enable_tcp_no_delay = true;
} else {
rabit_enable_tcp_no_delay = false;
}
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker(
xgboost::collective::TCPSocket *out) const {
int magic = kMagic;
// get information from tracker
xgboost::collective::TCPSocket &tracker = *out;
auto rc =
Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker);
if (!rc.OK()) {
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
}
using utils::Assert;
if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) {
return xgboost::collective::Fail("Failed to send the verification number.");
}
if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) {
return xgboost::collective::Fail("Failed to recieve the verification number.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) {
return xgboost::collective::Fail("Failed to send the local rank back to the tracker.");
}
if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) {
return xgboost::collective::Fail("Failed to send the world size back to the tracker.");
}
if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) {
return xgboost::collective::Fail("Failed to send the task ID back to the tracker.");
}
return xgboost::collective::Success();
}
/*!
* \brief connect to the tracker to fix the missing links
* this function is also used when the engine start up
*/
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0;
world_size = 1;
return xgboost::collective::Success();
}
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
}
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.Send(xgboost::StringView{cmd});
try {
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
sizeof(parent_rank), "ReConnectLink failure 4");
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
if (rank == -1) {
LOG(FATAL) << "tracker got overwhelmed and not able to assign correct rank";
}
LOG(CONSOLE) << "task " << task_id << " got new rank " << rank;
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
std::int32_t port{0};
SafeColl(sock_listen.BindHost(&port));
SafeColl(sock_listen.Listen());
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
for (auto & all_link : all_links) {
SafeColl(all_link.sock.Close());
}
// tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
"ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
SafeColl(r.sock.Close());
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (auto & all_link : all_links) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(), "Override a link that is active");
all_link.sock = std::move(r.sock);
match = true;
break;
}
}
if (!match) all_links.emplace_back(std::move(r));
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
SafeColl(tracker.Close());
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (auto & all_link : all_links) {
if (all_link.rank == r.rank) {
utils::Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_link.sock = std::move(r.sock);
match = true;
break;
}
}
if (!match) all_links.emplace_back(std::move(r));
}
SafeColl(sock_listen.Close());
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
for (auto &all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
CHECK(all_link.sock.NonBlocking(true).OK());
CHECK(all_link.sock.SetKeepAlive().OK());
if (rabit_enable_tcp_no_delay) {
CHECK(all_link.sock.SetNoDelay().OK());
}
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_link);
}
if (all_link.rank == prev_rank) ring_prev = &all_link;
if (all_link.rank == next_rank) ring_next = &all_link;
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != nullptr,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link");
return xgboost::collective::Success();
} catch (const std::exception& e) {
std::stringstream ss;
ss << "Failed in ReconnectLink " << e.what();
return xgboost::collective::Fail(ss.str());
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
* NOTE on Allreduce:
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
if (count > reduce_ring_mincount) {
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
} else {
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.Size() == 0 || count == 0) return kSuccess;
// total size of message
const size_t total_size = type_nbytes * count;
// number of links
const int nlink = static_cast<int>(links.Size());
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// size of space that we already performs reduce in up pass
size_t size_up_reduce = 0;
// size of space that we have already passed to parent
size_t size_up_out = 0;
// size of message we received, and send in the down pass
size_t size_down_in = 0;
// minimal size of each reducer
const size_t eachreduce = (tree_reduce_minsize / type_nbytes * type_nbytes);
// initialize the link ring-buffer and pointer
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
}
links[i].ResetSize();
}
// if no children, no need to reduce
if (nlink == static_cast<int>(parent_index != -1)) {
size_up_reduce = total_size;
}
// while we have not passed the messages out
while (true) {
// select helper
bool finished = true;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
watcher.WatchException(links[i].sock);
finished = false;
}
}
}
// finish running allreduce
if (finished) {
break;
}
// select must return
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
// make sure to receive minimal reducer size
// since each child reduce and sends the minimal reducer size
while (links[i].size_read < total_size
&& links[i].size_read - size_up_reduce < eachreduce) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
// this node have children, perform reduce
if (nlink > static_cast<int>(parent_index != -1)) {
size_t buffer_size = 0;
// do upstream reduce
size_t max_reduce = total_size;
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
max_reduce = std::min(max_reduce, links[i].size_read);
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
"buffer size inconsistent");
buffer_size = links[i].buffer_size;
}
}
utils::Assert(buffer_size != 0, "must assign buffer_size");
// round to type_n4bytes
max_reduce = (max_reduce / type_nbytes * type_nbytes);
// if max reduce is less than total size, we reduce multiple times of
// each reduce size
if (max_reduce < total_size) {
max_reduce = max_reduce - max_reduce % eachreduce;
}
// perform reduce, can be at most two rounds
while (size_up_reduce < max_reduce) {
// start position
size_t start = size_up_reduce % buffer_size;
// perform read till end of buffer
size_t nread = std::min(buffer_size - start,
max_reduce - size_up_reduce);
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
reducer(links[i].buffer_head + start,
sendrecvbuf + size_up_reduce,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
}
}
size_up_reduce += nread;
}
}
if (parent_index != -1) {
// pass message up to parent, can pass data that are already been reduced
if (size_up_out < size_up_reduce) {
ssize_t len = links[parent_index].sock.
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
if (len != -1) {
size_up_out += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
return ReportError(&links[parent_index], ret);
}
}
}
// read data from parent
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
size_t left_size = total_size-size_down_in;
size_t reduce_size_min = std::min(left_size, eachreduce);
size_t recved = 0;
while (recved < reduce_size_min) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out,
"Allreduce: boundary error");
recved+=len;
// if it receives more data than each reduce, it means the next block is sent.
// we double the reduce_size_min or add to left_size
while (recved > reduce_size_min) {
reduce_size_min += std::min(left_size-reduce_size_min, eachreduce);
}
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
return ReportError(&links[parent_index], ret);
}
}
}
}
} else {
// this is root, can use reduce as most recent point
size_down_in = size_up_out = size_up_reduce;
}
// can pass message down to children
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write < size_down_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param total_size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.Size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.Size());
// size of space already read from data
size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root
int in_link = -2;
// initialize the link statistics
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
// root have all the data
if (this->rank == root) {
size_in = total_size;
in_link = -1;
}
// while we have not passed the messages out
while (true) {
bool finished = true;
// select helper
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
}
// finish running
if (finished) break;
// select
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
size_in = links[i].size_read;
if (size_in != 0) {
in_link = i; break;
}
}
}
} else {
// read from in link
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
}
size_in = links[in_link].size_read;
}
}
// send data to all out-link
for (int i = 0; i < nlink; ++i) {
if (i != in_link && links[i].size_write < size_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
const size_t stop_read = total_size + slice_begin;
const size_t stop_write = total_size + slice_begin - size_prev_slice;
size_t write_ptr = slice_begin;
size_t read_ptr = slice_end;
while (true) {
// select helper
bool finished = true;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) {
break;
}
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
if (len != -1) {
read_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
auto err = ReportError(&next, ret);
return err;
}
}
}
if (write_ptr < read_ptr && write_ptr != stop_write) {
size_t size = std::min(read_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
auto err = ReportError(&prev, ret);
return err;
}
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
* and will return the cause of failure
*
* Ring-based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
AllreduceBase::ReturnType
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// total size of message
const size_t total_size = type_nbytes * count;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t r = static_cast<size_t>(next.rank);
size_t write_ptr = std::min(r * step, count) * type_nbytes;
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
size_t reduce_ptr = read_ptr;
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// position to stop reading
const size_t stop_read = total_size + write_ptr;
// position to stop writing
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
if (stop_write > stop_read) {
stop_write -= total_size;
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
}
// use ring buffer in next position
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
// set size_read to read pointer for ring buffer to work properly
next.size_read = read_ptr;
while (true) {
// select helper
bool finished = true;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) {
break;
}
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
}
// sync the rate
read_ptr = next.size_read;
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
const size_t buffer_size = next.buffer_size;
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
while (reduce_ptr < max_reduce) {
size_t bstart = reduce_ptr % buffer_size;
size_t nread = std::min(buffer_size - bstart,
max_reduce - reduce_ptr);
size_t rstart = reduce_ptr % total_size;
nread = std::min(nread, total_size - rstart);
reducer(next.buffer_head + bstart,
sendrecvbuf + rstart,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
reduce_ptr += nread;
}
}
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
if (ret != kSuccess) return ret;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t begin = std::min(rank * step, count) * type_nbytes;
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
// previous rank
int prank = ring_prev->rank;
// get rank of previous
return TryAllgatherRing
(sendrecvbuf_, type_nbytes * count,
begin, end,
(std::min((prank + 1) * step, count) -
std::min(prank * step, count)) * type_nbytes);
}
} // namespace rabit::engine

View File

@ -1,501 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h
* \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction.
*
* This implementation provides basic utility of AllReduce and Broadcast
* without considering node failure
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <algorithm>
#include <functional>
#include <future>
#include <string>
#include <vector>
#include "rabit/internal/engine.h"
#include "rabit/internal/socket.h"
#include "rabit/internal/utils.h"
#include "xgboost/collective/result.h"
#ifdef RABIT_CXXTESTDEFS_H
#define private public
#define protected public
#endif // RABIT_CXXTESTDEFS_H
namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
namespace engine {
/*! \brief implementation of basic Allreduce engine */
class AllreduceBase : public IEngine {
public:
// magic number to verify server
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase();
virtual ~AllreduceBase() = default;
// initialize the manager
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual bool Shutdown();
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
void TrackerPrint(const std::string &msg) override;
/*! \brief get rank of previous node in ring topology*/
int GetRingPrevRank() const override {
return ring_prev->rank;
}
/*! \brief get rank */
int GetRank() const override {
return rank;
}
/*! \brief get rank */
int GetWorldSize() const override {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief whether is distributed or not */
bool IsDistributed() const override {
return tracker_uri != "NULL";
}
/*! \brief get rank */
std::string GetHost() const override {
return host_uri;
}
/*!
* \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice) override {
if (world_size == 1 || world_size == -1) {
return;
}
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice) == kSuccess,
"AllgatherRing failed");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr) override {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
kSuccess,
"Allreduce failed");
}
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed");
}
/*!
* \brief deprecated
* \sa CheckPoint, VersionNumber
*/
int LoadCheckPoint() override { return 0; }
// deprecated, increase internal version number
void CheckPoint() override { version_number += 1; }
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
int VersionNumber() const override {
return version_number;
}
/*!
* \brief report current status to the job tracker
* depending on the job tracker we are in
*/
inline void ReportStatus() const {
if (hadoop_mode != 0) {
LOG(CONSOLE) << "reporter:status:Rabit Phase[" << version_number << "] Operation " << seq_counter << "\n";
}
}
protected:
/*! \brief enumeration of possible returning results from Try functions */
enum ReturnTypeEnum {
/*! \brief execution is successful */
kSuccess,
/*! \brief a link was reset by peer */
kConnReset,
/*! \brief received a zero length message */
kRecvZeroLen,
/*! \brief a neighbor node go down, the connection is dropped */
kSockError,
/*!
* \brief another node which is not my neighbor go down,
* get Out-of-Band exception notification from my neighbor
*/
kGetExcept
};
/*! \brief struct return type to avoid implicit conversion to int/bool */
struct ReturnType {
/*! \brief internal return type */
ReturnTypeEnum value;
// constructor
ReturnType() = default;
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
inline bool operator==(const ReturnTypeEnum &v) const {
return value == v;
}
inline bool operator!=(const ReturnTypeEnum &v) const {
return value != v;
}
};
/*! \brief translate errno to return type */
static ReturnType Errno2Return() {
int errsv = xgboost::system::LastError();
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
#ifdef _WIN32
if (errsv == WSAEWOULDBLOCK) return kSuccess;
if (errsv == WSAECONNRESET) return kConnReset;
#endif // _WIN32
if (errsv == ECONNRESET) return kConnReset;
return kSockError;
}
// link record to a neighbor
struct LinkRecord {
public:
// socket to get data from/to link
xgboost::collective::TCPSocket sock;
// rank of the node in this link
int rank;
// size of data readed from link
size_t size_read;
// size of data sent to the link
size_t size_write;
// pointer to buffer head
char *buffer_head {nullptr};
// buffer size, in bytes
size_t buffer_size {0};
// constructor
LinkRecord() = default;
// initialize buffer
void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
auto to = Min(reduce_buffer_size, n);
buffer_.resize(to);
// make sure align to type_nbytes
buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size,
"too large type_nbytes=%lu, buffer_size=%lu",
type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
// reset the recv and sent size
inline void ResetSize() {
size_write = size_read = 0;
}
/*!
* \brief read data into ring-buffer, with care not to existing useful override data
* position after protect_start
* \param protect_start all data start from protect_start is still needed in buffer
* read shall not override this
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
* \return the type of reading
*/
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = max_size_read - size_read;
nmax = Min(nmax, buffer_size - ngap);
nmax = Min(nmax, buffer_size - offset);
if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
return kSuccess;
}
/*!
* \brief read data into array,
* this function can not be used together with ReadToRingBuffer
* a link can either read into the ring buffer, or existing array
* \param max_size maximum size of array
* \return true if it is a successful read, false if there is some error happens, check errno
*/
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
if (max_size == size_read) return kSuccess;
char *p = static_cast<char*>(recvbuf_);
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
return kSuccess;
}
/*!
* \brief write data in array to sock
* \param sendbuf_ head of array
* \param max_size maximum size of array
* \return true if it is a successful write, false if there is some error happens, check errno
*/
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
const char *p = static_cast<const char*>(sendbuf_);
ssize_t len = sock.Send(p + size_write, max_size - size_write);
if (len == -1) return Errno2Return();
size_write += static_cast<size_t>(len);
return kSuccess;
}
private:
// recv buffer to get data from child
// aligned with 64 bits, will be able to perform 64 bits operations freely
std::vector<uint64_t> buffer_;
};
/*!
* \brief simple data structure that works like a vector
* but takes reference instead of space
*/
struct RefLinkVector {
std::vector<LinkRecord*> plinks;
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t Size() const {
return plinks.size();
}
};
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
/*!
* \brief connect to the tracker to fix the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
[[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
* NOTE on Allreduce:
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin, size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
*
* after the function, node k get k-th segment of the reduction result
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
* where step = ceil(count / world_size)
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm, reduce-scatter + allgather
*
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief function used to report error when a link goes wrong
* \param link the pointer to the link who causes the error
* \param err the error type
*/
inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
err_link = link; return err;
}
//---- data structure related to model ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter{0}; // NOLINT
// version number of model
int version_number {0}; // NOLINT
// whether the job is running in Hadoop
bool hadoop_mode; // NOLINT
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index; // NOLINT
// rank of parent node, can be -1
int parent_rank; // NOLINT
// sockets of all links this connects to
std::vector<LinkRecord> all_links; // NOLINT
// used to record the link where things goes wrong
LinkRecord *err_link; // NOLINT
// all the links in the reduction tree connection
RefLinkVector tree_links; // NOLINT
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next; // NOLINT
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars; // NOLINT
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id; // NOLINT
// uri of current host, to be set by Init
std::string host_uri; // NOLINT
// uri of tracker
std::string tracker_uri; // NOLINT
// role in dmlc jobs
std::string dmlc_role; // NOLINT
// port of tracker address
int tracker_port; // NOLINT
// reduce buffer size
size_t reduce_buffer_size; // NOLINT
// reduction method
int reduce_method; // NOLINT
// minimum count of cells to use ring based method
size_t reduce_ring_mincount; // NOLINT
// minimum block size per tree reduce
size_t tree_reduce_minsize; // NOLINT
// current rank
int rank; // NOLINT
// world size
int world_size; // NOLINT
// connect retry time
int connect_retry; // NOLINT
// by default, if rabit worker not recover in half an hour exit
std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT
// flag to enable rabit_timeout
bool rabit_timeout = false; // NOLINT
// Enable TCP node delay
bool rabit_enable_tcp_no_delay = false; // NOLINT
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_BASE_H_

View File

@ -1,147 +0,0 @@
/*!
* Copyright by Contributors
* \file allreduce_mock.h
* \brief Mock test module of AllReduce engine,
* insert failures in certain call point, to test if the engine is robust to failure
*
* \author Ignacio Cano, Tianqi Chen
*/
#ifndef RABIT_ALLREDUCE_MOCK_H_
#define RABIT_ALLREDUCE_MOCK_H_
#include <vector>
#include <map>
#include <sstream>
#include <dmlc/timer.h>
#include "rabit/internal/engine.h"
#include "allreduce_base.h"
namespace rabit {
namespace engine {
class AllreduceMock : public AllreduceBase {
public:
// constructor
AllreduceMock() {
num_trial_ = 0;
force_local_ = 0;
report_stats_ = 0;
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
}
// destructor
~AllreduceMock() override = default;
void SetParam(const char *name, const char *val) override {
AllreduceBase::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
if (!strcmp(name, "force_local")) force_local_ = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter");
mock_map_[k] = 1;
}
}
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
double tstart = dmlc::GetTime();
AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer,
prepare_fun, prepare_arg);
tsum_allreduce_ += dmlc::GetTime() - tstart;
}
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
double tstart = dmlc::GetTime();
AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end,
size_prev_slice);
tsum_allgather_ += dmlc::GetTime() - tstart;
}
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root);
}
int LoadCheckPoint() override {
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
time_checkpoint_ = dmlc::GetTime();
if (force_local_ == 0) {
return AllreduceBase::LoadCheckPoint();
} else {
return AllreduceBase::LoadCheckPoint();
}
}
void CheckPoint() override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
double tstart = dmlc::GetTime();
double tbet_chkpt = tstart - time_checkpoint_;
AllreduceBase::CheckPoint();
time_checkpoint_ = dmlc::GetTime();
double tcost = dmlc::GetTime() - tstart;
if (report_stats_ != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size="
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
}
protected:
// force checkpoint to local
int force_local_;
// whether report statistics
int report_stats_;
// sum of allreduce
double tsum_allreduce_;
// sum of allgather
double tsum_allgather_;
double time_checkpoint_;
private:
// key to identify the mock stage
struct MockKey {
int rank;
int version;
int seqno;
int ntrial;
MockKey() = default;
MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const {
return rank == b.rank &&
version == b.version &&
seqno == b.seqno &&
ntrial == b.ntrial;
}
inline bool operator<(const MockKey &b) const {
if (rank != b.rank) return rank < b.rank;
if (version != b.version) return version < b.version;
if (seqno != b.seqno) return seqno < b.seqno;
return ntrial < b.ntrial;
}
};
// number of failure trials
int num_trial_;
// record all mock actions
std::map<MockKey, int> mock_map_;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) {
if (mock_map_.count(key) != 0) {
num_trial_ += 1;
// data processing frameworks runs on shared process
throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name);
}
}
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H_

View File

@ -1,106 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.cc
* \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#include <rabit/base.h>
#include <dmlc/thread_local.h>
#include <memory>
#include "rabit/internal/engine.h"
#include "allreduce_base.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK
using Manager = AllreduceBase;
#else
typedef AllreduceMock Manager;
#endif // RABIT_USE_MOCK
#else
typedef AllreduceBase Manager;
#endif // RABIT_USE_BASE
/*! \brief entry to to easily hold returning information */
struct ThreadLocalEntry {
/*! \brief stores the current engine */
std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */
bool initialized{false};
/*! \brief constructor */
ThreadLocalEntry() = default;
};
// define the threadlocal store.
using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
ThreadLocalEntry* e = EngineThreadLocal::Get();
if (e->engine.get() == nullptr) {
e->initialized = true;
e->engine.reset(new Manager());
return e->engine->Init(argc, argv);
} else {
return true;
}
}
/*! \brief finalize syncrhonization module */
bool Finalize() {
ThreadLocalEntry* e = EngineThreadLocal::Get();
if (e->engine.get() != nullptr) {
if (e->engine->Shutdown()) {
e->engine.reset(nullptr);
e->initialized = false;
return true;
} else {
return false;
}
} else {
return true;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() {
// un-initialized default manager.
static AllreduceBase default_manager;
ThreadLocalEntry* e = EngineThreadLocal::Get();
IEngine* ptr = e->engine.get();
if (ptr == nullptr) {
utils::Check(!e->initialized, "the rabit has not been initialized");
return &default_manager;
} else {
return ptr;
}
}
// perform in-place allgather, on sendrecvbuf
void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice);
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType,
mpi::OpType ,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@ -1,14 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen
*/
// define use MOCK, os we will use mock Manager
#define NOMINMAX
// switch engine to AllreduceMock
#define RABIT_USE_MOCK
#include <rabit/base.h>
#include "allreduce_mock.h"
#include "engine.cc"

View File

@ -1,342 +0,0 @@
// Copyright by Contributors
// implementations in ctypes
#include <rabit/base.h>
#include <cstring>
#include <string>
#include "rabit/rabit.h"
#include "rabit/c_api.h"
#include "../../src/c_api/c_api_error.h"
namespace rabit {
namespace c_api {
// helper use to avoid BitOR operator
template<typename OP, typename DType>
struct FHelper {
static void
Allreduce(DType *senrecvbuf_,
size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
rabit::Allreduce<OP>(senrecvbuf_, count,
prepare_fun, prepare_arg);
}
};
template<typename DType>
struct FHelper<op::BitAND, DType> {
static void
Allreduce(DType *,
size_t ,
void (*)(void *arg),
void *) {
utils::Error("DataType does not support bitwise AND operation");
}
};
template<typename DType>
struct FHelper<op::BitOR, DType> {
static void
Allreduce(DType *,
size_t ,
void (*)(void *arg),
void *) {
utils::Error("DataType does not support bitwise OR operation");
}
};
template<typename DType>
struct FHelper<op::BitXOR, DType> {
static void
Allreduce(DType *,
size_t ,
void (*)(void *arg),
void *) {
utils::Error("DataType does not support bitwise XOR operation");
}
};
template<typename OP>
void Allreduce(void *sendrecvbuf_,
size_t count,
engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi; // NOLINT
switch (enum_dtype) {
case kChar:
rabit::Allreduce<OP>
(static_cast<char*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kUChar:
rabit::Allreduce<OP>
(static_cast<unsigned char*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kInt:
rabit::Allreduce<OP>
(static_cast<int*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kUInt:
rabit::Allreduce<OP>
(static_cast<unsigned*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kLong:
rabit::Allreduce<OP>
(static_cast<long*>(sendrecvbuf_), // NOLINT(*)
count, prepare_fun, prepare_arg);
return;
case kULong:
rabit::Allreduce<OP>
(static_cast<unsigned long*>(sendrecvbuf_), // NOLINT(*)
count, prepare_fun, prepare_arg);
return;
case kFloat:
FHelper<OP, float>::Allreduce
(static_cast<float*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kDouble:
FHelper<OP, double>::Allreduce
(static_cast<double*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
default: utils::Error("unknown data_type");
}
}
void Allreduce(void *sendrecvbuf,
size_t count,
engine::mpi::DataType enum_dtype,
engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi; // NOLINT
switch (enum_op) {
case kMax:
Allreduce<op::Max>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kMin:
Allreduce<op::Min>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kSum:
Allreduce<op::Sum>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseAND:
Allreduce<op::BitAND>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseOR:
Allreduce<op::BitOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseXOR:
Allreduce<op::BitXOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
default: utils::Error("unknown enum_op");
}
}
void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype) {
using namespace engine::mpi; // NOLINT
size_t type_size = 0;
switch (enum_dtype) {
case kChar:
type_size = sizeof(char);
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kUChar:
type_size = sizeof(unsigned char);
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kInt:
type_size = sizeof(int);
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kUInt:
type_size = sizeof(unsigned);
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kLong:
type_size = sizeof(int64_t);
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kULong:
type_size = sizeof(uint64_t);
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kFloat:
type_size = sizeof(float);
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kDouble:
type_size = sizeof(double);
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
default: utils::Error("unknown data_type");
}
}
// wrapper for serialization
struct ReadWrapper : public Serializable {
std::string *p_str;
explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {}
void Load(Stream *fi) override {
uint64_t sz;
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
"Read pickle string");
p_str->resize(sz);
if (sz != 0) {
utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0,
"Read pickle string");
}
}
void Save(Stream *) const override {
utils::Error("not implemented");
}
};
struct WriteWrapper : public Serializable {
const char *data;
size_t length;
explicit WriteWrapper(const char *data,
size_t length)
: data(data), length(length) {
}
void Load(Stream *) override {
utils::Error("not implemented");
}
void Save(Stream *fo) const override {
uint64_t sz = static_cast<uint16_t>(length);
fo->Write(&sz, sizeof(sz));
fo->Write(data, length * sizeof(char));
}
};
} // namespace c_api
} // namespace rabit
RABIT_DLL bool RabitInit(int argc, char *argv[]) {
auto ret = rabit::Init(argc, argv);
if (!ret) {
XGBAPISetLastError("Failed to initialize RABIT.");
}
return ret;
}
RABIT_DLL int RabitFinalize() {
auto ret = rabit::Finalize();
if (!ret) {
XGBAPISetLastError("Failed to shutdown RABIT worker.");
}
return static_cast<int>(ret);
}
RABIT_DLL int RabitGetRingPrevRank() {
return rabit::GetRingPrevRank();
}
RABIT_DLL int RabitGetRank() {
return rabit::GetRank();
}
RABIT_DLL int RabitGetWorldSize() {
return rabit::GetWorldSize();
}
RABIT_DLL int RabitIsDistributed() {
return rabit::IsDistributed();
}
RABIT_DLL int RabitTrackerPrint(const char *msg) {
API_BEGIN()
std::string m(msg);
rabit::TrackerPrint(m);
API_END()
}
RABIT_DLL void RabitGetProcessorName(char *out_name,
rbt_ulong *out_len,
rbt_ulong max_len) {
std::string s = rabit::GetProcessorName();
if (s.length() > max_len) {
s.resize(max_len - 1);
}
strcpy(out_name, s.c_str()); // NOLINT(*)
*out_len = static_cast<rbt_ulong>(s.length());
}
RABIT_DLL int RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root) {
API_BEGIN()
rabit::Broadcast(sendrecv_data, size, root);
API_END()
}
RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size,
size_t beginIndex, size_t size_node_slice,
size_t size_prev_slice, int enum_dtype) {
API_BEGIN()
rabit::c_api::Allgather(
sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice,
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
API_END()
}
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
int enum_op, void (*prepare_fun)(void *arg),
void *prepare_arg) {
API_BEGIN()
rabit::c_api::Allreduce(sendrecvbuf, count,
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
static_cast<rabit::engine::mpi::OpType>(enum_op),
prepare_fun, prepare_arg);
API_END()
}
RABIT_DLL int RabitVersionNumber() {
return rabit::VersionNumber();
}
RABIT_DLL int RabitLinkTag() {
return 0;
}

View File

@ -15,9 +15,9 @@
#include <utility> // for pair
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/error_msg.h" // for NoFederated
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
@ -27,11 +27,10 @@
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "rabit/c_api.h" // for RabitLinkTag
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
@ -46,10 +45,6 @@
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
#endif
using namespace xgboost; // NOLINT(*);
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
*out_features = dmlc::BeginPtr(feature_names_c);
API_END();
}
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config{Json::Load(StringView{json_config})};
collective::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize() {
API_BEGIN();
collective::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
return collective::GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return collective::GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed(void) {
return collective::IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {
API_BEGIN();
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}
// Run a server without SSL for local testing.
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
API_BEGIN();
federated::RunInsecureServer(port, world_size);
API_END();
}
#endif
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();

View File

@ -1,22 +1,28 @@
/*!
* Copyright (c) 2015 by Contributors
/**
* Copyright 2015-2023, XGBoost Contributors
* \file c_api_error.cc
* \brief C error handling
*/
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "./c_api_error.h"
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "../collective/comm.h"
#include "../collective/comm_group.h"
struct XGBAPIErrorEntry {
std::string last_error;
std::int32_t code{-1};
};
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
XGB_DLL const char *XGBGetLastError() {
return XGBAPIErrorStore::Get()->last_error.c_str();
}
XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); }
void XGBAPISetLastError(const char* msg) {
XGBAPIErrorStore::Get()->last_error = msg;
XGBAPIErrorStore::Get()->code = -1;
}
XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; }

View File

@ -10,6 +10,7 @@
#include <dmlc/logging.h>
#include "c_api_utils.h"
#include "xgboost/collective/result.h"
/*! \brief macro to guard beginning and end section of all functions */
#ifdef LOG_CAPI_INVOCATION
@ -30,7 +31,7 @@
#define API_END() \
} catch (dmlc::Error & _except_) { \
return XGBAPIHandleException(_except_); \
} catch (std::exception const &_except_) { \
} catch (std::exception const& _except_) { \
return XGBAPIHandleException(dmlc::Error(_except_.what())); \
} \
return 0; // NOLINT(*)
@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg);
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int XGBAPIHandleException(const dmlc::Error &e) {
inline int XGBAPIHandleException(const dmlc::Error& e) {
XGBAPISetLastError(e.what());
return -1;
}

View File

@ -9,10 +9,15 @@
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "../collective/allgather.h" // for Allgather
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/broadcast.h" // for Broadcast
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/comm_group.h" // for GlobalCommGroup
#include "../collective/communicator-inl.h" // for GetProcessorName
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
@ -20,10 +25,36 @@
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
#else
#include "../common/error_msg.h" // for NoFederated
#endif
namespace xgboost::collective {
void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) {
Context ctx;
DispatchDType(static_cast<ArrayInterfaceHandler::Type>(data_type), [&](auto t) {
using T = decltype(t);
auto data = linalg::MakeTensorView(
&ctx, common::Span{static_cast<T *>(send_receive_buffer), count}, count);
auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast<Op>(op));
SafeColl(rc);
});
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) {
Context ctx;
auto rc = Broadcast(&ctx, *GlobalCommGroup(),
linalg::MakeVec(static_cast<std::int8_t *>(send_receive_buffer), size), root);
SafeColl(rc);
}
void Allgather(void *send_receive_buffer, std::size_t size) {
Context ctx;
auto const &comm = GlobalCommGroup();
auto rc = Allgather(&ctx, *comm,
linalg::MakeVec(reinterpret_cast<std::int8_t *>(send_receive_buffer), size));
SafeColl(rc);
}
} // namespace xgboost::collective
using namespace xgboost; // NOLINT
namespace {
@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
: kDft};
common::Timer timer;
timer.Start();
@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
break;
}
if (timer.Duration() > timeout && timeout.count() != 0) {
if (timer.Duration() > timeout && collective::HasTimeout(timeout)) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
}
@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
delete ptr;
API_END();
}
XGB_DLL int XGCommunicatorInit(char const *json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config{Json::Load(StringView{json_config})};
collective::GlobalCommGroupInit(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize(void) {
API_BEGIN();
collective::GlobalCommGroupFinalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
API_BEGIN();
return collective::GetRank();
API_END();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); }
XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); }
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto &local = *CollAPIThreadLocalStore::Get();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}
// Not exposed to the public since the previous implementation didn't and we don't want to
// add unnecessary communicator API to a machine learning library.
XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) {
API_BEGIN();
collective::Allgather(send_receive_buffer, count);
API_END();
}
// Not yet exposed to the public, error recovery is still WIP.
XGB_DLL int XGCommunicatorSignalError() {
API_BEGIN();
auto msg = XGBGetLastError();
SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg)));
API_END()
}

View File

@ -22,7 +22,6 @@
#include <cstdio>
#include <cstring>
#include <vector>
#include "collective/communicator-inl.h"
#include "common/common.h"
#include "common/config.h"
#include "common/io.h"
@ -193,10 +192,6 @@ class CLI {
void CLITrain() {
const double tstart_data_load = dmlc::GetTime();
if (collective::IsDistributed()) {
std::string pname = collective::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
@ -235,15 +230,9 @@ class CLI {
version += 1;
}
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (collective::IsDistributed()) {
if (collective::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
collective::GetRank() == 0) {
LOG(CONSOLE) << res;
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0) {
std::ostringstream os;
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
@ -256,8 +245,7 @@ class CLI {
<< " sec";
// always save final round
if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) &&
collective::GetRank() == 0) {
param_.num_round % param_.save_period != 0)) {
std::ostringstream os;
if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
@ -465,13 +453,6 @@ class CLI {
}
}
// Initialize the collective communicator.
Json json{JsonObject()};
for (auto& kv : cfg) {
json[kv.first] = String(kv.second);
}
collective::Init(json);
param_.Configure(cfg);
}
@ -507,10 +488,6 @@ class CLI {
}
return 0;
}
~CLI() {
collective::Finalize();
}
};
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
@ -13,7 +13,8 @@
#include <utility>
#include <vector>
#include "communicator-inl.cuh"
#include "allreduce.h"
#include "xgboost/collective/result.h" // for Result
namespace xgboost::collective {
@ -24,15 +25,17 @@ namespace xgboost::collective {
* column-wise (vertically), the original values are returned.
*
* @tparam T The type of the values.
*
* @param info MetaInfo about the DMatrix.
* @param device The device id.
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}
} // namespace xgboost::collective

View File

@ -11,11 +11,44 @@
#include <utility>
#include <vector>
#include "allreduce.h"
#include "broadcast.h"
#include "comm.h"
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
namespace xgboost::collective {
namespace detail {
template <typename Fn>
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
std::string msg;
if (collective::GetRank() == 0) {
try {
fn();
} catch (dmlc::Error const& e) {
msg = e.what();
}
}
std::size_t msg_size{msg.size()};
auto rc = Success() << [&] {
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
return rc;
} << [&] {
if (msg_size > 0) {
msg.resize(msg_size);
return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0);
}
return Success();
} << [&] {
if (msg_size > 0) {
LOG(FATAL) << msg;
}
return Success();
};
return rc;
}
} // namespace detail
/**
* @brief Apply the given function where the labels are.
@ -30,29 +63,19 @@ namespace xgboost::collective {
* @param size The size of the buffer.
* @param function The function used to calculate the results.
*/
template <typename FN>
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
FN&& function) {
template <typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<FN>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
collective::Broadcast(&message, 0);
if (message.empty()) {
collective::Broadcast(buffer, size, 0);
} else {
LOG(FATAL) << &message[0];
}
auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] {
// We assume labels are only available on worker 0, so the calculation is done there and
// result broadcast to other workers.
return collective::Broadcast(
ctx, linalg::MakeVec(reinterpret_cast<std::int8_t*>(buffer), size), 0);
};
SafeColl(rc);
} else {
std::forward<FN>(function)();
std::forward<Fn>(fn)();
}
}
@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
template <typename T, typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
auto rc = detail::TryApplyWithLabels(ctx, fn);
collective::Broadcast(&message, 0);
if (!message.empty()) {
LOG(FATAL) << &message[0];
return;
}
std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
std::size_t size{result->Size()};
rc = std::move(rc) << [&] {
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
} << [&] {
result->Resize(size);
return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0);
};
SafeColl(rc);
} else {
std::forward<Function>(function)();
std::forward<Fn>(fn)();
}
}
@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
* @return The global max of the input.
*/
template <typename T>
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const* ctx,
MetaInfo const& info,
T value) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kMax>(&value, 1);
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax);
SafeColl(rc);
}
return value;
}
@ -136,19 +147,14 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
* @param size Number of values to sum.
*/
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}
template <typename Container>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
return GlobalSum(ctx, info, values->data(), values->size());
}
/**
* @brief Find the global ratio of the given two values across all workers.
*

View File

@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return comm.Block();
};
if (!rc.OK()) {
return rc;
return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc));
}
}
@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> si
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r),
std::move(rc));
}
offset += as_bytes;
}
@ -102,7 +103,7 @@ namespace detail {
return prev_ch->Block();
};
if (!rc.OK()) {
return rc;
return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc));
}
}
return comm.Block();

View File

@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func
auto rc = RingAllgather(comm, typed);
if (!rc.OK()) {
return rc;
return Fail("Ring allreduce small failed.", std::move(rc));
}
auto first = s_buffer.subspan(0, data.size_bytes());
CHECK_EQ(first.size(), data.size());
@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1);
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r),
std::move(rc));
}
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto n_bytes_in_seg = (n / world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
return Fail("Ring Allreduce failed.", std::move(rc));
}
auto prev = BootstrapPrev(comm.Rank(), comm.World());

View File

@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
}
auto rank = comm.Rank();
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.");
std::size_t n_bytes{0};
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.", std::move(rc));
}
workers[r] = std::move(worker);
}
@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return rc;
}
std::int32_t rank{-1};
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
if (n_bytes != sizeof(comm.Rank())) {
std::size_t n_bytes{0};
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
}
workers[rank] = std::move(peer);

View File

@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
loop_->Submit(std::move(op));
}
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }

View File

@ -76,7 +76,7 @@ CommGroup::CommGroup()
// Common args
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
auto timeout =
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
get_param("dmlc_timeout", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() {
sptr.reset();
SafeColl(rc);
}
void Init(Json const& config) { GlobalCommGroupInit(config); }
void Finalize() { GlobalCommGroupFinalize(); }
std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
[[nodiscard]] bool IsFederated() {
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
}
void Print(std::string const& message) {
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
SafeColl(rc);
}
std::string GetProcessorName() {
std::string out;
auto rc = GlobalCommGroup()->ProcessorName(&out);
SafeColl(rc);
return out;
}
} // namespace xgboost::collective

View File

@ -1,34 +0,0 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "communicator-inl.h"
namespace xgboost::collective {
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const &vec) { return vec.size(); });
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const &vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
auto out = AllgatherV(collected);
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
} // namespace xgboost::collective

View File

@ -1,95 +0,0 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Reduce values from all processes and distribute the result back to all processes.
* @param device ID of the device.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
template <Operation op>
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
}
/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
}
/**
* @brief Synchronize device operations.
* @param device ID of the device.
*/
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
} // namespace collective
} // namespace xgboost

View File

@ -3,308 +3,63 @@
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "xgboost/json.h" // for Json
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Initialize the collective communicator.
*/
void Init(Json const& config);
/**
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const &config) { Communicator::Init(config); }
/*!
* \brief Finalize the collective communicator.
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*/
inline void Finalize() { Communicator::Finalize(); }
void Finalize();
/*!
* \brief Get rank of current process.
/**
* @brief Get rank of current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
inline int GetRank() { return Communicator::Get()->GetRank(); }
[[nodiscard]] std::int32_t GetRank() noexcept;
/*!
* \brief Get total number of processes.
/**
* @brief Get total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
[[nodiscard]] std::int32_t GetWorldSize() noexcept;
/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
[[nodiscard]] bool IsDistributed() noexcept;
/*!
* \brief Get if the communicator is federated.
/**
* @brief Get if the communicator is federated.
*
* \return True if the communicator is federated.
* @return True if the communicator is federated.
*/
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
[[nodiscard]] bool IsFederated();
/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
* @param message The message to be printed.
*/
inline void Print(char const *message) { Communicator::Get()->Print(message); }
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
/*!
* \brief Get the name of the processor.
*
* \return Name of the processor.
*/
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
void Print(std::string const& message);
/**
* @brief Gathers a single value all processes and distributes the result to all processes.
* @brief Get the name of the processor.
*
* @param input The single value.
* @return Name of the processor.
*/
template <typename T>
inline std::vector<T> Allgather(T const &input) {
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size.
*
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> Allgather(std::vector<T> const &input) {
if (input.empty()) {
return input;
}
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGatherV(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
if (!output.empty()) {
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
}
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input);
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
std::size_t total_size{0};
for (auto const &s : input) {
total_size += s.length() + 1; // +1 for null-terminators
}
std::string flat_string;
flat_string.reserve(total_size);
for (auto const &s : input) {
flat_string.append(s);
flat_string.push_back('\0'); // Append a null-terminator after each string
}
auto const output = Communicator::Get()->AllGatherV(flat_string);
std::vector<std::string> result;
std::size_t start_index = 0;
// Iterate through the output, find each null-terminated substring.
for (std::size_t i = 0; i < output.size(); i++) {
if (output[i] == '\0') {
// Construct a std::string from the char* substring
result.emplace_back(&output[start_index]);
// Move to the next substring
start_index = i + 1;
}
}
return result;
}
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*/
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
static_cast<Operation>(op));
}
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
}
template <Operation op>
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
// Specialization for size_t, which is implementation defined, so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <Operation op, typename T,
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
inline void Allreduce(T *send_receive_buffer, size_t count) {
static_assert(sizeof(T) == sizeof(uint64_t));
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void Allreduce(float *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
} // namespace collective
} // namespace xgboost
std::string GetProcessorName();
} // namespace xgboost::collective

View File

@ -1,63 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "comm.h"
#include "in_memory_communicator.h"
#include "noop_communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};
void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
type = arg;
}
if (type == CommunicatorType::kUnknown) {
// Default to Rabit if unspecified.
type = CommunicatorType::kRabit;
}
type_ = type;
switch (type) {
case CommunicatorType::kRabit: {
communicator_.reset(RabitCommunicator::Create(config));
break;
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
}
#endif
} // namespace xgboost::collective

View File

@ -1,54 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#include "noop_communicator.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
device_communicator_.reset(nullptr);
}
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
thread_local auto old_device_ordinal = -1;
// If the number of GPUs changes, we need to re-initialize NCCL.
thread_local auto old_world_size = -1;
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
communicator_->GetWorldSize() != old_world_size) {
old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
break;
case CommunicatorType::kFederated:
case CommunicatorType::kInMemory:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemoryNccl:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();
}
} // namespace collective
} // namespace xgboost

View File

@ -1,247 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <memory>
#include <string>
namespace xgboost {
namespace collective {
/** @brief Defines the integral and floating data types. */
enum class DataType {
kInt8 = 0,
kUInt8 = 1,
kInt32 = 2,
kUInt32 = 3,
kInt64 = 4,
kUInt64 = 5,
kFloat = 6,
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation {
kMax = 0,
kMin = 1,
kSum = 2,
kBitwiseAND = 3,
kBitwiseOR = 4,
kBitwiseXOR = 5
};
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
#ifdef _MSC_VER
return _stricmp(s1, s2);
#else // _MSC_VER
return strcasecmp(s1, s2);
#endif // _MSC_VER
}
/**
* @brief A communicator class that handles collective communication.
*/
class Communicator {
public:
/**
* @brief Initialize the communicator. This can only be done once.
*
* @param config JSON configuration for the communicator.
*/
static void Init(Json const &config);
/** @brief Finalize the communicator. */
static void Finalize();
/** @brief Get the communicator instance. */
static Communicator *Get() { return communicator_.get(); }
#if defined(XGBOOST_USE_CUDA)
/**
* @brief Get the device communicator.
*
* @param device_ordinal ID of the device.
* @return An instance of device communicator.
*/
static DeviceCommunicator *GetDevice(int device_ordinal);
#endif
virtual ~Communicator() = default;
/** @brief Get the total number of processes. */
int GetWorldSize() const { return world_size_; }
/** @brief Get the rank of the current processes. */
int GetRank() const { return rank_; }
/** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0;
/** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0;
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size.
*
* @param input Buffer storing the data.
*/
virtual std::string AllGather(std::string_view input) = 0;
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
virtual std::string AllGatherV(std::string_view input) = 0;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
* group.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param root Rank of broadcast root.
*/
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
/**
* @brief Gets the name of the processor.
*/
virtual std::string GetProcessorName() = 0;
/**
* @brief Prints the message.
*/
virtual void Print(std::string const &message) = 0;
/** @brief Get the communicator type from environment variables. Visible for testing. */
static CommunicatorType GetTypeFromEnv() {
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
if (env != nullptr) {
return StringToType(env);
} else {
return CommunicatorType::kUnknown;
}
}
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
static CommunicatorType GetTypeFromConfig(Json const &config) {
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
if (IsA<String const>(j_upper)) {
return StringToType(get<String const>(j_upper).c_str());
}
auto const &j_lower = config["xgboost_communicator"];
if (IsA<String const>(j_lower)) {
return StringToType(get<String const>(j_lower).c_str());
}
return CommunicatorType::kUnknown;
}
protected:
/**
* @brief Construct a new communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
if (world_size < 1) {
LOG(FATAL) << "World size " << world_size << " is less than 1.";
}
if (rank < 0) {
LOG(FATAL) << "Rank " << rank << " is less than 0.";
}
if (rank >= world_size) {
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
}
}
/**
* @brief Shuts down the communicator.
*/
virtual void Shutdown() = 0;
private:
static CommunicatorType StringToType(char const *str) {
CommunicatorType result = CommunicatorType::kUnknown;
if (!CompareStringsCaseInsensitive("rabit", str)) {
result = CommunicatorType::kRabit;
} else if (!CompareStringsCaseInsensitive("federated", str)) {
result = CommunicatorType::kFederated;
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
result = CommunicatorType::kInMemory;
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
result = CommunicatorType::kInMemoryNccl;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}
return result;
}
static thread_local std::unique_ptr<Communicator> communicator_;
static thread_local CommunicatorType type_;
static thread_local std::string nccl_path_;
#if defined(XGBOOST_USE_CUDA)
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
#endif
int const world_size_;
int const rank_;
};
} // namespace collective
} // namespace xgboost

View File

@ -1,57 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <vector>
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Collective communicator for device buffers.
*/
class DeviceCommunicator {
public:
virtual ~DeviceCommunicator() = default;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) = 0;
/** @brief Synchronize device operations. */
virtual void Synchronize() = 0;
};
} // namespace collective
} // namespace xgboost

View File

@ -1,94 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <numeric> // for accumulate
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.resize(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
auto const output = Allgather(host_buffer_);
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.resize(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
}
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
cudaMemcpyDefault));
}
void Synchronize() override {
// Noop.
}
private:
int const device_ordinal_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
} // namespace collective
} // namespace xgboost

View File

@ -1,12 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "in_memory_communicator.h"
namespace xgboost {
namespace collective {
InMemoryHandler InMemoryCommunicator::handler_{};
} // namespace collective
} // namespace xgboost

View File

@ -15,14 +15,14 @@ namespace collective {
/**
* An in-memory communicator, useful for testing.
*/
class InMemoryCommunicator : public Communicator {
class InMemoryCommunicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static Communicator* Create(Json const& config) {
static InMemoryCommunicator* Create(Json const& config) {
int world_size{0};
int rank{-1};
@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator {
return new InMemoryCommunicator(world_size, rank);
}
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
InMemoryCommunicator(int world_size, int rank) {
handler_.Init(world_size, rank);
}

View File

@ -1,14 +1,13 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "in_memory_handler.h"
#include <algorithm>
#include <functional>
#include "comm.h"
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Functor for allgather.
*/
@ -16,7 +15,7 @@ class AllgatherFunctor {
public:
std::string const name{"Allgather"};
AllgatherFunctor(std::size_t world_size, std::size_t rank)
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
: world_size_{world_size}, rank_{rank} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
@ -30,8 +29,8 @@ class AllgatherFunctor {
}
private:
std::size_t world_size_;
std::size_t rank_;
std::int32_t world_size_;
std::int32_t rank_;
};
/**
@ -41,13 +40,13 @@ class AllgatherVFunctor {
public:
std::string const name{"AllgatherV"};
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
std::map<std::size_t, std::string_view>* data)
: world_size_{world_size}, rank_{rank}, data_{data} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
data_->emplace(rank_, std::string_view{input, bytes});
if (data_->size() == world_size_) {
if (data_->size() == static_cast<std::size_t>(world_size_)) {
for (auto const& kv : *data_) {
buffer->append(kv.second);
}
@ -56,8 +55,8 @@ class AllgatherVFunctor {
}
private:
std::size_t world_size_;
std::size_t rank_;
std::int32_t world_size_;
std::int32_t rank_;
std::map<std::size_t, std::string_view>* data_;
};
@ -68,7 +67,7 @@ class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
AllreduceFunctor(DataType dataType, Operation operation)
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
: data_type_{dataType}, operation_{operation} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
@ -76,23 +75,23 @@ class AllreduceFunctor {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
// Apply the reduce_operation to the input and the buffer.
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
Accumulate(input, bytes / n_bytes_type, &buffer->front());
}
}
private:
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
Operation reduce_operation) const {
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Operation::kBitwiseAND:
case Op::kBitwiseAND:
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
break;
case Operation::kBitwiseOR:
case Op::kBitwiseOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
break;
case Operation::kBitwiseXOR:
case Op::kBitwiseXOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
break;
default:
@ -101,27 +100,27 @@ class AllreduceFunctor {
}
template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Operation::kMax:
case Op::kMax:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::max(a, b); });
break;
case Operation::kMin:
case Op::kMin:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::min(a, b); });
break;
case Operation::kSum:
case Op::kSum:
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
case Op::kBitwiseAND:
case Op::kBitwiseOR:
case Op::kBitwiseXOR:
AccumulateBitwise(buffer, input, size, reduce_operation);
break;
default:
@ -130,36 +129,37 @@ class AllreduceFunctor {
}
void Accumulate(char const* input, std::size_t size, char* buffer) const {
using Type = ArrayInterfaceHandler::Type;
switch (data_type_) {
case DataType::kInt8:
case Type::kI1:
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
reinterpret_cast<std::int8_t const*>(input), size, operation_);
break;
case DataType::kUInt8:
case Type::kU1:
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
break;
case DataType::kInt32:
case Type::kI4:
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
reinterpret_cast<std::int32_t const*>(input), size, operation_);
break;
case DataType::kUInt32:
case Type::kU4:
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
break;
case DataType::kInt64:
case Type::kI8:
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
reinterpret_cast<std::int64_t const*>(input), size, operation_);
break;
case DataType::kUInt64:
case Type::kU8:
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
break;
case DataType::kFloat:
case Type::kF4:
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
operation_);
break;
case DataType::kDouble:
case Type::kF8:
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
operation_);
break;
@ -169,8 +169,8 @@ class AllreduceFunctor {
}
private:
DataType data_type_;
Operation operation_;
ArrayInterfaceHandler::Type data_type_;
Op operation_;
};
/**
@ -180,7 +180,7 @@ class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (rank_ == root_) {
@ -190,11 +190,11 @@ class BroadcastFunctor {
}
private:
std::size_t rank_;
std::size_t root_;
std::int32_t rank_;
std::int32_t root_;
};
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
std::unique_lock<std::mutex> lock(mutex_);
@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
cv_.notify_all();
}
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
std::unique_lock<std::mutex> lock(mutex_);
@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
}
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
}
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, DataType data_type,
Operation op) {
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op) {
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
}
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, std::size_t root) {
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
}
template <class HandlerFunctor>
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank,
std::size_t sequence_number, std::int32_t rank,
HandlerFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
cv_.notify_all();
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@ -1,16 +1,15 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <condition_variable>
#include <map>
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
#include "../data/array_interface.h"
#include "comm.h"
namespace xgboost::collective {
/**
* @brief Handles collective communication primitives in memory.
*
@ -28,12 +27,11 @@ class InMemoryHandler {
/**
* @brief Construct a handler with the given world size.
* @param world_size Number of workers.
* @param world Number of workers.
*
* This is used when the handler only needs to be initialized once with a known world size.
*/
explicit InMemoryHandler(std::int32_t worldSize)
: world_size_{static_cast<std::size_t>(worldSize)} {}
explicit InMemoryHandler(std::int32_t world) : world_size_{world} {}
/**
* @brief Initialize the handler with the world size and rank.
@ -43,7 +41,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* initialize it collectively.
*/
void Init(std::size_t world_size, std::size_t rank);
void Init(std::int32_t world_size, std::int32_t rank);
/**
* @brief Shut down the handler.
@ -53,7 +51,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* shut it down collectively.
*/
void Shutdown(uint64_t sequence_number, std::size_t rank);
void Shutdown(uint64_t sequence_number, std::int32_t rank);
/**
* @brief Perform allgather.
@ -64,7 +62,7 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
std::size_t sequence_number, std::int32_t rank);
/**
* @brief Perform variable-length allgather.
@ -75,7 +73,7 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
std::size_t sequence_number, std::int32_t rank);
/**
* @brief Perform allreduce.
@ -88,7 +86,8 @@ class InMemoryHandler {
* @param op The reduce operation.
*/
void Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op);
/**
* @brief Perform broadcast.
@ -100,7 +99,7 @@ class InMemoryHandler {
* @param root Index of the worker to broadcast from.
*/
void Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, std::size_t root);
std::size_t sequence_number, std::int32_t rank, std::int32_t root);
private:
/**
@ -115,17 +114,15 @@ class InMemoryHandler {
*/
template <class HandlerFunctor>
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
std::size_t rank, HandlerFunctor const& functor);
std::int32_t rank, HandlerFunctor const& functor);
std::size_t world_size_{}; /// Number of workers.
std::size_t received_{}; /// Number of calls received with the current sequence.
std::size_t sent_{}; /// Number of calls completed with the current sequence.
std::int32_t world_size_{}; /// Number of workers.
std::int64_t received_{}; /// Number of calls received with the current sequence.
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
std::string buffer_{}; /// A shared common buffer.
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
uint64_t sequence_number_{}; /// Call sequence number.
mutable std::mutex mutex_; /// Lock.
mutable std::condition_variable cv_; /// Conditional variable to wait on.
};
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@ -6,6 +6,8 @@
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <exception> // for exception, current_exception, rethrow_exception
#include <future> // for promise
#include <memory> // for make_shared
#include <mutex> // for lock_guard, unique_lock
#include <queue> // for queue
#include <string> // for string
@ -18,9 +20,10 @@
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
Result Loop::ProcessQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__);
auto error = [this] {
auto error = [this](Op op) {
op.pr->set_value();
timer_.Stop(__func__);
};
@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
// Iterate through all the ops for poll
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
auto op = std::move(qcopy.front());
qcopy.pop();
switch (op.code) {
@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
break;
}
default: {
error();
error(op);
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
qcopy.push(std::move(op));
}
// poll, work on fds that are ready.
@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
if (!poll.fds.empty()) {
auto rc = poll.Poll(timeout_);
if (!rc.OK()) {
error();
timer_.Stop(__func__);
return rc;
}
}
timer_.Stop("poll");
// we wonldn't be here if the queue is empty.
// We wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
// Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
auto op = std::move(qcopy.front());
qcopy.pop();
std::int32_t n_bytes_done{0};
@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
if (n_bytes_done == 0) {
error();
return Fail("Encountered EOF. The other end is likely closed.");
error(op);
return Fail("Encountered EOF. The other end is likely closed.",
op.sock->GetSockError());
}
}
break;
@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
break;
}
default: {
error();
error(op);
return Fail("Invalid socket operation.");
}
}
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
auto rc = system::FailWithCode("Invalid socket output.");
error();
error(op);
return rc;
}
@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
CHECK_LE(op.off, op.n);
if (op.off != op.n) {
// not yet finished, push back to queue for next round.
// not yet finished, push back to queue for the next round.
qcopy.push(op);
} else {
op.pr->set_value();
}
}
if (!blocking) {
break;
}
}
timer_.Stop(__func__);
@ -148,8 +150,7 @@ void Loop::Process() {
};
// This loop cannot exit unless `stop_` is set to true. There must always be a thread to
// answer the blocking call even if there are errors, otherwise the blocking will wait
// forever.
// answer the call even if there are errors.
while (true) {
try {
std::unique_lock lock{mu_};
@ -170,44 +171,15 @@ void Loop::Process() {
// Move the global queue into a local variable to unblock it.
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
auto op = std::move(queue_.front());
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
qcopy.push(op);
}
lock.unlock();
// Clear the local queue, if `is_blocking` is true, this is blocking the current
// worker thread (but not the client thread), wait until all operations are
// finished.
auto rc = this->ProcessQueue(&qcopy, is_blocking);
if (is_blocking && rc.OK()) {
CHECK(qcopy.empty());
}
// Push back the remaining operations.
if (rc.OK()) {
std::unique_lock lock{mu_};
while (!qcopy.empty()) {
queue_.push(qcopy.front());
qcopy.pop();
}
}
// Notify the client thread who called block after all error conditions are set.
auto notify_if_block = [&] {
if (is_blocking) {
std::unique_lock lock{mu_};
block_done_ = true;
lock.unlock();
block_cv_.notify_one();
}
};
// Clear the local queue.
auto rc = this->ProcessQueue(&qcopy);
// Handle error
if (!rc.OK()) {
@ -215,8 +187,6 @@ void Loop::Process() {
} else {
CHECK(qcopy.empty());
}
notify_if_block();
} catch (std::exception const& e) {
curr_exce_ = std::current_exception();
set_rc(Fail("Exception inside the event loop:" + std::string{e.what()}));
@ -256,20 +226,28 @@ Result Loop::Stop() {
stop_ = true;
}
}
if (!this->worker_.joinable()) {
std::lock_guard<std::mutex> guard{rc_lock_};
return Fail("Worker has stopped.", std::move(rc_));
}
this->Submit(Op{Op::kBlock});
{
// Wait for the block call to finish.
std::unique_lock lock{mu_};
block_cv_.wait(lock, [this] { return block_done_ || stop_; });
block_done_ = false;
cv_.notify_one();
}
for (auto& fut : futures_) {
if (fut.valid()) {
try {
fut.get();
} catch (std::future_error const&) {
// Do nothing. If something went wrong in the worker, we have a std::future_error
// due to broken promise. This function will transfer the rc back to the caller.
}
}
}
futures_.clear();
{
// Transfer the rc.
std::lock_guard<std::mutex> lock{rc_lock_};
@ -278,13 +256,13 @@ Result Loop::Stop() {
}
void Loop::Submit(Op op) {
auto p = std::make_shared<std::promise<void>>();
op.pr = std::move(p);
futures_.emplace_back(op.pr->get_future());
CHECK_NE(op.n, 0);
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {

View File

@ -7,9 +7,12 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t
#include <exception> // for exception_ptr
#include <mutex> // for unique_lock, mutex
#include <future> // for future
#include <memory> // for shared_ptr
#include <mutex> // for mutex
#include <queue> // for queue
#include <thread> // for thread
#include <vector> // for vector
#include "../common/timer.h" // for Monitor
#include "xgboost/collective/result.h" // for Result
@ -20,14 +23,15 @@ class Loop {
public:
struct Op {
// kSleep is only for testing
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
std::shared_ptr<std::promise<void>> pr;
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
explicit Op(Code c) : code{c} { CHECK(c == kSleep); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
@ -45,12 +49,11 @@ class Loop {
private:
std::thread worker_; // thread worker to execute the tasks
std::condition_variable cv_; // CV used to notify a new submit call
std::condition_variable block_cv_; // CV used to notify the blocking call
bool block_done_{false}; // Flag to indicate whether the blocking call has finished.
std::condition_variable cv_; // CV used to notify a new submit call
std::queue<Op> queue_; // event queue
std::mutex mu_; // mutex to protect the queue, cv, and block_done
std::vector<std::future<void>> futures_;
std::mutex mu_; // mutex to protect the queue, cv, and block_done
std::chrono::seconds timeout_;
@ -61,7 +64,7 @@ class Loop {
std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_;
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
Result ProcessQueue(std::queue<Op>* p_queue) const;
// The cunsumer function that runs inside a worker thread.
void Process();

View File

@ -1,243 +0,0 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <numeric> // for accumulate
#include "comm.cuh"
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
StringView nccl_path)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (world_size_ == 1) {
return;
}
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
CHECK(rc.OK()) << rc.Report();
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (nccl_comm_) {
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result{ncclInt8};
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
bool IsBitwiseOp(Operation const &op) {
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
op == Operation::kBitwiseXOR;
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size) {
dh::LaunchN(size, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
if (needs_sync_) {
dh::DefaultStream().Sync();
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
}
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream());
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [&] { return stub_->GroupEnd(); };
}
void NcclDeviceCommunicator::Synchronize() {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::DefaultStream().Sync();
}
} // namespace collective
} // namespace xgboost
#endif

View File

@ -1,91 +0,0 @@
/*!
* Copyright 2022-2023 XGBoost contributors
*/
#pragma once
#include "../common/device_helpers.cuh"
#include "comm.cuh"
#include "communicator.h"
#include "device_communicator.cuh"
#include "nccl_stub.h"
namespace xgboost {
namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
/**
* @brief Construct a new NCCL communicator.
* @param device_ordinal The GPU device id.
* @param needs_sync Whether extra CUDA stream synchronization is needed.
*
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
*
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override;
void Synchronize() override;
private:
static constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (rank_ == kRootRank) {
auto rc = stub_->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op);
int const device_ordinal_;
bool const needs_sync_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
std::shared_ptr<NcclStub> stub_;
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
};
} // namespace collective
} // namespace xgboost

View File

@ -1,32 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
public:
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view) override { return {}; }
std::string AllGatherV(std::string_view) override { return {}; }
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return {}; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
};
} // namespace collective
} // namespace xgboost

View File

@ -41,20 +41,26 @@ struct Magic {
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
std::int32_t magic{kMagic};
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
magic = 0;
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
std::size_t n_sent{0};
return Success() << [&] {
return p_sock->SendAll(&magic, sizeof(magic), &n_sent);
} << [&] {
if (n_sent != sizeof(magic)) {
return Fail("Failed to verify.");
}
return Success();
} << [&] {
magic = 0;
return p_sock->RecvAll(&magic, sizeof(magic), &n_sent);
} << [&] {
if (n_sent != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
};
}
};
@ -227,31 +233,43 @@ struct Error {
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
std::int32_t err{ErrorSignal()};
auto n_sent = worker->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
std::size_t n_sent{0};
return Success() << [&] {
return worker->SendAll(&err, sizeof(err), &n_sent);
} << [&] {
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
};
}
// self is localhost, we are sending the signal to the error handling thread for it to
// close.
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
std::int32_t err{ShutdownSignal()};
auto n_sent = self->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
std::size_t n_sent{0};
return Success() << [&] {
return self->SendAll(&err, sizeof(err), &n_sent);
} << [&] {
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
};
}
// get signal, either for error or for shutdown.
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
std::int32_t err{ShutdownSignal()};
auto n_recv = peer->RecvAll(&err, sizeof(err));
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
std::size_t n_recv{0};
return Success() << [&] {
return peer->RecvAll(&err, sizeof(err), &n_recv);
} << [&] {
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
};
}
};
} // namespace xgboost::collective::proto

View File

@ -1,175 +0,0 @@
/**
* Copyright 2022-2023 by XGBoost contributors
*/
#pragma once
#include <rabit/rabit.h>
#include <string>
#include <vector>
#include "communicator-inl.h"
#include "communicator.h"
#include "xgboost/json.h"
namespace xgboost {
namespace collective {
class RabitCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::vector<std::string> args_str;
for (auto &items : get<Object const>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kString: {
args_str.push_back(items.first + "=" + get<String const>(items.second));
break;
}
case xgboost::Value::ValueKind::kInteger: {
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
args_str.push_back(items.first + "=1");
} else {
args_str.push_back(items.first + "=0");
}
break;
}
default:
break;
}
}
std::vector<char *> args;
for (auto &key_value : args_str) {
args.push_back(&key_value[0]);
}
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
LOG(FATAL) << "Failed to initialize Rabit";
}
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
}
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
bool IsDistributed() const override { return rabit::IsDistributed(); }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view input) override {
auto const per_rank = input.size();
auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank();
std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result;
}
std::string AllGatherV(std::string_view input) override {
auto const size_node_slice = input.size();
auto const all_sizes = collective::Allgather(size_node_slice);
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input);
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
return result;
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
case DataType::kInt8:
DoAllReduce<char>(send_receive_buffer, count, op);
break;
case DataType::kUInt8:
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
break;
case DataType::kInt32:
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt32:
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
break;
case DataType::kInt64:
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt64:
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
break;
case DataType::kFloat:
DoAllReduce<float>(send_receive_buffer, count, op);
break;
case DataType::kDouble:
DoAllReduce<double>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown data type";
}
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
rabit::Broadcast(send_receive_buffer, size, root);
}
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
protected:
void Shutdown() override { rabit::Finalize(); }
private:
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kBitwiseAND:
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
case Operation::kBitwiseOR:
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseXOR:
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kMax:
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kMin:
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
};
} // namespace collective
} // namespace xgboost

View File

@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
ptr->prev = std::move(rhs);
}
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
return std::forward<std::string>(msg);
}
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
auto name = std::filesystem::path{file}.filename();
dmlc::DateLogger logger;
if (file && line != -1) {
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
auto name = std::filesystem::path{ file }.filename();
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
"]: " + std::forward<std::string>(msg);
}
return std::forward<std::string>(msg);
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
}
#endif
} // namespace detail
void SafeColl(Result const& rc) {

View File

@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) {
CHECK(!this->IsClosed());
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
std::int32_t len = static_cast<std::int32_t>(str.size());
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
auto bytes = this->SendAll(str.c_str(), str.size());
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
return bytes;
std::size_t n_bytes{0};
auto rc = Success() << [&] {
return this->SendAll(&len, sizeof(len), &n_bytes);
} << [&] {
if (n_bytes != sizeof(len)) {
return Fail("Failed to send string length.");
}
return Success();
} << [&] {
return this->SendAll(str.c_str(), str.size(), &n_bytes);
} << [&] {
if (n_bytes != str.size()) {
return Fail("Failed to send string.");
}
return Success();
};
SafeColl(rc);
return n_bytes;
}
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed());
std::int32_t len;
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
return Fail("Failed to recv string length.");
}
p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len);
if (static_cast<decltype(len)>(bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
std::size_t n_bytes{0};
return Success() << [&] {
return this->RecvAll(&len, sizeof(len), &n_bytes);
} << [&] {
if (n_bytes != sizeof(len)) {
return Fail("Failed to recv string length.");
}
return Success();
} << [&] {
p_str->resize(len);
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
} << [&] {
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
};
}
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,

View File

@ -31,14 +31,20 @@
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: sortby_{static_cast<SortBy>(
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
n_workers_{
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
timeout_{std::chrono::seconds{
OptionalArg<Integer const>(config, "timeout", static_cast<std::int64_t>(0))}} {
using std::chrono_literals::operator""s;
// Some old configurations in JVM for the scala implementation (removed) use 0 to
// indicate blocking. We continue that convention here.
timeout_ = (timeout_ == 0s) ? -1s : timeout_;
}
Result Tracker::WaitUntilReady() const {
using namespace std::chrono_literals; // NOLINT
@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const {
timer.Start();
while (!this->Ready()) {
auto ela = timer.Duration().count();
if (ela > this->Timeout().count()) {
if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) {
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
std::lock_guard lock{listener_mu_};
return listener_.NonBlocking(true);
} << [&] {
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
{
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
}
if (state.running) {
// Don't timeout if the communicator group is up and running.
return poll.Poll(std::chrono::seconds{-1});

View File

@ -15,6 +15,7 @@
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; }
/**
*
* @brief Implementation of RABIT tracker.
@ -52,7 +53,7 @@ class Tracker {
protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
std::chrono::seconds timeout_{-1};
std::atomic<bool> ready_{false};
public:

View File

@ -1,5 +1,5 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file io.h
* \brief general stream interface for serialization, I/O
* \author Tianqi Chen
@ -8,7 +8,6 @@
#define XGBOOST_COMMON_IO_H_
#include <dmlc/io.h>
#include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
#include <algorithm> // for min, fill_n, copy_n
#include <array> // for array
@ -23,12 +22,99 @@
#include <utility> // for move
#include <vector> // for vector
#include "common.h"
#include "common.h" // for DivRoundUp
#include "xgboost/string_view.h" // for StringView
namespace xgboost::common {
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
using MemoryBufferStream = rabit::utils::MemoryBufferStream;
struct MemoryFixSizeBuffer : public dmlc::SeekStream {
public:
// similar to SEEK_END in libc
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
public:
/**
* @brief Ctor
*
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
* @param buffer_size Size of the source buffer
*/
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
~MemoryFixSizeBuffer() override = default;
std::size_t Read(void *ptr, std::size_t size) override {
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, std::size_t size) override {
if (size == 0) return;
CHECK_LE(curr_ptr_ + size, buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(std::size_t pos) override {
if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_;
} else {
curr_ptr_ = static_cast<std::size_t>(pos);
}
}
/**
* @brief Current position in the buffer (stream).
*/
std::size_t Tell() override { return curr_ptr_; }
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
protected:
/*! \brief in memory buffer */
char *p_buffer_{nullptr};
/*! \brief current pointer */
std::size_t buffer_size_{0};
/*! \brief current pointer */
std::size_t curr_ptr_{0};
};
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public dmlc::SeekStream {
public:
explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, size_t size) override {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
size_t Tell() override {
return curr_ptr_;
}
virtual bool AtEnd() const {
return curr_ptr_ == p_buffer_->length();
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*!
* \brief Input stream that support additional PeekRead operation,

View File

@ -116,19 +116,19 @@ INSTANTIATE(ColumnarAdapterBatch)
namespace {
/**
* \brief A view over gathered sketch values.
* @brief A view over gathered sketch values.
*/
template <typename T>
struct QuantileAllreduce {
common::Span<T> global_values;
common::Span<bst_idx_t> worker_indptr;
common::Span<bst_idx_t> feature_indptr;
size_t n_features{0};
bst_feature_t n_features{0};
/**
* \brief Get sketch values of the a feature from a worker.
* @brief Get sketch values of the a feature from a worker.
*
* \param rank rank of target worker
* \param fidx feature idx
* @param rank rank of target worker
* @param fidx feature idx
*/
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
// get span for worker
@ -154,7 +154,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
worker_segments.resize(1, 0);
auto world = collective::GetWorldSize();
auto rank = collective::GetRank();
auto n_columns = sketches_.size();
bst_feature_t n_columns = sketches_.size();
// get the size of each feature.
std::vector<bst_idx_t> sketch_size;
@ -165,7 +165,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
sketch_size.push_back(reduced[i].size);
}
}
// turn the size into CSC indptr
// Turn the size into CSC indptr
std::vector<bst_idx_t> &sketches_scan = *p_sketches_scan;
sketches_scan.resize((n_columns + 1) * world, 0);
size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker.
@ -174,7 +174,10 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
// Gather all column pointers
auto rc =
collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size()));
collective::SafeColl(rc);
if (!rc.OK()) {
collective::SafeColl(collective::Fail("Failed to get sketch scan.", std::move(rc)));
}
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
@ -206,7 +209,9 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
ctx, info,
linalg::MakeVec(reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)));
collective::SafeColl(rc);
if (!rc.OK()) {
collective::SafeColl(collective::Fail("Failed to get sketch.", std::move(rc)));
}
}
template <typename WQSketch>
@ -260,7 +265,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, Meta
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_categories.data(), global_categories.size()));
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
static_cast<bst_feature_t>(categories_.size())};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
if (!IsCat(feature_types_, fidx)) {
return;
@ -285,8 +290,9 @@ void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
monitor_.Start(__func__);
size_t n_columns = sketches_.size();
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
bst_feature_t n_columns = sketches_.size();
auto rc = collective::Allreduce(ctx, &n_columns, collective::Op::kMax);
collective::SafeColl(rc);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
AllreduceCategories(ctx, info);
@ -300,8 +306,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
// Prune the intermediate num cuts for synchronization.
std::vector<bst_idx_t> global_column_size(columns_size_);
auto rc = collective::GlobalSum(
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_column_size.data(), global_column_size.size()));
collective::SafeColl(rc);
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {

Some files were not shown because too many files have changed in this diff Show More