Merge branch 'master' into sync-condition-2023Oct11
This commit is contained in:
commit
d7f1235b7d
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@ -151,4 +151,4 @@ jobs:
|
||||
python-package/xgboost/lib python-package/xgboost/rabit \
|
||||
python-package/xgboost/src
|
||||
|
||||
sh ./tests/ci_build/lint_cmake.sh || true
|
||||
sh ./tests/ci_build/lint_cmake.sh
|
||||
|
||||
@ -33,7 +33,7 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(${xgboost_SOURCE_DIR}/cmake/FindPrefetchIntrinsics.cmake)
|
||||
include(${xgboost_SOURCE_DIR}/cmake/PrefetchIntrinsics.cmake)
|
||||
find_prefetch_intrinsics()
|
||||
include(${xgboost_SOURCE_DIR}/cmake/Version.cmake)
|
||||
write_version()
|
||||
|
||||
@ -10,14 +10,14 @@ How to tune parameters
|
||||
See :doc:`Parameter Tuning Guide </tutorials/param_tuning>`.
|
||||
|
||||
************************
|
||||
Description on the model
|
||||
Description of the model
|
||||
************************
|
||||
See :doc:`Introduction to Boosted Trees </tutorials/model>`.
|
||||
|
||||
********************
|
||||
I have a big dataset
|
||||
********************
|
||||
XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fit into your memory.
|
||||
XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fits into your memory.
|
||||
This usually means millions of instances.
|
||||
|
||||
If you are running out of memory, checkout the tutorial page for using :doc:`distributed training </tutorials/index>` with one of the many frameworks, or the :doc:`external memory version </tutorials/external_memory>` for using external memory.
|
||||
@ -37,14 +37,14 @@ The ultimate question will still come back to how to push the limit of each comp
|
||||
and use less resources to complete the task (thus with less communication and chance of failure).
|
||||
|
||||
To achieve these, we decide to reuse the optimizations in the single node XGBoost and build the distributed version on top of it.
|
||||
The demand of communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit).
|
||||
The demand for communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit).
|
||||
Such design allows us to reuse most of the code, while being portable to major platforms such as Hadoop/Yarn, MPI, SGE.
|
||||
Most importantly, it pushes the limit of the computation resources we can use.
|
||||
|
||||
****************************************
|
||||
How can I port a model to my own system?
|
||||
****************************************
|
||||
The model and data format of XGBoost is exchangeable,
|
||||
The model and data format of XGBoost are exchangeable,
|
||||
which means the model trained by one language can be loaded in another.
|
||||
This means you can train the model using R, while running prediction using
|
||||
Java or C++, which are more common in production systems.
|
||||
|
||||
@ -73,7 +73,7 @@ Parameters for Tree Booster
|
||||
===========================
|
||||
* ``eta`` [default=0.3, alias: ``learning_rate``]
|
||||
|
||||
- Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative.
|
||||
- Step size shrinkage used in update to prevent overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative.
|
||||
- range: [0,1]
|
||||
|
||||
* ``gamma`` [default=0, alias: ``min_split_loss``]
|
||||
|
||||
@ -87,8 +87,8 @@ XGBoost PySpark GPU support
|
||||
XGBoost PySpark fully supports GPU acceleration. Users are not only able to enable
|
||||
efficient training but also utilize their GPUs for the whole PySpark pipeline including
|
||||
ETL and inference. In below sections, we will walk through an example of training on a
|
||||
PySpark standalone GPU cluster. To get started, first we need to install some additional
|
||||
packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
|
||||
Spark standalone cluster with GPU support. To get started, first we need to install some
|
||||
additional packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
|
||||
|
||||
Prepare the necessary packages
|
||||
==============================
|
||||
@ -128,7 +128,8 @@ Write your PySpark application
|
||||
==============================
|
||||
|
||||
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
|
||||
using a list of feature names and the additional parameter ``device``:
|
||||
using a list of feature names instead of vector type as the input. The parameter ``"device=cuda"``
|
||||
specifically indicates that the training will be performed on a GPU.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -163,14 +164,29 @@ using a list of feature names and the additional parameter ``device``:
|
||||
predict_df = model.transform(test_df)
|
||||
predict_df.show()
|
||||
|
||||
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
|
||||
Like other distributed interfaces, the ``device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
|
||||
|
||||
.. _stage-level-scheduling:
|
||||
|
||||
Submit the PySpark application
|
||||
==============================
|
||||
|
||||
Assuming you have configured your Spark cluster with GPU support. Otherwise, please
|
||||
Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please
|
||||
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.
|
||||
|
||||
Starting from XGBoost 2.0.1, stage-level scheduling is automatically enabled. Therefore,
|
||||
if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend
|
||||
configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will
|
||||
enable running multiple tasks in parallel during the ETL phase. An example configuration
|
||||
would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are
|
||||
using a XGBoost version earlier than 2.0.1 or a Spark standalone cluster version below 3.4.0,
|
||||
you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``.
|
||||
|
||||
.. note::
|
||||
|
||||
As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode.
|
||||
However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export PYSPARK_DRIVER_PYTHON=python
|
||||
@ -178,19 +194,21 @@ refer to `spark standalone configuration with GPU support <https://nvidia.github
|
||||
|
||||
spark-submit \
|
||||
--master spark://<master-ip>:7077 \
|
||||
--conf spark.executor.cores=12 \
|
||||
--conf spark.task.cpus=1 \
|
||||
--conf spark.executor.resource.gpu.amount=1 \
|
||||
--conf spark.task.resource.gpu.amount=1 \
|
||||
--conf spark.task.resource.gpu.amount=0.08 \
|
||||
--archives xgboost_env.tar.gz#environment \
|
||||
xgboost_app.py
|
||||
|
||||
|
||||
The submit command sends the Python environment created by pip or conda along with the
|
||||
specification of GPU allocation. We will revisit this command later on.
|
||||
The above command submits the xgboost pyspark application with the python environment created by pip or conda,
|
||||
specifying a request for 1 GPU and 12 CPUs per executor. So you can see, a total of 12 tasks per executor will be
|
||||
executed concurrently during the ETL phase.
|
||||
|
||||
Model Persistence
|
||||
=================
|
||||
|
||||
Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save`
|
||||
Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save``
|
||||
and ``load`` methods:
|
||||
|
||||
.. code-block:: python
|
||||
@ -230,8 +248,13 @@ Accelerate the whole pipeline for xgboost pyspark
|
||||
|
||||
With `RAPIDS Accelerator for Apache Spark <https://nvidia.github.io/spark-rapids/>`_, you
|
||||
can leverage GPUs to accelerate the whole pipeline (ETL, Train, Transform) for xgboost
|
||||
pyspark without any Python code change. An example submit command is shown below with
|
||||
additional spark configurations and dependencies:
|
||||
pyspark without the need for any code modifications. Likewise, you have the option to configure
|
||||
the ``"spark.task.resource.gpu.amount"`` setting as a fractional value, enabling a higher
|
||||
number of tasks to be executed in parallel during the ETL phase. please refer to
|
||||
:ref:`stage-level-scheduling` for more details.
|
||||
|
||||
|
||||
An example submit command is shown below with additional spark configurations and dependencies:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -240,8 +263,10 @@ additional spark configurations and dependencies:
|
||||
|
||||
spark-submit \
|
||||
--master spark://<master-ip>:7077 \
|
||||
--conf spark.executor.cores=12 \
|
||||
--conf spark.task.cpus=1 \
|
||||
--conf spark.executor.resource.gpu.amount=1 \
|
||||
--conf spark.task.resource.gpu.amount=1 \
|
||||
--conf spark.task.resource.gpu.amount=0.08 \
|
||||
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
|
||||
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
||||
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
|
||||
|
||||
@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(objxgboost PRIVATE federated_server.cc)
|
||||
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
|
||||
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||
|
||||
114
plugin/federated/federated_comm.cc
Normal file
114
plugin/federated/federated_comm.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_comm.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for getenv
|
||||
#include <string> // for string, stoi
|
||||
|
||||
#include "../../src/common/common.h" // for Split
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank, std::string const& server_cert,
|
||||
std::string const& client_key, std::string const& client_cert) {
|
||||
this->rank_ = rank;
|
||||
this->world_ = world;
|
||||
|
||||
this->tracker_.host = host;
|
||||
this->tracker_.port = port;
|
||||
this->tracker_.rank = rank;
|
||||
|
||||
CHECK_GE(world, 1) << "Invalid world size.";
|
||||
CHECK_GE(rank, 0) << "Invalid worker rank.";
|
||||
CHECK_LT(rank, world) << "Invalid worker rank.";
|
||||
|
||||
if (server_cert.empty()) {
|
||||
stub_ = [&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return federated::Federated::NewStub(
|
||||
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
|
||||
}();
|
||||
} else {
|
||||
stub_ = [&] {
|
||||
grpc::SslCredentialsOptions options;
|
||||
options.pem_root_certs = server_cert;
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return federated::Federated::NewStub(channel);
|
||||
}();
|
||||
}
|
||||
}
|
||||
|
||||
FederatedComm::FederatedComm(Json const& config) {
|
||||
/**
|
||||
* Topology
|
||||
*/
|
||||
std::string server_address{};
|
||||
std::int32_t world_size{0};
|
||||
std::int32_t rank{-1};
|
||||
// Parse environment variables first.
|
||||
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
|
||||
if (value != nullptr) {
|
||||
server_address = value;
|
||||
}
|
||||
value = std::getenv("FEDERATED_WORLD_SIZE");
|
||||
if (value != nullptr) {
|
||||
world_size = std::stoi(value);
|
||||
}
|
||||
value = std::getenv("FEDERATED_RANK");
|
||||
if (value != nullptr) {
|
||||
rank = std::stoi(value);
|
||||
}
|
||||
|
||||
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||
world_size =
|
||||
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
|
||||
auto parsed = common::Split(server_address, ':');
|
||||
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
|
||||
|
||||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||
|
||||
/**
|
||||
* Certificates
|
||||
*/
|
||||
std::string server_cert{};
|
||||
std::string client_key{};
|
||||
std::string client_cert{};
|
||||
value = getenv("FEDERATED_SERVER_CERT_PATH");
|
||||
if (value != nullptr) {
|
||||
server_cert = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_KEY_PATH");
|
||||
if (value != nullptr) {
|
||||
client_key = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_CERT_PATH");
|
||||
if (value != nullptr) {
|
||||
client_cert = value;
|
||||
}
|
||||
|
||||
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
|
||||
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
|
||||
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);
|
||||
|
||||
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
53
plugin/federated/federated_comm.h
Normal file
53
plugin/federated/federated_comm.h
Normal file
@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
std::unique_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
std::string const& server_cert, std::string const& client_key,
|
||||
std::string const& client_cert);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @param config
|
||||
*
|
||||
* - federated_server_address: Tracker address
|
||||
* - federated_world_size: The number of workers
|
||||
* - federated_rank: Rank of federated worker
|
||||
* - federated_server_cert_path
|
||||
* - federated_client_key_path
|
||||
* - federated_client_cert_path
|
||||
*/
|
||||
explicit FederatedComm(Json const& config);
|
||||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
||||
std::int32_t rank) {
|
||||
this->Init(host, port, world, rank, {}, {}, {});
|
||||
}
|
||||
~FederatedComm() override { stub_.reset(); }
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
|
||||
return nullptr;
|
||||
}
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override {
|
||||
LOG(CONSOLE) << msg;
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -4,12 +4,15 @@
|
||||
#include "federated_server.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/server.h> // for Server
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "../../src/collective/comm.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/json_utils.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
auto options =
|
||||
@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost {
|
||||
namespace federated {
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::federated
|
||||
|
||||
101
plugin/federated/federated_tracker.cc
Normal file
101
plugin/federated/federated_tracker.cc
Normal file
@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_tracker.h"
|
||||
|
||||
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
||||
#include <grpcpp/server_builder.h> // for ServerBuilder
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
#include "federated_server.h" // for FederatedService
|
||||
|
||||
namespace xgboost::collective {
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
|
||||
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
|
||||
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
std::future<Result> FederatedTracker::Run() {
|
||||
return std::async([this]() {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
|
||||
grpc::ServerBuilder builder;
|
||||
|
||||
if (this->server_cert_file_.empty()) {
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
if (this->port_ == 0) {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
|
||||
} else {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< this->n_workers_;
|
||||
} else {
|
||||
auto options = grpc::SslServerCredentialsOptions(
|
||||
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
|
||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||
key.private_key = xgboost::common::ReadAll(server_key_path_);
|
||||
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
|
||||
options.pem_key_cert_pairs.push_back(key);
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
if (this->port_ == 0) {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
|
||||
} else {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< n_workers_;
|
||||
}
|
||||
|
||||
try {
|
||||
server_->Wait();
|
||||
} catch (std::exception const& e) {
|
||||
return collective::Fail(std::string{e.what()});
|
||||
}
|
||||
return collective::Success();
|
||||
});
|
||||
}
|
||||
|
||||
FederatedTracker::~FederatedTracker() = default;
|
||||
|
||||
Result FederatedTracker::Shutdown() {
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
using namespace std::chrono_literals;
|
||||
while (!server_) {
|
||||
timer.Stop();
|
||||
auto ela = timer.ElapsedSeconds();
|
||||
if (ela > this->Timeout().count()) {
|
||||
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
try {
|
||||
server_->Shutdown();
|
||||
} catch (std::exception const& e) {
|
||||
return Fail("Failed to shutdown:" + std::string{e.what()});
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
41
plugin/federated/federated_tracker.h
Normal file
41
plugin/federated/federated_tracker.h
Normal file
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.grpc.pb.h> // for Server
|
||||
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedTracker : public collective::Tracker {
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::string server_key_path_;
|
||||
std::string server_cert_file_;
|
||||
std::string client_cert_file_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief CTOR
|
||||
*
|
||||
* @param config Configuration, other than the base configuration from Tracker, we have:
|
||||
*
|
||||
* - federated_secure: bool whether this is a secure server.
|
||||
* - server_key_path: path to the key.
|
||||
* - server_cert_path: certificate path.
|
||||
* - client_cert_path: certificate path for client.
|
||||
*/
|
||||
explicit FederatedTracker(Json const& config);
|
||||
~FederatedTracker() override;
|
||||
std::future<Result> Run() override;
|
||||
// federated tracker do not provide initialization parameters, users have to provide it
|
||||
// themseleves.
|
||||
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
|
||||
[[nodiscard]] Result Shutdown();
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -17,7 +17,7 @@ class HasArbitraryParamsDict(Params):
|
||||
Params._dummy(),
|
||||
"arbitrary_params_dict",
|
||||
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
|
||||
"not exposed as the the XGBoost Spark estimator params but can be recognized by "
|
||||
"not exposed as the XGBoost Spark estimator params but can be recognized by "
|
||||
"underlying XGBoost library. It is stored as a dictionary.",
|
||||
)
|
||||
|
||||
|
||||
@ -106,7 +106,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
}
|
||||
}
|
||||
if (dmlc_role != "worker") {
|
||||
LOG(FATAL) << "Rabit Module currently only work with dmlc worker";
|
||||
LOG(FATAL) << "Rabit Module currently only works with dmlc worker";
|
||||
}
|
||||
|
||||
// clear the setting before start reconnection
|
||||
@ -273,7 +273,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* \brief connect to the tracker to fix the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
|
||||
@ -89,7 +89,7 @@ class AllreduceBase : public IEngine {
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
@ -281,7 +281,7 @@ class AllreduceBase : public IEngine {
|
||||
* this function can not be used together with ReadToRingBuffer
|
||||
* a link can either read into the ring buffer, or existing array
|
||||
* \param max_size maximum size of array
|
||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
||||
* \return true if it is a successful read, false if there is some error happens, check errno
|
||||
*/
|
||||
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
|
||||
if (max_size == size_read) return kSuccess;
|
||||
@ -299,7 +299,7 @@ class AllreduceBase : public IEngine {
|
||||
* \brief write data in array to sock
|
||||
* \param sendbuf_ head of array
|
||||
* \param max_size maximum size of array
|
||||
* \return true if it is an successful write, false if there is some error happens, check errno
|
||||
* \return true if it is a successful write, false if there is some error happens, check errno
|
||||
*/
|
||||
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
|
||||
const char *p = static_cast<const char*>(sendbuf_);
|
||||
@ -333,7 +333,7 @@ class AllreduceBase : public IEngine {
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* \brief connect to the tracker to fix the missing links
|
||||
* this function is also used when the engine start up
|
||||
* \param cmd possible command to sent to tracker
|
||||
*/
|
||||
|
||||
@ -7,20 +7,23 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for partial_sum
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "broadcast.h"
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
||||
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
CHECK_LT(worker_off, world);
|
||||
if (world == 1) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
return Success();
|
||||
}
|
||||
|
||||
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t> recv) {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm.World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace detail {
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t> offset,
|
||||
common::Span<std::int64_t const> offset,
|
||||
common::Span<std::int8_t> erased_result) {
|
||||
auto world = comm.World();
|
||||
if (world == 1) {
|
||||
return Success();
|
||||
}
|
||||
auto rank = comm.Rank();
|
||||
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
// get worker offset
|
||||
CHECK_EQ(world + 1, offset.size());
|
||||
std::fill_n(offset.data(), offset.size(), 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
|
||||
// copy data
|
||||
auto current = erased_result.subspan(offset[rank], data.size_bytes());
|
||||
auto erased_data = EraseType(data);
|
||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
} // namespace detail
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -12,25 +12,44 @@
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
|
||||
* = 1, then it owns the third segment.
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
||||
* worker_off = 1, then it owns the third segment.
|
||||
*/
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch);
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t> offset,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
/**
|
||||
* @brief Implement allgather-v using broadcast.
|
||||
*
|
||||
* https://arxiv.org/abs/1812.05964
|
||||
*/
|
||||
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t> recv);
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace detail {
|
||||
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> offset) {
|
||||
// get worker offset
|
||||
std::fill_n(offset.data(), offset.size(), 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
}
|
||||
|
||||
// An implementation that's used by both cpu and gpu
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t const> offset,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
@ -68,9 +87,15 @@ template <typename T>
|
||||
auto h_result = common::Span{result.data(), result.size()};
|
||||
auto erased_result = common::EraseType(h_result);
|
||||
auto erased_data = common::EraseType(data);
|
||||
std::vector<std::int64_t> offset(world + 1);
|
||||
std::vector<std::int64_t> recv_segments(world + 1);
|
||||
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};
|
||||
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
||||
common::Span{offset.data(), offset.size()}, erased_result);
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, s_segments);
|
||||
// copy data
|
||||
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
|
||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||
|
||||
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -12,12 +12,10 @@
|
||||
#include "allreduce.h" // for Allreduce
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
|
||||
Op op) {
|
||||
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type, Op op) {
|
||||
namespace coll = ::xgboost::collective;
|
||||
|
||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||
@ -55,21 +53,45 @@ namespace xgboost::collective {
|
||||
return comm.Block();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::int32_t root) {
|
||||
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
return cpu_impl::Broadcast(comm, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::size_t size) {
|
||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
return RingAllgather(comm, data, size);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv) {
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
|
||||
// copy data
|
||||
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
|
||||
if (current.data() != data.data()) {
|
||||
std::copy_n(data.data(), data.size(), current.data());
|
||||
}
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing:
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
case AllgatherVAlgo::kBcast:
|
||||
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
|
||||
default: {
|
||||
return Fail("Unknown algorithm for allgather-v");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
Coll* Coll::MakeCUDAVar() {
|
||||
LOG(FATAL) << "NCCL is required for device communication.";
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace xgboost::collective
|
||||
|
||||
254
src/collective/coll.cu
Normal file
254
src/collective/coll.cu
Normal file
@ -0,0 +1,254 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../data/array_interface.h"
|
||||
#include "allgather.h" // for AllgatherVOffset
|
||||
#include "coll.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "nccl.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
|
||||
|
||||
NCCLColl::~NCCLColl() = default;
|
||||
namespace {
|
||||
Result GetNCCLResult(ncclResult_t code) {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
|
||||
auto fatal = [] {
|
||||
LOG(FATAL) << "Invalid type for NCCL operation.";
|
||||
return ncclHalf; // dummy return to silent the compiler warning.
|
||||
};
|
||||
using H = ArrayInterfaceHandler;
|
||||
switch (type) {
|
||||
case H::kF2:
|
||||
return ncclHalf;
|
||||
case H::kF4:
|
||||
return ncclFloat32;
|
||||
case H::kF8:
|
||||
return ncclFloat64;
|
||||
case H::kF16:
|
||||
return fatal();
|
||||
case H::kI1:
|
||||
return ncclInt8;
|
||||
case H::kI2:
|
||||
return fatal();
|
||||
case H::kI4:
|
||||
return ncclInt32;
|
||||
case H::kI8:
|
||||
return ncclInt64;
|
||||
case H::kU1:
|
||||
return ncclUint8;
|
||||
case H::kU2:
|
||||
return fatal();
|
||||
case H::kU4:
|
||||
return ncclUint32;
|
||||
case H::kU8:
|
||||
return ncclUint64;
|
||||
}
|
||||
return fatal();
|
||||
}
|
||||
|
||||
bool IsBitwiseOp(Op const& op) {
|
||||
return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> out_buffer,
|
||||
std::int8_t const* device_buffer, Func func, std::int32_t world_size,
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
|
||||
[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle,
|
||||
common::Span<std::int8_t> data, Op op) {
|
||||
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
|
||||
auto* device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
CHECK(handle);
|
||||
auto rc = GetNCCLResult(
|
||||
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
switch (op) {
|
||||
case Op::kBitwiseAND:
|
||||
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and<std::int8_t>(),
|
||||
pcomm->World(), data.size());
|
||||
break;
|
||||
case Op::kBitwiseOR:
|
||||
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or<std::int8_t>(),
|
||||
pcomm->World(), data.size());
|
||||
break;
|
||||
case Op::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor<std::int8_t>(),
|
||||
pcomm->World(), data.size());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
ncclRedOp_t result{ncclMax};
|
||||
switch (op) {
|
||||
case Op::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Op::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Op::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
return Success() << [&] {
|
||||
if (IsBitwiseOp(op)) {
|
||||
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
|
||||
} else {
|
||||
return DispatchDType(type, [=](auto t) {
|
||||
using T = decltype(t);
|
||||
auto rdata = common::RestoreType<T>(data);
|
||||
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||
return GetNCCLResult(rc);
|
||||
});
|
||||
}
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||
nccl->Handle(), nccl->Stream()));
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(
|
||||
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
/**
|
||||
* @brief Implement allgather-v using broadcast.
|
||||
*
|
||||
* https://arxiv.org/abs/1812.05964
|
||||
*/
|
||||
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
if (rc != ncclSuccess) {
|
||||
return GetNCCLResult(rc);
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [] { return GetNCCLResult(ncclGroupEnd()); };
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing: {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
// copy data
|
||||
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
|
||||
if (current.data() != data.data()) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(),
|
||||
cudaMemcpyDeviceToDevice, nccl->Stream()));
|
||||
}
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
} << [] {
|
||||
return GetNCCLResult(ncclGroupEnd());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
case AllgatherVAlgo::kBcast: {
|
||||
return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv);
|
||||
}
|
||||
default: {
|
||||
return Fail("Unknown algorithm for allgather-v");
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
29
src/collective/coll.cuh
Normal file
29
src/collective/coll.cuh
Normal file
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLColl : public Coll {
|
||||
public:
|
||||
~NCCLColl() override;
|
||||
|
||||
[[nodiscard]] Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -2,17 +2,20 @@
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
#include <memory> // for enable_shared_from_this
|
||||
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
enum class AllgatherVAlgo {
|
||||
kRing = 0, // use ring-based allgather-v
|
||||
kBcast = 1, // use broadcast-based allgather-v
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Interface and base implementation for collective.
|
||||
*/
|
||||
@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
Coll() = default;
|
||||
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||
|
||||
Coll* MakeCUDAVar();
|
||||
|
||||
/**
|
||||
* @brief Allreduce
|
||||
*
|
||||
@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
|
||||
* the CPU implementation.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data,
|
||||
[[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op);
|
||||
/**
|
||||
* @brief Broadcast
|
||||
@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] root Root rank for broadcast.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::int32_t root);
|
||||
[[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root);
|
||||
/**
|
||||
* @brief Allgather
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] size Size of data for each worker.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::size_t size);
|
||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size);
|
||||
/**
|
||||
* @brief Allgather with variable length.
|
||||
*
|
||||
* @param [in] data Input data for the current worker.
|
||||
* @param [in] sizes Size of the input from each worker.
|
||||
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
|
||||
* should be equal to (world + 1).
|
||||
* @param [out] recv_segments pre-allocated offset buffer for each worker in the output,
|
||||
* size should be equal to (world + 1). GPU ring-based implementation
|
||||
* doesn't use the buffer.
|
||||
* @param [out] recv pre-allocated buffer for output.
|
||||
*/
|
||||
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
[[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv);
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo);
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
}
|
||||
|
||||
RabitComm::~RabitComm() noexcept(false) {
|
||||
if (!IsDistributed()) {
|
||||
if (!this->IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
auto rc = this->Shutdown();
|
||||
|
||||
112
src/collective/comm.cu
Normal file
112
src/collective/comm.cu
Normal file
@ -0,0 +1,112 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <algorithm> // for sort
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint64_t, int8_t
|
||||
#include <cstring> // for memcpy
|
||||
#include <memory> // for shared_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.cuh" // for NCCLComm
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
|
||||
kRootRank);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
*pid = id;
|
||||
return Success();
|
||||
}
|
||||
|
||||
inline constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(std::uint64_t);
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid, DeviceOrd device) {
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal));
|
||||
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
|
||||
return new NCCLComm{ctx, *this, pimpl};
|
||||
}
|
||||
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||
root.TaskID()},
|
||||
stream_{dh::DefaultStream()} {
|
||||
this->world_ = root.World();
|
||||
this->rank_ = root.Rank();
|
||||
this->domain_ = root.Domain();
|
||||
if (!root.IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
|
||||
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||
std::size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = s_uuid.subspan(i, kUuidLength);
|
||||
j++;
|
||||
}
|
||||
|
||||
std::sort(converted.begin(), converted.end());
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, root.World())
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = GetUniqueId(root, &nccl_unique_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||
|
||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||
this->channels_.emplace_back(
|
||||
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
|
||||
}
|
||||
}
|
||||
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
67
src/collective/comm.cuh
Normal file
67
src/collective/comm.cuh
Normal file
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "coll.h"
|
||||
#include "comm.h"
|
||||
#include "xgboost/context.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
inline Result GetCUDAResult(cudaError rc) {
|
||||
if (rc == cudaSuccess) {
|
||||
return Success();
|
||||
}
|
||||
std::string msg = thrust::system_error(rc, thrust::cuda_category()).what();
|
||||
return Fail(msg);
|
||||
}
|
||||
|
||||
class NCCLComm : public Comm {
|
||||
ncclComm_t nccl_comm_{nullptr};
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
|
||||
|
||||
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
|
||||
[[nodiscard]] Result LogTracker(std::string) const override {
|
||||
LOG(FATAL) << "Device comm is used for logging.";
|
||||
return Fail("Undefined.");
|
||||
}
|
||||
~NCCLComm() override;
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; }
|
||||
[[nodiscard]] Result Block() const override {
|
||||
auto rc = this->Stream().Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
};
|
||||
|
||||
class NCCLChannel : public Channel {
|
||||
std::int32_t rank_{-1};
|
||||
ncclComm_t nccl_comm_{};
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
|
||||
dh::CUDAStreamView stream)
|
||||
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
}
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
}
|
||||
[[nodiscard]] Result Block() override {
|
||||
auto rc = stream_.Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -8,7 +8,6 @@
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for remove_const_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
@ -16,6 +15,7 @@
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
||||
}
|
||||
|
||||
class Channel;
|
||||
class Coll;
|
||||
|
||||
/**
|
||||
* @brief Base communicator storing info about the tracker and other communicators.
|
||||
*/
|
||||
class Comm {
|
||||
protected:
|
||||
std::int32_t world_{1};
|
||||
std::int32_t world_{-1};
|
||||
std::int32_t rank_{0};
|
||||
std::chrono::seconds timeout_{DefaultTimeoutSec()};
|
||||
std::int32_t retry_{DefaultRetry()};
|
||||
@ -69,12 +70,14 @@ class Comm {
|
||||
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
|
||||
[[nodiscard]] auto Domain() const { return domain_; }
|
||||
[[nodiscard]] auto Timeout() const { return timeout_; }
|
||||
[[nodiscard]] auto Retry() const { return retry_; }
|
||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||
|
||||
[[nodiscard]] auto Rank() const { return rank_; }
|
||||
[[nodiscard]] auto World() const { return world_; }
|
||||
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
|
||||
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
||||
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||
[[nodiscard]] Result Block() const { return loop_->Block(); }
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||
return channels_.at(rank);
|
||||
@ -83,6 +86,8 @@ class Comm {
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
@ -116,7 +121,7 @@ class Channel {
|
||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||
: sock_{std::move(sock)}, comm_{comm} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
@ -125,7 +130,7 @@ class Channel {
|
||||
this->SendAll(data.data(), data.size_bytes());
|
||||
}
|
||||
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
@ -133,7 +138,7 @@ class Channel {
|
||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||
|
||||
[[nodiscard]] auto Socket() const { return sock_; }
|
||||
[[nodiscard]] Result Block() { return comm_.Block(); }
|
||||
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
||||
};
|
||||
|
||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||
|
||||
@ -50,6 +50,7 @@ class Tracker {
|
||||
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
||||
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
||||
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
||||
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
|
||||
};
|
||||
|
||||
class RabitTracker : public Tracker {
|
||||
@ -124,7 +125,6 @@ class RabitTracker : public Tracker {
|
||||
|
||||
std::future<Result> Run() override;
|
||||
|
||||
[[nodiscard]] std::int32_t Port() const { return port_; }
|
||||
[[nodiscard]] Json WorkerArgs() const override {
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host_};
|
||||
|
||||
@ -1171,7 +1171,13 @@ class CUDAStreamView {
|
||||
operator cudaStream_t() const { // NOLINT
|
||||
return stream_;
|
||||
}
|
||||
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
|
||||
cudaError_t Sync(bool error = true) {
|
||||
if (error) {
|
||||
dh::safe_cuda(cudaStreamSynchronize(stream_));
|
||||
return cudaSuccess;
|
||||
}
|
||||
return cudaStreamSynchronize(stream_);
|
||||
}
|
||||
};
|
||||
|
||||
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
||||
|
||||
@ -20,7 +20,6 @@
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/ellpack_page.h"
|
||||
@ -40,7 +39,6 @@
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/comm.h" // for RabitComm
|
||||
#include "gtest/gtest.h" // for AssertionR...
|
||||
#include "test_worker.h" // for TestDistri...
|
||||
@ -63,25 +64,7 @@ class Worker : public WorkerForTest {
|
||||
}
|
||||
}
|
||||
|
||||
void TestV() {
|
||||
{
|
||||
// basic test
|
||||
std::int32_t n{comm_.Rank()};
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// V test
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||
void CheckV(common::Span<std::int32_t> result) {
|
||||
std::int32_t k{0};
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
|
||||
@ -93,6 +76,66 @@ class Worker : public WorkerForTest {
|
||||
}
|
||||
}
|
||||
}
|
||||
void TestVRing() {
|
||||
// V test
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||
CheckV(result);
|
||||
}
|
||||
|
||||
void TestVBasic() {
|
||||
// basic test
|
||||
std::int32_t n{comm_.Rank()};
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
void TestVAlgo() {
|
||||
// V test, broadcast
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
auto s_data = common::Span{data.data(), data.size()};
|
||||
|
||||
std::vector<std::int64_t> sizes(comm_.World(), 0);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::shared_ptr<Coll> pcoll{new Coll{}};
|
||||
|
||||
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
|
||||
std::vector<std::int32_t> recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0));
|
||||
|
||||
auto s_recv = common::Span{recv.data(), recv.size()};
|
||||
|
||||
rc = pcoll->AllgatherV(comm_, common::EraseType(s_data),
|
||||
common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_segments.data(), recv_segments.size()},
|
||||
common::EraseType(s_recv), AllgatherVAlgo::kBcast);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
CheckV(s_recv);
|
||||
|
||||
// Test inplace
|
||||
auto test_inplace = [&] (AllgatherVAlgo algo) {
|
||||
std::fill_n(s_recv.data(), s_recv.size(), 0);
|
||||
auto current = s_recv.subspan(recv_segments[comm_.Rank()],
|
||||
recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]);
|
||||
std::copy_n(data.data(), data.size(), current.data());
|
||||
rc = pcoll->AllgatherV(comm_, common::EraseType(current),
|
||||
common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_segments.data(), recv_segments.size()},
|
||||
common::EraseType(s_recv), algo);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
CheckV(s_recv);
|
||||
};
|
||||
|
||||
test_inplace(AllgatherVAlgo::kBcast);
|
||||
test_inplace(AllgatherVAlgo::kRing);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, V) {
|
||||
TEST_F(AllgatherTest, VBasic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestV();
|
||||
worker.TestVBasic();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, VRing) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestVRing();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, VAlgo) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestVAlgo();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
117
tests/cpp/collective/test_allgather.cu
Normal file
117
tests/cpp/collective/test_allgather.cu
Normal file
@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/device_vector.h> // for device_vector
|
||||
#include <thrust/equal.h> // for equal
|
||||
#include <xgboost/span.h> // for Span
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, int64_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
#include "test_worker.h" // for TestDistributed, WorkerForTest
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class Worker : public NCCLWorkerForTest {
|
||||
public:
|
||||
using NCCLWorkerForTest::NCCLWorkerForTest;
|
||||
|
||||
void TestV(AllgatherVAlgo algo) {
|
||||
{
|
||||
// basic test
|
||||
std::size_t n = 1;
|
||||
// create data
|
||||
dh::device_vector<std::int32_t> data(n, comm_.Rank());
|
||||
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(comm_.World(), -1);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(comm_.World(), -1);
|
||||
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||
|
||||
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
{
|
||||
// V test
|
||||
std::size_t n = 256 * 256;
|
||||
// create data
|
||||
dh::device_vector<std::int32_t> data(n * nccl_comm_->Rank(), nccl_comm_->Rank());
|
||||
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
|
||||
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
|
||||
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||
|
||||
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
// check segment size
|
||||
if (algo != AllgatherVAlgo::kBcast) {
|
||||
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
|
||||
ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t));
|
||||
ASSERT_EQ(size, sizes[nccl_comm_->Rank()]);
|
||||
}
|
||||
// check data
|
||||
std::size_t k{0};
|
||||
for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) {
|
||||
std::size_t s = n * r;
|
||||
auto current = dh::ToSpan(result).subspan(k, s);
|
||||
std::vector<std::int32_t> h_data(current.size());
|
||||
dh::CopyDeviceSpanToVector(&h_data, current);
|
||||
for (auto v : h_data) {
|
||||
ASSERT_EQ(v, r);
|
||||
}
|
||||
k += s;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AllgatherTestGPU : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.TestV(AllgatherVAlgo::kRing);
|
||||
w.TestV(AllgatherVAlgo::kBcast);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.TestV(AllgatherVAlgo::kBcast);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
@ -6,10 +6,10 @@
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
namespace {
|
||||
class AllreduceWorker : public WorkerForTest {
|
||||
public:
|
||||
@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest {
|
||||
}
|
||||
|
||||
void BitOr() {
|
||||
Context ctx;
|
||||
std::vector<std::uint32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
|
||||
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
|
||||
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (auto v : data) {
|
||||
|
||||
70
tests/cpp/collective/test_allreduce.cu
Normal file
70
tests/cpp/collective/test_allreduce.cu
Normal file
@ -0,0 +1,70 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class AllreduceTestGPU : public SocketTest {};
|
||||
|
||||
class Worker : public NCCLWorkerForTest {
|
||||
public:
|
||||
using NCCLWorkerForTest::NCCLWorkerForTest;
|
||||
|
||||
void BitOr() {
|
||||
dh::device_vector<std::uint32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
thrust::host_vector<std::uint32_t> h_data(data.size());
|
||||
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
|
||||
for (auto v : h_data) {
|
||||
ASSERT_EQ(v, ~std::uint32_t{0});
|
||||
}
|
||||
}
|
||||
|
||||
void Acc() {
|
||||
dh::device_vector<double> data(314, 1.5);
|
||||
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||
ArrayInterfaceHandler::kF8, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllreduceTestGPU, BitOr) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.BitOr();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTestGPU, Sum) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.Acc();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.Run();
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xgboost::collective
|
||||
|
||||
32
tests/cpp/collective/test_worker.cuh
Normal file
32
tests/cpp/collective/test_worker.cuh
Normal file
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/comm.h" // for Comm
|
||||
#include "test_worker.h"
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLWorkerForTest : public WorkerForTest {
|
||||
protected:
|
||||
std::shared_ptr<Coll> coll_;
|
||||
std::shared_ptr<xgboost::collective::Comm> nccl_comm_;
|
||||
std::shared_ptr<Coll> nccl_coll_;
|
||||
Context ctx_;
|
||||
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Setup() {
|
||||
ctx_ = MakeCUDACtx(comm_.Rank());
|
||||
coll_.reset(new Coll{});
|
||||
nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_));
|
||||
nccl_coll_.reset(coll_->MakeCUDAVar());
|
||||
ASSERT_EQ(comm_.World(), nccl_comm_->World());
|
||||
ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank());
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
|
||||
@ -97,4 +97,29 @@ TEST(BitField, Clear) {
|
||||
TestBitFieldClear<RBitField8>(19);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BitField, CTZ) {
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0);
|
||||
ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b00011100);
|
||||
ASSERT_EQ(cnt, 2);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b00011100);
|
||||
ASSERT_EQ(cnt, 2);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b00011101);
|
||||
ASSERT_EQ(cnt, 0);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b00011101);
|
||||
ASSERT_EQ(cnt, 0);
|
||||
}
|
||||
{
|
||||
auto cnt = TrailingZeroBits(0b1000000000000000);
|
||||
ASSERT_EQ(cnt, 15);
|
||||
cnt = detail::TrailingZeroBitsImpl(0b1000000000000000);
|
||||
ASSERT_EQ(cnt, 15);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test {
|
||||
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
|
||||
|
||||
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
||||
|
||||
/**
|
||||
* @brief poor man's gmock for message matching.
|
||||
*
|
||||
* @tparam Error The type of expected execption.
|
||||
*
|
||||
* @param submsg A substring of the actual error message.
|
||||
* @param fn The function that throws Error
|
||||
*/
|
||||
template <typename Error, typename Fn>
|
||||
void ExpectThrow(std::string submsg, Fn&& fn) {
|
||||
try {
|
||||
fn();
|
||||
} catch (Error const& exc) {
|
||||
auto actual = std::string{exc.what()};
|
||||
ASSERT_NE(actual.find(submsg), std::string::npos)
|
||||
<< "Expecting substring `" << submsg << "` from the error message."
|
||||
<< " Got:\n"
|
||||
<< actual << "\n";
|
||||
return;
|
||||
} catch (std::exception const& exc) {
|
||||
auto actual = exc.what();
|
||||
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
|
||||
return;
|
||||
}
|
||||
ASSERT_TRUE(false) << "No exception is thrown";
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_comm.h"
|
||||
#include "../../collective/net_test.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class FederatedCommTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedComm comm(config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
auto construct = [] {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedComm comm(config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = 6;
|
||||
config["federated_rank"] = 3;
|
||||
config["federated_server_address"] = String{"localhost:0"};
|
||||
FederatedComm comm{config};
|
||||
EXPECT_EQ(comm.World(), 6);
|
||||
EXPECT_EQ(comm.Rank(), 3);
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, IsDistributed) {
|
||||
FederatedComm comm{"localhost", 0, 2, 1};
|
||||
EXPECT_TRUE(comm.IsDistributed());
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, InsecureTracker) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = rank;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedComm comm{config};
|
||||
ASSERT_EQ(comm.Rank(), rank);
|
||||
ASSERT_EQ(comm.World(), n_workers);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
42
tests/cpp/plugin/federated/test_worker.h
Normal file
42
tests/cpp/plugin/federated/test_worker.h
Normal file
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for ms
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
FederatedTracker tracker{config};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals;
|
||||
while (tracker.Port() == 0) {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
auto rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -26,7 +26,7 @@ class ServerForTest {
|
||||
explicit ServerForTest(std::size_t world_size) {
|
||||
server_thread_.reset(new std::thread([this, world_size] {
|
||||
grpc::ServerBuilder builder;
|
||||
xgboost::federated::FederatedService service{world_size};
|
||||
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
int selected_port;
|
||||
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
|
||||
builder.RegisterService(&service);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user