Merge branch 'master' into sync-condition-2023Oct11

This commit is contained in:
Hui Liu 2023-10-30 13:19:33 -07:00
commit d7f1235b7d
41 changed files with 1486 additions and 156 deletions

View File

@ -151,4 +151,4 @@ jobs:
python-package/xgboost/lib python-package/xgboost/rabit \
python-package/xgboost/src
sh ./tests/ci_build/lint_cmake.sh || true
sh ./tests/ci_build/lint_cmake.sh

View File

@ -33,7 +33,7 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
endif()
endif()
include(${xgboost_SOURCE_DIR}/cmake/FindPrefetchIntrinsics.cmake)
include(${xgboost_SOURCE_DIR}/cmake/PrefetchIntrinsics.cmake)
find_prefetch_intrinsics()
include(${xgboost_SOURCE_DIR}/cmake/Version.cmake)
write_version()

View File

@ -10,14 +10,14 @@ How to tune parameters
See :doc:`Parameter Tuning Guide </tutorials/param_tuning>`.
************************
Description on the model
Description of the model
************************
See :doc:`Introduction to Boosted Trees </tutorials/model>`.
********************
I have a big dataset
********************
XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fit into your memory.
XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fits into your memory.
This usually means millions of instances.
If you are running out of memory, checkout the tutorial page for using :doc:`distributed training </tutorials/index>` with one of the many frameworks, or the :doc:`external memory version </tutorials/external_memory>` for using external memory.
@ -26,7 +26,7 @@ If you are running out of memory, checkout the tutorial page for using :doc:`dis
**********************************
How to handle categorical feature?
**********************************
Visit :doc:`this tutorial </tutorials/categorical>` for a walk through of categorical data handling and some worked examples.
Visit :doc:`this tutorial </tutorials/categorical>` for a walkthrough of categorical data handling and some worked examples.
******************************************************************
Why not implement distributed XGBoost on top of X (Spark, Hadoop)?
@ -37,14 +37,14 @@ The ultimate question will still come back to how to push the limit of each comp
and use less resources to complete the task (thus with less communication and chance of failure).
To achieve these, we decide to reuse the optimizations in the single node XGBoost and build the distributed version on top of it.
The demand of communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit).
The demand for communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit).
Such design allows us to reuse most of the code, while being portable to major platforms such as Hadoop/Yarn, MPI, SGE.
Most importantly, it pushes the limit of the computation resources we can use.
****************************************
How can I port a model to my own system?
****************************************
The model and data format of XGBoost is exchangeable,
The model and data format of XGBoost are exchangeable,
which means the model trained by one language can be loaded in another.
This means you can train the model using R, while running prediction using
Java or C++, which are more common in production systems.

View File

@ -73,7 +73,7 @@ Parameters for Tree Booster
===========================
* ``eta`` [default=0.3, alias: ``learning_rate``]
- Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative.
- Step size shrinkage used in update to prevent overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative.
- range: [0,1]
* ``gamma`` [default=0, alias: ``min_split_loss``]

View File

@ -87,8 +87,8 @@ XGBoost PySpark GPU support
XGBoost PySpark fully supports GPU acceleration. Users are not only able to enable
efficient training but also utilize their GPUs for the whole PySpark pipeline including
ETL and inference. In below sections, we will walk through an example of training on a
PySpark standalone GPU cluster. To get started, first we need to install some additional
packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
Spark standalone cluster with GPU support. To get started, first we need to install some
additional packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
Prepare the necessary packages
==============================
@ -128,7 +128,8 @@ Write your PySpark application
==============================
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
using a list of feature names and the additional parameter ``device``:
using a list of feature names instead of vector type as the input. The parameter ``"device=cuda"``
specifically indicates that the training will be performed on a GPU.
.. code-block:: python
@ -163,14 +164,29 @@ using a list of feature names and the additional parameter ``device``:
predict_df = model.transform(test_df)
predict_df.show()
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
Like other distributed interfaces, the ``device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
.. _stage-level-scheduling:
Submit the PySpark application
==============================
Assuming you have configured your Spark cluster with GPU support. Otherwise, please
Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.
Starting from XGBoost 2.0.1, stage-level scheduling is automatically enabled. Therefore,
if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend
configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will
enable running multiple tasks in parallel during the ETL phase. An example configuration
would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are
using a XGBoost version earlier than 2.0.1 or a Spark standalone cluster version below 3.4.0,
you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``.
.. note::
As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode.
However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released.
.. code-block:: bash
export PYSPARK_DRIVER_PYTHON=python
@ -178,19 +194,21 @@ refer to `spark standalone configuration with GPU support <https://nvidia.github
spark-submit \
--master spark://<master-ip>:7077 \
--conf spark.executor.cores=12 \
--conf spark.task.cpus=1 \
--conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=0.08 \
--archives xgboost_env.tar.gz#environment \
xgboost_app.py
The submit command sends the Python environment created by pip or conda along with the
specification of GPU allocation. We will revisit this command later on.
The above command submits the xgboost pyspark application with the python environment created by pip or conda,
specifying a request for 1 GPU and 12 CPUs per executor. So you can see, a total of 12 tasks per executor will be
executed concurrently during the ETL phase.
Model Persistence
=================
Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save`
Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save``
and ``load`` methods:
.. code-block:: python
@ -230,8 +248,13 @@ Accelerate the whole pipeline for xgboost pyspark
With `RAPIDS Accelerator for Apache Spark <https://nvidia.github.io/spark-rapids/>`_, you
can leverage GPUs to accelerate the whole pipeline (ETL, Train, Transform) for xgboost
pyspark without any Python code change. An example submit command is shown below with
additional spark configurations and dependencies:
pyspark without the need for any code modifications. Likewise, you have the option to configure
the ``"spark.task.resource.gpu.amount"`` setting as a fractional value, enabling a higher
number of tasks to be executed in parallel during the ETL phase. please refer to
:ref:`stage-level-scheduling` for more details.
An example submit command is shown below with additional spark configurations and dependencies:
.. code-block:: bash
@ -240,8 +263,10 @@ additional spark configurations and dependencies:
spark-submit \
--master spark://<master-ip>:7077 \
--conf spark.executor.cores=12 \
--conf spark.task.cpus=1 \
--conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=0.08 \
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \

View File

@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc)
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -0,0 +1,114 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "federated_comm.h"
#include <grpcpp/grpcpp.h>
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <string> // for string, stoi
#include "../../src/common/common.h" // for Split
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
#include "xgboost/logging.h"
namespace xgboost::collective {
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank, std::string const& server_cert,
std::string const& client_key, std::string const& client_cert) {
this->rank_ = rank;
this->world_ = world;
this->tracker_.host = host;
this->tracker_.port = port;
this->tracker_.rank = rank;
CHECK_GE(world, 1) << "Invalid world size.";
CHECK_GE(rank, 0) << "Invalid worker rank.";
CHECK_LT(rank, world) << "Invalid worker rank.";
if (server_cert.empty()) {
stub_ = [&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return federated::Federated::NewStub(
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
}();
} else {
stub_ = [&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
}();
}
}
FederatedComm::FederatedComm(Json const& config) {
/**
* Topology
*/
std::string server_address{};
std::int32_t world_size{0};
std::int32_t rank{-1};
// Parse environment variables first.
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = std::getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = std::getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
/**
* Certificates
*/
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
value = getenv("FEDERATED_SERVER_CERT_PATH");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY_PATH");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT_PATH");
if (value != nullptr) {
client_cert = value;
}
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
} // namespace xgboost::collective

View File

@ -0,0 +1,53 @@
/**
* Copyright 2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
class FederatedComm : public Comm {
std::unique_ptr<federated::Federated::Stub> stub_;
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);
public:
/**
* @param config
*
* - federated_server_address: Tracker address
* - federated_world_size: The number of workers
* - federated_rank: Rank of federated worker
* - federated_server_cert_path
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
return nullptr;
}
[[nodiscard]] Result LogTracker(std::string msg) const override {
LOG(CONSOLE) << msg;
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
};
} // namespace xgboost::collective

View File

@ -4,12 +4,15 @@
#include "federated_server.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/server.h> // for Server
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>
#include <sstream>
#include "../../src/collective/comm.h"
#include "../../src/common/io.h"
#include "../../src/common/json_utils.h"
namespace xgboost::federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};
grpc::ServerBuilder builder;
auto options =
@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
void RunInsecureServer(int port, std::size_t world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};
grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

View File

@ -1,18 +1,22 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <cstdint> // for int32_t
#include <future> // for future
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
namespace xgboost {
namespace federated {
namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file);
void RunInsecureServer(int port, std::size_t world_size);
} // namespace federated
} // namespace xgboost
} // namespace xgboost::federated

View File

@ -0,0 +1,101 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "federated_tracker.h"
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
#include <grpcpp/server_builder.h> // for ServerBuilder
#include <chrono> // for ms
#include <cstdint> // for int32_t
#include <exception> // for exception
#include <limits> // for numeric_limits
#include <string> // for string
#include <thread> // for sleep_for
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
#include "federated_server.h" // for FederatedService
namespace xgboost::collective {
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
}
}
std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
grpc::ServerBuilder builder;
if (this->server_cert_file_.empty()) {
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
} else {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< this->n_workers_;
} else {
auto options = grpc::SslServerCredentialsOptions(
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = xgboost::common::ReadAll(server_key_path_);
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
} else {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< n_workers_;
}
try {
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});
}
return collective::Success();
});
}
FederatedTracker::~FederatedTracker() = default;
Result FederatedTracker::Shutdown() {
common::Timer timer;
timer.Start();
using namespace std::chrono_literals;
while (!server_) {
timer.Stop();
auto ela = timer.ElapsedSeconds();
if (ela > this->Timeout().count()) {
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(10ms);
}
try {
server_->Shutdown();
} catch (std::exception const& e) {
return Fail("Failed to shutdown:" + std::string{e.what()});
}
return Success();
}
} // namespace xgboost::collective

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h> // for Server
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;
std::string server_cert_file_;
std::string client_cert_file_;
public:
/**
* @brief CTOR
*
* @param config Configuration, other than the base configuration from Tracker, we have:
*
* - federated_secure: bool whether this is a secure server.
* - server_key_path: path to the key.
* - server_cert_path: certificate path.
* - client_cert_path: certificate path for client.
*/
explicit FederatedTracker(Json const& config);
~FederatedTracker() override;
std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves.
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Result Shutdown();
};
} // namespace xgboost::collective

View File

@ -17,7 +17,7 @@ class HasArbitraryParamsDict(Params):
Params._dummy(),
"arbitrary_params_dict",
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
"not exposed as the the XGBoost Spark estimator params but can be recognized by "
"not exposed as the XGBoost Spark estimator params but can be recognized by "
"underlying XGBoost library. It is stored as a dictionary.",
)

View File

@ -106,7 +106,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
}
}
if (dmlc_role != "worker") {
LOG(FATAL) << "Rabit Module currently only work with dmlc worker";
LOG(FATAL) << "Rabit Module currently only works with dmlc worker";
}
// clear the setting before start reconnection
@ -273,7 +273,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
return xgboost::collective::Success();
}
/*!
* \brief connect to the tracker to fix the the missing links
* \brief connect to the tracker to fix the missing links
* this function is also used when the engine start up
*/
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {

View File

@ -89,7 +89,7 @@ class AllreduceBase : public IEngine {
}
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
@ -281,7 +281,7 @@ class AllreduceBase : public IEngine {
* this function can not be used together with ReadToRingBuffer
* a link can either read into the ring buffer, or existing array
* \param max_size maximum size of array
* \return true if it is an successful read, false if there is some error happens, check errno
* \return true if it is a successful read, false if there is some error happens, check errno
*/
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
if (max_size == size_read) return kSuccess;
@ -299,7 +299,7 @@ class AllreduceBase : public IEngine {
* \brief write data in array to sock
* \param sendbuf_ head of array
* \param max_size maximum size of array
* \return true if it is an successful write, false if there is some error happens, check errno
* \return true if it is a successful write, false if there is some error happens, check errno
*/
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
const char *p = static_cast<const char*>(sendbuf_);
@ -333,7 +333,7 @@ class AllreduceBase : public IEngine {
*/
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
/*!
* \brief connect to the tracker to fix the the missing links
* \brief connect to the tracker to fix the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
@ -358,7 +358,7 @@ class AllreduceBase : public IEngine {
size_t count,
ReduceFunction reducer);
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data

View File

@ -7,20 +7,23 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <numeric> // for partial_sum
#include <vector> // for vector
#include "broadcast.h"
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
namespace xgboost::collective {
namespace cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch) {
auto world = comm.World();
auto rank = comm.Rank();
CHECK_LT(worker_off, world);
if (world == 1) {
return Success();
}
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world;
@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return Success();
}
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv) {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm.World(); ++r) {
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
}
} // namespace cpu_impl
namespace detail {
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
if (world == 1) {
return Success();
}
auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World());
@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
// get worker offset
CHECK_EQ(world + 1, offset.size());
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
// copy data
auto current = erased_result.subspan(offset[rank], data.size_bytes());
auto erased_data = EraseType(data);
std::copy_n(erased_data.data(), erased_data.size(), current.data());
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl
} // namespace detail
} // namespace xgboost::collective

View File

@ -9,28 +9,47 @@
#include <type_traits> // for remove_cv_t
#include <vector> // for vector
#include "../common/type.h" // for EraseType
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
#include "xgboost/linalg.h"
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
* = 1, then it owns the third segment.
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment.
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int8_t> erased_result);
/**
* @brief Implement allgather-v using broadcast.
*
* https://arxiv.org/abs/1812.05964
*/
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv);
} // namespace cpu_impl
namespace detail {
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> offset) {
// get worker offset
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
}
// An implementation that's used by both cpu and gpu
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result);
} // namespace detail
template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_bytes = sizeof(T) * size;
@ -68,9 +87,15 @@ template <typename T>
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = common::EraseType(h_result);
auto erased_data = common::EraseType(data);
std::vector<std::int64_t> offset(world + 1);
std::vector<std::int64_t> recv_segments(world + 1);
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
common::Span{offset.data(), offset.size()}, erased_result);
// get worker offset
detail::AllgatherVOffset(sizes, s_segments);
// copy data
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
std::copy_n(erased_data.data(), erased_data.size(), current.data());
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}
} // namespace xgboost::collective

View File

@ -8,16 +8,14 @@
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "xgboost/context.h" // for Context
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
Op op) {
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type, Op op) {
namespace coll = ::xgboost::collective;
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
@ -55,21 +53,45 @@ namespace xgboost::collective {
return comm.Block();
}
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root) {
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root);
}
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size) {
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
return RingAllgather(comm, data, size);
}
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
common::Span<std::int8_t const> data,
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);
// copy data
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
if (current.data() != data.data()) {
std::copy_n(data.data(), data.size(), current.data());
}
switch (algo) {
case AllgatherVAlgo::kRing:
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
case AllgatherVAlgo::kBcast:
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
default: {
return Fail("Unknown algorithm for allgather-v");
}
}
}
#if !defined(XGBOOST_USE_NCCL)
Coll* Coll::MakeCUDAVar() {
LOG(FATAL) << "NCCL is required for device communication.";
return nullptr;
}
#endif
} // namespace xgboost::collective

254
src/collective/coll.cu Normal file
View File

@ -0,0 +1,254 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <cstdint> // for int8_t, int64_t
#include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh"
#include "../data/array_interface.h"
#include "allgather.h" // for AllgatherVOffset
#include "coll.cuh"
#include "comm.cuh"
#include "nccl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
NCCLColl::~NCCLColl() = default;
namespace {
Result GetNCCLResult(ncclResult_t code) {
if (code == ncclSuccess) {
return Success();
}
std::stringstream ss;
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = cudaPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for NCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
return Fail(ss.str());
}
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
auto fatal = [] {
LOG(FATAL) << "Invalid type for NCCL operation.";
return ncclHalf; // dummy return to silent the compiler warning.
};
using H = ArrayInterfaceHandler;
switch (type) {
case H::kF2:
return ncclHalf;
case H::kF4:
return ncclFloat32;
case H::kF8:
return ncclFloat64;
case H::kF16:
return fatal();
case H::kI1:
return ncclInt8;
case H::kI2:
return fatal();
case H::kI4:
return ncclInt32;
case H::kI8:
return ncclInt64;
case H::kU1:
return ncclUint8;
case H::kU2:
return fatal();
case H::kU4:
return ncclUint32;
case H::kU8:
return ncclUint64;
}
return fatal();
}
bool IsBitwiseOp(Op const& op) {
return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR;
}
template <typename Func>
void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> out_buffer,
std::int8_t const* device_buffer, Func func, std::int32_t world_size,
std::size_t size) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle,
common::Span<std::int8_t> data, Op op) {
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
auto* device_buffer = buffer.data().get();
// First gather data from all the workers.
CHECK(handle);
auto rc = GetNCCLResult(
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
if (!rc.OK()) {
return rc;
}
// Then reduce locally.
switch (op) {
case Op::kBitwiseAND:
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and<std::int8_t>(),
pcomm->World(), data.size());
break;
case Op::kBitwiseOR:
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or<std::int8_t>(),
pcomm->World(), data.size());
break;
case Op::kBitwiseXOR:
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor<std::int8_t>(),
pcomm->World(), data.size());
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
return Success();
}
ncclRedOp_t GetNCCLRedOp(Op const& op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Op::kMax:
result = ncclMax;
break;
case Op::kMin:
result = ncclMin;
break;
case Op::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
} // namespace
[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
return Success() << [&] {
if (IsBitwiseOp(op)) {
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
} else {
return DispatchDType(type, [=](auto t) {
using T = decltype(t);
auto rdata = common::RestoreType<T>(data);
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
return GetNCCLResult(rc);
});
}
} << [&] { return nccl->Block(); };
}
[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
return Success() << [&] {
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
nccl->Handle(), nccl->Stream()));
} << [&] { return nccl->Block(); };
}
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
auto send = data.subspan(comm.Rank() * size, size);
return Success() << [&] {
return GetNCCLResult(
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
} << [&] { return nccl->Block(); };
}
namespace cuda_impl {
/**
* @brief Implement allgather-v using broadcast.
*
* https://arxiv.org/abs/1812.05964
*/
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm->World(); ++r) {
auto as_bytes = sizes[r];
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
ncclInt8, r, comm->Handle(), dh::DefaultStream());
if (rc != ncclSuccess) {
return GetNCCLResult(rc);
}
offset += as_bytes;
}
return Success();
} << [] { return GetNCCLResult(ncclGroupEnd()); };
}
} // namespace cuda_impl
[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
if (!comm.IsDistributed()) {
return Success();
}
switch (algo) {
case AllgatherVAlgo::kRing: {
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);
// copy data
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
if (current.data() != data.data()) {
dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(),
cudaMemcpyDeviceToDevice, nccl->Stream()));
}
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
} << [] {
return GetNCCLResult(ncclGroupEnd());
} << [&] { return nccl->Block(); };
}
case AllgatherVAlgo::kBcast: {
return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv);
}
default: {
return Fail("Unknown algorithm for allgather-v");
}
}
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

29
src/collective/coll.cuh Normal file
View File

@ -0,0 +1,29 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t, int64_t
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
class NCCLColl : public Coll {
public:
~NCCLColl() override;
[[nodiscard]] Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
};
} // namespace xgboost::collective

View File

@ -2,17 +2,20 @@
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <memory> // for enable_shared_from_this
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
enum class AllgatherVAlgo {
kRing = 0, // use ring-based allgather-v
kBcast = 1, // use broadcast-based allgather-v
};
/**
* @brief Interface and base implementation for collective.
*/
@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
Coll() = default;
virtual ~Coll() noexcept(false) {} // NOLINT
Coll* MakeCUDAVar();
/**
* @brief Allreduce
*
@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
* the CPU implementation.
*/
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data,
[[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op);
/**
* @brief Broadcast
@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this<Coll> {
* @param [in,out] data Data buffer for input and output.
* @param [in] root Root rank for broadcast.
*/
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root);
[[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root);
/**
* @brief Allgather
*
* @param [in,out] data Data buffer for input and output.
* @param [in] size Size of data for each worker.
*/
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size);
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size);
/**
* @brief Allgather with variable length.
*
* @param [in] data Input data for the current worker.
* @param [in] sizes Size of the input from each worker.
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
* should be equal to (world + 1).
* @param [out] recv_segments pre-allocated offset buffer for each worker in the output,
* size should be equal to (world + 1). GPU ring-based implementation
* doesn't use the buffer.
* @param [out] recv pre-allocated buffer for output.
*/
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
common::Span<std::int8_t const> data,
[[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv);
common::Span<std::int8_t> recv, AllgatherVAlgo algo);
};
} // namespace xgboost::collective

View File

@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
}
RabitComm::~RabitComm() noexcept(false) {
if (!IsDistributed()) {
if (!this->IsDistributed()) {
return;
}
auto rc = this->Shutdown();

112
src/collective/comm.cu Normal file
View File

@ -0,0 +1,112 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <algorithm> // for sort
#include <cstddef> // for size_t
#include <cstdint> // for uint64_t, int8_t
#include <cstring> // for memcpy
#include <memory> // for shared_ptr
#include <sstream> // for stringstream
#include <vector> // for vector
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "broadcast.h" // for Broadcast
#include "comm.cuh" // for NCCLComm
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace {
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
static const int kRootRank = 0;
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
kRootRank);
if (!rc.OK()) {
return rc;
}
*pid = id;
return Success();
}
inline constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(std::uint64_t);
void GetCudaUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid, DeviceOrd device) {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal));
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
} // namespace
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
return new NCCLComm{ctx, *this, pimpl};
}
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{dh::DefaultStream()} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();
if (!root.IsDistributed()) {
return;
}
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
CHECK(rc.OK()) << rc.Report();
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
std::size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = s_uuid.subspan(i, kUuidLength);
j++;
}
std::sort(converted.begin(), converted.end());
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, root.World())
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
rc = GetUniqueId(root, &nccl_unique_id_);
CHECK(rc.OK()) << rc.Report();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
for (std::int32_t r = 0; r < root.World(); ++r) {
this->channels_.emplace_back(
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
}
}
NCCLComm::~NCCLComm() {
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

67
src/collective/comm.cuh Normal file
View File

@ -0,0 +1,67 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#endif // XGBOOST_USE_NCCL
#include "../common/device_helpers.cuh"
#include "coll.h"
#include "comm.h"
#include "xgboost/context.h"
namespace xgboost::collective {
inline Result GetCUDAResult(cudaError rc) {
if (rc == cudaSuccess) {
return Success();
}
std::string msg = thrust::system_error(rc, thrust::cuda_category()).what();
return Fail(msg);
}
class NCCLComm : public Comm {
ncclComm_t nccl_comm_{nullptr};
ncclUniqueId nccl_unique_id_{};
dh::CUDAStreamView stream_;
public:
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
[[nodiscard]] Result LogTracker(std::string) const override {
LOG(FATAL) << "Device comm is used for logging.";
return Fail("Undefined.");
}
~NCCLComm() override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; }
[[nodiscard]] Result Block() const override {
auto rc = this->Stream().Sync(false);
return GetCUDAResult(rc);
}
};
class NCCLChannel : public Channel {
std::int32_t rank_{-1};
ncclComm_t nccl_comm_{};
dh::CUDAStreamView stream_;
public:
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
dh::CUDAStreamView stream)
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
void SendAll(std::int8_t const* ptr, std::size_t n) override {
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
}
void RecvAll(std::int8_t* ptr, std::size_t n) override {
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
}
[[nodiscard]] Result Block() override {
auto rc = stream_.Sync(false);
return GetCUDAResult(rc);
}
};
} // namespace xgboost::collective

View File

@ -2,20 +2,20 @@
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <type_traits> // for remove_const_t
#include <utility> // for move
#include <vector> // for vector
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <utility> // for move
#include <vector> // for vector
#include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
}
class Channel;
class Coll;
/**
* @brief Base communicator storing info about the tracker and other communicators.
*/
class Comm {
protected:
std::int32_t world_{1};
std::int32_t world_{-1};
std::int32_t rank_{0};
std::chrono::seconds timeout_{DefaultTimeoutSec()};
std::int32_t retry_{DefaultRetry()};
@ -69,12 +70,14 @@ class Comm {
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
[[nodiscard]] auto Domain() const { return domain_; }
[[nodiscard]] auto Timeout() const { return timeout_; }
[[nodiscard]] auto Retry() const { return retry_; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return world_; }
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
void Submit(Loop::Op op) const { loop_->Submit(op); }
[[nodiscard]] Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
return channels_.at(rank);
@ -83,6 +86,8 @@ class Comm {
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
};
class RabitComm : public Comm {
@ -116,7 +121,7 @@ class Channel {
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
: sock_{std::move(sock)}, comm_{comm} {}
void SendAll(std::int8_t const* ptr, std::size_t n) {
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
@ -125,7 +130,7 @@ class Channel {
this->SendAll(data.data(), data.size_bytes());
}
void RecvAll(std::int8_t* ptr, std::size_t n) {
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
@ -133,7 +138,7 @@ class Channel {
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
[[nodiscard]] auto Socket() const { return sock_; }
[[nodiscard]] Result Block() { return comm_.Block(); }
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
};
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };

View File

@ -50,6 +50,7 @@ class Tracker {
[[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
};
class RabitTracker : public Tracker {
@ -124,7 +125,6 @@ class RabitTracker : public Tracker {
std::future<Result> Run() override;
[[nodiscard]] std::int32_t Port() const { return port_; }
[[nodiscard]] Json WorkerArgs() const override {
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};

View File

@ -1171,7 +1171,13 @@ class CUDAStreamView {
operator cudaStream_t() const { // NOLINT
return stream_;
}
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
cudaError_t Sync(bool error = true) {
if (error) {
dh::safe_cuda(cudaStreamSynchronize(stream_));
return cudaSuccess;
}
return cudaStreamSynchronize(stream_);
}
};
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT

View File

@ -20,7 +20,6 @@
#include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/io.h"
#include "../common/timer.h"
#include "../data/ellpack_page.cuh"
#include "../data/ellpack_page.h"
@ -40,7 +39,6 @@
#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h"

View File

@ -14,6 +14,7 @@
#include <vector> // for vector
#include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/comm.h" // for RabitComm
#include "gtest/gtest.h" // for AssertionR...
#include "test_worker.h" // for TestDistri...
@ -63,37 +64,79 @@ class Worker : public WorkerForTest {
}
}
void TestV() {
{
// basic test
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
{
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
std::int32_t k{0};
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
if (comm_.Rank() == 0) {
for (auto v : seg) {
ASSERT_EQ(v, r);
}
k += seg.size();
void CheckV(common::Span<std::int32_t> result) {
std::int32_t k{0};
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
if (comm_.Rank() == 0) {
for (auto v : seg) {
ASSERT_EQ(v, r);
}
k += seg.size();
}
}
}
void TestVRing() {
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
CheckV(result);
}
void TestVBasic() {
// basic test
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
void TestVAlgo() {
// V test, broadcast
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
auto s_data = common::Span{data.data(), data.size()};
std::vector<std::int64_t> sizes(comm_.World(), 0);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
std::shared_ptr<Coll> pcoll{new Coll{}};
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
std::vector<std::int32_t> recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0));
auto s_recv = common::Span{recv.data(), recv.size()};
rc = pcoll->AllgatherV(comm_, common::EraseType(s_data),
common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), AllgatherVAlgo::kBcast);
ASSERT_TRUE(rc.OK());
CheckV(s_recv);
// Test inplace
auto test_inplace = [&] (AllgatherVAlgo algo) {
std::fill_n(s_recv.data(), s_recv.size(), 0);
auto current = s_recv.subspan(recv_segments[comm_.Rank()],
recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]);
std::copy_n(data.data(), data.size(), current.data());
rc = pcoll->AllgatherV(comm_, common::EraseType(current),
common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), algo);
ASSERT_TRUE(rc.OK());
CheckV(s_recv);
};
test_inplace(AllgatherVAlgo::kBcast);
test_inplace(AllgatherVAlgo::kRing);
}
};
} // namespace
@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) {
});
}
TEST_F(AllgatherTest, V) {
TEST_F(AllgatherTest, VBasic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestV();
worker.TestVBasic();
});
}
TEST_F(AllgatherTest, VRing) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestVRing();
});
}
TEST_F(AllgatherTest, VAlgo) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestVAlgo();
});
}
} // namespace xgboost::collective

View File

@ -0,0 +1,117 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
#include <thrust/device_vector.h> // for device_vector
#include <thrust/equal.h> // for equal
#include <xgboost/span.h> // for Span
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int64_t
#include <vector> // for vector
#include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
#include "../../../src/common/type.h" // for EraseType
#include "test_worker.cuh" // for NCCLWorkerForTest
#include "test_worker.h" // for TestDistributed, WorkerForTest
namespace xgboost::collective {
namespace {
class Worker : public NCCLWorkerForTest {
public:
using NCCLWorkerForTest::NCCLWorkerForTest;
void TestV(AllgatherVAlgo algo) {
{
// basic test
std::size_t n = 1;
// create data
dh::device_vector<std::int32_t> data(n, comm_.Rank());
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
// get size
std::vector<std::int64_t> sizes(comm_.World(), -1);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
// create result
dh::device_vector<std::int32_t> result(comm_.World(), -1);
auto s_result = common::EraseType(dh::ToSpan(result));
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
{
// V test
std::size_t n = 256 * 256;
// create data
dh::device_vector<std::int32_t> data(n * nccl_comm_->Rank(), nccl_comm_->Rank());
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
// get size
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
// create result
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
auto s_result = common::EraseType(dh::ToSpan(result));
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
// check segment size
if (algo != AllgatherVAlgo::kBcast) {
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t));
ASSERT_EQ(size, sizes[nccl_comm_->Rank()]);
}
// check data
std::size_t k{0};
for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) {
std::size_t s = n * r;
auto current = dh::ToSpan(result).subspan(k, s);
std::vector<std::int32_t> h_data(current.size());
dh::CopyDeviceSpanToVector(&h_data, current);
for (auto v : h_data) {
ASSERT_EQ(v, r);
}
k += s;
}
}
}
};
class AllgatherTestGPU : public SocketTest {};
} // namespace
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.TestV(AllgatherVAlgo::kRing);
w.TestV(AllgatherVAlgo::kBcast);
});
}
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.TestV(AllgatherVAlgo::kBcast);
});
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

View File

@ -6,10 +6,10 @@
#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for WorkerForTest, TestDistributed
#include "../../../src/common/type.h" // for EraseType
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
namespace {
class AllreduceWorker : public WorkerForTest {
public:
@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest {
}
void BitOr() {
Context ctx;
std::vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (auto v : data) {

View File

@ -0,0 +1,70 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
#include <thrust/host_vector.h> // for host_vector
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/common/common.h"
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
#include "../../../src/common/type.h" // for EraseType
#include "../helpers.h" // for MakeCUDACtx
#include "test_worker.cuh" // for NCCLWorkerForTest
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
namespace {
class AllreduceTestGPU : public SocketTest {};
class Worker : public NCCLWorkerForTest {
public:
using NCCLWorkerForTest::NCCLWorkerForTest;
void BitOr() {
dh::device_vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
thrust::host_vector<std::uint32_t> h_data(data.size());
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
for (auto v : h_data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}
}
void Acc() {
dh::device_vector<double> data(314, 1.5);
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kF8, Op::kSum);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}
};
} // namespace
TEST_F(AllreduceTestGPU, BitOr) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.BitOr();
});
}
TEST_F(AllreduceTestGPU, Sum) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.Acc();
});
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

View File

@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) {
Worker worker{host, port, timeout, n_workers, r};
worker.Run();
});
}
} // namespace
} // namespace xgboost::collective

View File

@ -0,0 +1,32 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/comm.h" // for Comm
#include "test_worker.h"
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
class NCCLWorkerForTest : public WorkerForTest {
protected:
std::shared_ptr<Coll> coll_;
std::shared_ptr<xgboost::collective::Comm> nccl_comm_;
std::shared_ptr<Coll> nccl_coll_;
Context ctx_;
public:
using WorkerForTest::WorkerForTest;
void Setup() {
ctx_ = MakeCUDACtx(comm_.Rank());
coll_.reset(new Coll{});
nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_));
nccl_coll_.reset(coll_->MakeCUDAVar());
ASSERT_EQ(comm_.World(), nccl_comm_->World());
ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank());
}
};
} // namespace xgboost::collective

View File

@ -1,6 +1,7 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <gtest/gtest.h>
#include <chrono> // for seconds

View File

@ -97,4 +97,29 @@ TEST(BitField, Clear) {
TestBitFieldClear<RBitField8>(19);
}
}
TEST(BitField, CTZ) {
{
auto cnt = TrailingZeroBits(0);
ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8);
}
{
auto cnt = TrailingZeroBits(0b00011100);
ASSERT_EQ(cnt, 2);
cnt = detail::TrailingZeroBitsImpl(0b00011100);
ASSERT_EQ(cnt, 2);
}
{
auto cnt = TrailingZeroBits(0b00011101);
ASSERT_EQ(cnt, 0);
cnt = detail::TrailingZeroBitsImpl(0b00011101);
ASSERT_EQ(cnt, 0);
}
{
auto cnt = TrailingZeroBits(0b1000000000000000);
ASSERT_EQ(cnt, 15);
cnt = detail::TrailingZeroBitsImpl(0b1000000000000000);
ASSERT_EQ(cnt, 15);
}
}
} // namespace xgboost

View File

@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test {
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
/**
* @brief poor man's gmock for message matching.
*
* @tparam Error The type of expected execption.
*
* @param submsg A substring of the actual error message.
* @param fn The function that throws Error
*/
template <typename Error, typename Fn>
void ExpectThrow(std::string submsg, Fn&& fn) {
try {
fn();
} catch (Error const& exc) {
auto actual = std::string{exc.what()};
ASSERT_NE(actual.find(submsg), std::string::npos)
<< "Expecting substring `" << submsg << "` from the error message."
<< " Got:\n"
<< actual << "\n";
return;
} catch (std::exception const& exc) {
auto actual = exc.what();
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
return;
}
ASSERT_TRUE(false) << "No exception is thrown";
}
} // namespace xgboost

View File

@ -0,0 +1,84 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <string> // for string
#include <thread> // for thread
#include "../../../../plugin/federated/federated_comm.h"
#include "../../collective/net_test.h" // for SocketTest
#include "../../helpers.h" // for ExpectThrow
#include "test_worker.h" // for TestFederated
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace {
class FederatedCommTest : public SocketTest {};
} // namespace
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
}
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
}
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
}
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
auto construct = [] {
Json config{Object{}};
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
FederatedComm comm(config);
};
ExpectThrow<dmlc::Error>("got: `String`", construct);
}
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
auto construct = [] {
Json config{Object{}};
config["federated_server_address"] = std::string("localhost:0");
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
FederatedComm comm(config);
};
ExpectThrow<dmlc::Error>("got: `String`", construct);
}
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
Json config{Object{}};
config["federated_world_size"] = 6;
config["federated_rank"] = 3;
config["federated_server_address"] = String{"localhost:0"};
FederatedComm comm{config};
EXPECT_EQ(comm.World(), 6);
EXPECT_EQ(comm.Rank(), 3);
}
TEST_F(FederatedCommTest, IsDistributed) {
FederatedComm comm{"localhost", 0, 2, 1};
EXPECT_TRUE(comm.IsDistributed());
}
TEST_F(FederatedCommTest, InsecureTracker) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
Json config{Object{}};
config["federated_world_size"] = n_workers;
config["federated_rank"] = rank;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
FederatedComm comm{config};
ASSERT_EQ(comm.Rank(), rank);
ASSERT_EQ(comm.World(), n_workers);
});
}
} // namespace xgboost::collective

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <gtest/gtest.h>
#include <chrono> // for ms
#include <thread> // for thread
#include "../../../../plugin/federated/federated_tracker.h"
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
template <typename WorkerFn>
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
Json config{Object()};
config["federated_secure"] = Boolean{false};
config["n_workers"] = Integer{n_workers};
FederatedTracker tracker{config};
auto fut = tracker.Run();
std::vector<std::thread> workers;
using namespace std::chrono_literals;
while (tracker.Port() == 0) {
std::this_thread::sleep_for(100ms);
}
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { fn(port, i); });
}
for (auto& t : workers) {
t.join();
}
auto rc = tracker.Shutdown();
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_TRUE(fut.get().OK());
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2022-2023 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
@ -26,7 +26,7 @@ class ServerForTest {
explicit ServerForTest(std::size_t world_size) {
server_thread_.reset(new std::thread([this, world_size] {
grpc::ServerBuilder builder;
xgboost::federated::FederatedService service{world_size};
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
int selected_port;
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
builder.RegisterService(&service);