Merge branch 'master' into sync-condition-2023Oct11
This commit is contained in:
commit
d7f1235b7d
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@ -151,4 +151,4 @@ jobs:
|
|||||||
python-package/xgboost/lib python-package/xgboost/rabit \
|
python-package/xgboost/lib python-package/xgboost/rabit \
|
||||||
python-package/xgboost/src
|
python-package/xgboost/src
|
||||||
|
|
||||||
sh ./tests/ci_build/lint_cmake.sh || true
|
sh ./tests/ci_build/lint_cmake.sh
|
||||||
|
|||||||
@ -33,7 +33,7 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(${xgboost_SOURCE_DIR}/cmake/FindPrefetchIntrinsics.cmake)
|
include(${xgboost_SOURCE_DIR}/cmake/PrefetchIntrinsics.cmake)
|
||||||
find_prefetch_intrinsics()
|
find_prefetch_intrinsics()
|
||||||
include(${xgboost_SOURCE_DIR}/cmake/Version.cmake)
|
include(${xgboost_SOURCE_DIR}/cmake/Version.cmake)
|
||||||
write_version()
|
write_version()
|
||||||
|
|||||||
10
doc/faq.rst
10
doc/faq.rst
@ -10,14 +10,14 @@ How to tune parameters
|
|||||||
See :doc:`Parameter Tuning Guide </tutorials/param_tuning>`.
|
See :doc:`Parameter Tuning Guide </tutorials/param_tuning>`.
|
||||||
|
|
||||||
************************
|
************************
|
||||||
Description on the model
|
Description of the model
|
||||||
************************
|
************************
|
||||||
See :doc:`Introduction to Boosted Trees </tutorials/model>`.
|
See :doc:`Introduction to Boosted Trees </tutorials/model>`.
|
||||||
|
|
||||||
********************
|
********************
|
||||||
I have a big dataset
|
I have a big dataset
|
||||||
********************
|
********************
|
||||||
XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fit into your memory.
|
XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fits into your memory.
|
||||||
This usually means millions of instances.
|
This usually means millions of instances.
|
||||||
|
|
||||||
If you are running out of memory, checkout the tutorial page for using :doc:`distributed training </tutorials/index>` with one of the many frameworks, or the :doc:`external memory version </tutorials/external_memory>` for using external memory.
|
If you are running out of memory, checkout the tutorial page for using :doc:`distributed training </tutorials/index>` with one of the many frameworks, or the :doc:`external memory version </tutorials/external_memory>` for using external memory.
|
||||||
@ -26,7 +26,7 @@ If you are running out of memory, checkout the tutorial page for using :doc:`dis
|
|||||||
**********************************
|
**********************************
|
||||||
How to handle categorical feature?
|
How to handle categorical feature?
|
||||||
**********************************
|
**********************************
|
||||||
Visit :doc:`this tutorial </tutorials/categorical>` for a walk through of categorical data handling and some worked examples.
|
Visit :doc:`this tutorial </tutorials/categorical>` for a walkthrough of categorical data handling and some worked examples.
|
||||||
|
|
||||||
******************************************************************
|
******************************************************************
|
||||||
Why not implement distributed XGBoost on top of X (Spark, Hadoop)?
|
Why not implement distributed XGBoost on top of X (Spark, Hadoop)?
|
||||||
@ -37,14 +37,14 @@ The ultimate question will still come back to how to push the limit of each comp
|
|||||||
and use less resources to complete the task (thus with less communication and chance of failure).
|
and use less resources to complete the task (thus with less communication and chance of failure).
|
||||||
|
|
||||||
To achieve these, we decide to reuse the optimizations in the single node XGBoost and build the distributed version on top of it.
|
To achieve these, we decide to reuse the optimizations in the single node XGBoost and build the distributed version on top of it.
|
||||||
The demand of communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit).
|
The demand for communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit).
|
||||||
Such design allows us to reuse most of the code, while being portable to major platforms such as Hadoop/Yarn, MPI, SGE.
|
Such design allows us to reuse most of the code, while being portable to major platforms such as Hadoop/Yarn, MPI, SGE.
|
||||||
Most importantly, it pushes the limit of the computation resources we can use.
|
Most importantly, it pushes the limit of the computation resources we can use.
|
||||||
|
|
||||||
****************************************
|
****************************************
|
||||||
How can I port a model to my own system?
|
How can I port a model to my own system?
|
||||||
****************************************
|
****************************************
|
||||||
The model and data format of XGBoost is exchangeable,
|
The model and data format of XGBoost are exchangeable,
|
||||||
which means the model trained by one language can be loaded in another.
|
which means the model trained by one language can be loaded in another.
|
||||||
This means you can train the model using R, while running prediction using
|
This means you can train the model using R, while running prediction using
|
||||||
Java or C++, which are more common in production systems.
|
Java or C++, which are more common in production systems.
|
||||||
|
|||||||
@ -73,7 +73,7 @@ Parameters for Tree Booster
|
|||||||
===========================
|
===========================
|
||||||
* ``eta`` [default=0.3, alias: ``learning_rate``]
|
* ``eta`` [default=0.3, alias: ``learning_rate``]
|
||||||
|
|
||||||
- Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative.
|
- Step size shrinkage used in update to prevent overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative.
|
||||||
- range: [0,1]
|
- range: [0,1]
|
||||||
|
|
||||||
* ``gamma`` [default=0, alias: ``min_split_loss``]
|
* ``gamma`` [default=0, alias: ``min_split_loss``]
|
||||||
|
|||||||
@ -87,8 +87,8 @@ XGBoost PySpark GPU support
|
|||||||
XGBoost PySpark fully supports GPU acceleration. Users are not only able to enable
|
XGBoost PySpark fully supports GPU acceleration. Users are not only able to enable
|
||||||
efficient training but also utilize their GPUs for the whole PySpark pipeline including
|
efficient training but also utilize their GPUs for the whole PySpark pipeline including
|
||||||
ETL and inference. In below sections, we will walk through an example of training on a
|
ETL and inference. In below sections, we will walk through an example of training on a
|
||||||
PySpark standalone GPU cluster. To get started, first we need to install some additional
|
Spark standalone cluster with GPU support. To get started, first we need to install some
|
||||||
packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
|
additional packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
|
||||||
|
|
||||||
Prepare the necessary packages
|
Prepare the necessary packages
|
||||||
==============================
|
==============================
|
||||||
@ -128,7 +128,8 @@ Write your PySpark application
|
|||||||
==============================
|
==============================
|
||||||
|
|
||||||
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
|
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
|
||||||
using a list of feature names and the additional parameter ``device``:
|
using a list of feature names instead of vector type as the input. The parameter ``"device=cuda"``
|
||||||
|
specifically indicates that the training will be performed on a GPU.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -163,14 +164,29 @@ using a list of feature names and the additional parameter ``device``:
|
|||||||
predict_df = model.transform(test_df)
|
predict_df = model.transform(test_df)
|
||||||
predict_df.show()
|
predict_df.show()
|
||||||
|
|
||||||
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
|
Like other distributed interfaces, the ``device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
|
||||||
|
|
||||||
|
.. _stage-level-scheduling:
|
||||||
|
|
||||||
Submit the PySpark application
|
Submit the PySpark application
|
||||||
==============================
|
==============================
|
||||||
|
|
||||||
Assuming you have configured your Spark cluster with GPU support. Otherwise, please
|
Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please
|
||||||
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.
|
refer to `spark standalone configuration with GPU support <https://nvidia.github.io/spark-rapids/docs/get-started/getting-started-on-prem.html#spark-standalone-cluster>`_.
|
||||||
|
|
||||||
|
Starting from XGBoost 2.0.1, stage-level scheduling is automatically enabled. Therefore,
|
||||||
|
if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend
|
||||||
|
configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will
|
||||||
|
enable running multiple tasks in parallel during the ETL phase. An example configuration
|
||||||
|
would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are
|
||||||
|
using a XGBoost version earlier than 2.0.1 or a Spark standalone cluster version below 3.4.0,
|
||||||
|
you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode.
|
||||||
|
However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released.
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
export PYSPARK_DRIVER_PYTHON=python
|
export PYSPARK_DRIVER_PYTHON=python
|
||||||
@ -178,19 +194,21 @@ refer to `spark standalone configuration with GPU support <https://nvidia.github
|
|||||||
|
|
||||||
spark-submit \
|
spark-submit \
|
||||||
--master spark://<master-ip>:7077 \
|
--master spark://<master-ip>:7077 \
|
||||||
|
--conf spark.executor.cores=12 \
|
||||||
|
--conf spark.task.cpus=1 \
|
||||||
--conf spark.executor.resource.gpu.amount=1 \
|
--conf spark.executor.resource.gpu.amount=1 \
|
||||||
--conf spark.task.resource.gpu.amount=1 \
|
--conf spark.task.resource.gpu.amount=0.08 \
|
||||||
--archives xgboost_env.tar.gz#environment \
|
--archives xgboost_env.tar.gz#environment \
|
||||||
xgboost_app.py
|
xgboost_app.py
|
||||||
|
|
||||||
|
The above command submits the xgboost pyspark application with the python environment created by pip or conda,
|
||||||
The submit command sends the Python environment created by pip or conda along with the
|
specifying a request for 1 GPU and 12 CPUs per executor. So you can see, a total of 12 tasks per executor will be
|
||||||
specification of GPU allocation. We will revisit this command later on.
|
executed concurrently during the ETL phase.
|
||||||
|
|
||||||
Model Persistence
|
Model Persistence
|
||||||
=================
|
=================
|
||||||
|
|
||||||
Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save`
|
Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save``
|
||||||
and ``load`` methods:
|
and ``load`` methods:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -230,8 +248,13 @@ Accelerate the whole pipeline for xgboost pyspark
|
|||||||
|
|
||||||
With `RAPIDS Accelerator for Apache Spark <https://nvidia.github.io/spark-rapids/>`_, you
|
With `RAPIDS Accelerator for Apache Spark <https://nvidia.github.io/spark-rapids/>`_, you
|
||||||
can leverage GPUs to accelerate the whole pipeline (ETL, Train, Transform) for xgboost
|
can leverage GPUs to accelerate the whole pipeline (ETL, Train, Transform) for xgboost
|
||||||
pyspark without any Python code change. An example submit command is shown below with
|
pyspark without the need for any code modifications. Likewise, you have the option to configure
|
||||||
additional spark configurations and dependencies:
|
the ``"spark.task.resource.gpu.amount"`` setting as a fractional value, enabling a higher
|
||||||
|
number of tasks to be executed in parallel during the ETL phase. please refer to
|
||||||
|
:ref:`stage-level-scheduling` for more details.
|
||||||
|
|
||||||
|
|
||||||
|
An example submit command is shown below with additional spark configurations and dependencies:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
@ -240,8 +263,10 @@ additional spark configurations and dependencies:
|
|||||||
|
|
||||||
spark-submit \
|
spark-submit \
|
||||||
--master spark://<master-ip>:7077 \
|
--master spark://<master-ip>:7077 \
|
||||||
|
--conf spark.executor.cores=12 \
|
||||||
|
--conf spark.task.cpus=1 \
|
||||||
--conf spark.executor.resource.gpu.amount=1 \
|
--conf spark.executor.resource.gpu.amount=1 \
|
||||||
--conf spark.task.resource.gpu.amount=1 \
|
--conf spark.task.resource.gpu.amount=0.08 \
|
||||||
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
|
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
|
||||||
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
||||||
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
|
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
|
||||||
|
|||||||
@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
|
|||||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||||
|
|
||||||
# Rabit engine for Federated Learning.
|
# Rabit engine for Federated Learning.
|
||||||
target_sources(objxgboost PRIVATE federated_server.cc)
|
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
|
||||||
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
||||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||||
|
|||||||
114
plugin/federated/federated_comm.cc
Normal file
114
plugin/federated/federated_comm.cc
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "federated_comm.h"
|
||||||
|
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <cstdlib> // for getenv
|
||||||
|
#include <string> // for string, stoi
|
||||||
|
|
||||||
|
#include "../../src/common/common.h" // for Split
|
||||||
|
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
|
||||||
|
std::int32_t rank, std::string const& server_cert,
|
||||||
|
std::string const& client_key, std::string const& client_cert) {
|
||||||
|
this->rank_ = rank;
|
||||||
|
this->world_ = world;
|
||||||
|
|
||||||
|
this->tracker_.host = host;
|
||||||
|
this->tracker_.port = port;
|
||||||
|
this->tracker_.rank = rank;
|
||||||
|
|
||||||
|
CHECK_GE(world, 1) << "Invalid world size.";
|
||||||
|
CHECK_GE(rank, 0) << "Invalid worker rank.";
|
||||||
|
CHECK_LT(rank, world) << "Invalid worker rank.";
|
||||||
|
|
||||||
|
if (server_cert.empty()) {
|
||||||
|
stub_ = [&] {
|
||||||
|
grpc::ChannelArguments args;
|
||||||
|
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||||
|
return federated::Federated::NewStub(
|
||||||
|
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
|
||||||
|
}();
|
||||||
|
} else {
|
||||||
|
stub_ = [&] {
|
||||||
|
grpc::SslCredentialsOptions options;
|
||||||
|
options.pem_root_certs = server_cert;
|
||||||
|
options.pem_private_key = client_key;
|
||||||
|
options.pem_cert_chain = client_cert;
|
||||||
|
grpc::ChannelArguments args;
|
||||||
|
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||||
|
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
|
||||||
|
channel->WaitForConnected(
|
||||||
|
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||||
|
return federated::Federated::NewStub(channel);
|
||||||
|
}();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FederatedComm::FederatedComm(Json const& config) {
|
||||||
|
/**
|
||||||
|
* Topology
|
||||||
|
*/
|
||||||
|
std::string server_address{};
|
||||||
|
std::int32_t world_size{0};
|
||||||
|
std::int32_t rank{-1};
|
||||||
|
// Parse environment variables first.
|
||||||
|
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
|
||||||
|
if (value != nullptr) {
|
||||||
|
server_address = value;
|
||||||
|
}
|
||||||
|
value = std::getenv("FEDERATED_WORLD_SIZE");
|
||||||
|
if (value != nullptr) {
|
||||||
|
world_size = std::stoi(value);
|
||||||
|
}
|
||||||
|
value = std::getenv("FEDERATED_RANK");
|
||||||
|
if (value != nullptr) {
|
||||||
|
rank = std::stoi(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||||
|
world_size =
|
||||||
|
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||||
|
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||||
|
|
||||||
|
auto parsed = common::Split(server_address, ':');
|
||||||
|
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
|
||||||
|
|
||||||
|
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
||||||
|
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||||
|
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Certificates
|
||||||
|
*/
|
||||||
|
std::string server_cert{};
|
||||||
|
std::string client_key{};
|
||||||
|
std::string client_cert{};
|
||||||
|
value = getenv("FEDERATED_SERVER_CERT_PATH");
|
||||||
|
if (value != nullptr) {
|
||||||
|
server_cert = value;
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_CLIENT_KEY_PATH");
|
||||||
|
if (value != nullptr) {
|
||||||
|
client_key = value;
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_CLIENT_CERT_PATH");
|
||||||
|
if (value != nullptr) {
|
||||||
|
client_cert = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
|
||||||
|
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
|
||||||
|
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);
|
||||||
|
|
||||||
|
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
|
||||||
|
client_cert);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
53
plugin/federated/federated_comm.h
Normal file
53
plugin/federated/federated_comm.h
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <federated.grpc.pb.h>
|
||||||
|
#include <federated.pb.h>
|
||||||
|
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
|
#include "../../src/collective/comm.h" // for Comm
|
||||||
|
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
class FederatedComm : public Comm {
|
||||||
|
std::unique_ptr<federated::Federated::Stub> stub_;
|
||||||
|
|
||||||
|
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||||
|
std::string const& server_cert, std::string const& client_key,
|
||||||
|
std::string const& client_cert);
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @param config
|
||||||
|
*
|
||||||
|
* - federated_server_address: Tracker address
|
||||||
|
* - federated_world_size: The number of workers
|
||||||
|
* - federated_rank: Rank of federated worker
|
||||||
|
* - federated_server_cert_path
|
||||||
|
* - federated_client_key_path
|
||||||
|
* - federated_client_cert_path
|
||||||
|
*/
|
||||||
|
explicit FederatedComm(Json const& config);
|
||||||
|
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
||||||
|
std::int32_t rank) {
|
||||||
|
this->Init(host, port, world, rank, {}, {}, {});
|
||||||
|
}
|
||||||
|
~FederatedComm() override { stub_.reset(); }
|
||||||
|
|
||||||
|
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||||
|
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
[[nodiscard]] Result LogTracker(std::string msg) const override {
|
||||||
|
LOG(CONSOLE) << msg;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -4,12 +4,15 @@
|
|||||||
#include "federated_server.h"
|
#include "federated_server.h"
|
||||||
|
|
||||||
#include <grpcpp/grpcpp.h>
|
#include <grpcpp/grpcpp.h>
|
||||||
|
#include <grpcpp/server.h> // for Server
|
||||||
#include <grpcpp/server_builder.h>
|
#include <grpcpp/server_builder.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "../../src/collective/comm.h"
|
||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
|
#include "../../src/common/json_utils.h"
|
||||||
|
|
||||||
namespace xgboost::federated {
|
namespace xgboost::federated {
|
||||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||||
@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
|||||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||||
char const* server_cert_file, char const* client_cert_file) {
|
char const* server_cert_file, char const* client_cert_file) {
|
||||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||||
FederatedService service{world_size};
|
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||||
|
|
||||||
grpc::ServerBuilder builder;
|
grpc::ServerBuilder builder;
|
||||||
auto options =
|
auto options =
|
||||||
@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
|||||||
|
|
||||||
void RunInsecureServer(int port, std::size_t world_size) {
|
void RunInsecureServer(int port, std::size_t world_size) {
|
||||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||||
FederatedService service{world_size};
|
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||||
|
|
||||||
grpc::ServerBuilder builder;
|
grpc::ServerBuilder builder;
|
||||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||||
|
|||||||
@ -1,18 +1,22 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <federated.grpc.pb.h>
|
#include <federated.grpc.pb.h>
|
||||||
|
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <future> // for future
|
||||||
|
|
||||||
#include "../../src/collective/in_memory_handler.h"
|
#include "../../src/collective/in_memory_handler.h"
|
||||||
|
#include "../../src/collective/tracker.h" // for Tracker
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::federated {
|
||||||
namespace federated {
|
|
||||||
|
|
||||||
class FederatedService final : public Federated::Service {
|
class FederatedService final : public Federated::Service {
|
||||||
public:
|
public:
|
||||||
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
|
explicit FederatedService(std::int32_t world_size)
|
||||||
|
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||||
|
|
||||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||||
AllgatherReply* reply) override;
|
AllgatherReply* reply) override;
|
||||||
@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
|||||||
char const* server_cert_file, char const* client_cert_file);
|
char const* server_cert_file, char const* client_cert_file);
|
||||||
|
|
||||||
void RunInsecureServer(int port, std::size_t world_size);
|
void RunInsecureServer(int port, std::size_t world_size);
|
||||||
|
} // namespace xgboost::federated
|
||||||
} // namespace federated
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
101
plugin/federated/federated_tracker.cc
Normal file
101
plugin/federated/federated_tracker.cc
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "federated_tracker.h"
|
||||||
|
|
||||||
|
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
||||||
|
#include <grpcpp/server_builder.h> // for ServerBuilder
|
||||||
|
|
||||||
|
#include <chrono> // for ms
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <exception> // for exception
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <string> // for string
|
||||||
|
#include <thread> // for sleep_for
|
||||||
|
|
||||||
|
#include "../../src/common/io.h" // for ReadAll
|
||||||
|
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||||
|
#include "../../src/common/timer.h" // for Timer
|
||||||
|
#include "federated_server.h" // for FederatedService
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||||
|
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||||
|
if (is_secure) {
|
||||||
|
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
|
||||||
|
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
|
||||||
|
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::future<Result> FederatedTracker::Run() {
|
||||||
|
return std::async([this]() {
|
||||||
|
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||||
|
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
|
||||||
|
if (this->server_cert_file_.empty()) {
|
||||||
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||||
|
if (this->port_ == 0) {
|
||||||
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
|
||||||
|
} else {
|
||||||
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||||
|
}
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||||
|
<< this->n_workers_;
|
||||||
|
} else {
|
||||||
|
auto options = grpc::SslServerCredentialsOptions(
|
||||||
|
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||||
|
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
|
||||||
|
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||||
|
key.private_key = xgboost::common::ReadAll(server_key_path_);
|
||||||
|
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
|
||||||
|
options.pem_key_cert_pairs.push_back(key);
|
||||||
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||||
|
if (this->port_ == 0) {
|
||||||
|
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
|
||||||
|
} else {
|
||||||
|
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||||
|
}
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||||
|
<< n_workers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
server_->Wait();
|
||||||
|
} catch (std::exception const& e) {
|
||||||
|
return collective::Fail(std::string{e.what()});
|
||||||
|
}
|
||||||
|
return collective::Success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
FederatedTracker::~FederatedTracker() = default;
|
||||||
|
|
||||||
|
Result FederatedTracker::Shutdown() {
|
||||||
|
common::Timer timer;
|
||||||
|
timer.Start();
|
||||||
|
using namespace std::chrono_literals;
|
||||||
|
while (!server_) {
|
||||||
|
timer.Stop();
|
||||||
|
auto ela = timer.ElapsedSeconds();
|
||||||
|
if (ela > this->Timeout().count()) {
|
||||||
|
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
|
||||||
|
" seconds.");
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(10ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
server_->Shutdown();
|
||||||
|
} catch (std::exception const& e) {
|
||||||
|
return Fail("Failed to shutdown:" + std::string{e.what()});
|
||||||
|
}
|
||||||
|
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
41
plugin/federated/federated_tracker.h
Normal file
41
plugin/federated/federated_tracker.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <federated.grpc.pb.h> // for Server
|
||||||
|
|
||||||
|
#include <future> // for future
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
|
#include "../../src/collective/tracker.h" // for Tracker
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
class FederatedTracker : public collective::Tracker {
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
std::string server_key_path_;
|
||||||
|
std::string server_cert_file_;
|
||||||
|
std::string client_cert_file_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief CTOR
|
||||||
|
*
|
||||||
|
* @param config Configuration, other than the base configuration from Tracker, we have:
|
||||||
|
*
|
||||||
|
* - federated_secure: bool whether this is a secure server.
|
||||||
|
* - server_key_path: path to the key.
|
||||||
|
* - server_cert_path: certificate path.
|
||||||
|
* - client_cert_path: certificate path for client.
|
||||||
|
*/
|
||||||
|
explicit FederatedTracker(Json const& config);
|
||||||
|
~FederatedTracker() override;
|
||||||
|
std::future<Result> Run() override;
|
||||||
|
// federated tracker do not provide initialization parameters, users have to provide it
|
||||||
|
// themseleves.
|
||||||
|
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
|
||||||
|
[[nodiscard]] Result Shutdown();
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -17,7 +17,7 @@ class HasArbitraryParamsDict(Params):
|
|||||||
Params._dummy(),
|
Params._dummy(),
|
||||||
"arbitrary_params_dict",
|
"arbitrary_params_dict",
|
||||||
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
|
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
|
||||||
"not exposed as the the XGBoost Spark estimator params but can be recognized by "
|
"not exposed as the XGBoost Spark estimator params but can be recognized by "
|
||||||
"underlying XGBoost library. It is stored as a dictionary.",
|
"underlying XGBoost library. It is stored as a dictionary.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -106,7 +106,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (dmlc_role != "worker") {
|
if (dmlc_role != "worker") {
|
||||||
LOG(FATAL) << "Rabit Module currently only work with dmlc worker";
|
LOG(FATAL) << "Rabit Module currently only works with dmlc worker";
|
||||||
}
|
}
|
||||||
|
|
||||||
// clear the setting before start reconnection
|
// clear the setting before start reconnection
|
||||||
@ -273,7 +273,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
return xgboost::collective::Success();
|
return xgboost::collective::Success();
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
|
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class AllreduceBase : public IEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
* \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf,
|
||||||
* the data provided by current node k is [slice_begin, slice_end),
|
* the data provided by current node k is [slice_begin, slice_end),
|
||||||
* the next node's segment must start with slice_end
|
* the next node's segment must start with slice_end
|
||||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||||
@ -281,7 +281,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* this function can not be used together with ReadToRingBuffer
|
* this function can not be used together with ReadToRingBuffer
|
||||||
* a link can either read into the ring buffer, or existing array
|
* a link can either read into the ring buffer, or existing array
|
||||||
* \param max_size maximum size of array
|
* \param max_size maximum size of array
|
||||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
* \return true if it is a successful read, false if there is some error happens, check errno
|
||||||
*/
|
*/
|
||||||
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
|
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
|
||||||
if (max_size == size_read) return kSuccess;
|
if (max_size == size_read) return kSuccess;
|
||||||
@ -299,7 +299,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* \brief write data in array to sock
|
* \brief write data in array to sock
|
||||||
* \param sendbuf_ head of array
|
* \param sendbuf_ head of array
|
||||||
* \param max_size maximum size of array
|
* \param max_size maximum size of array
|
||||||
* \return true if it is an successful write, false if there is some error happens, check errno
|
* \return true if it is a successful write, false if there is some error happens, check errno
|
||||||
*/
|
*/
|
||||||
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
|
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
|
||||||
const char *p = static_cast<const char*>(sendbuf_);
|
const char *p = static_cast<const char*>(sendbuf_);
|
||||||
@ -333,7 +333,7 @@ class AllreduceBase : public IEngine {
|
|||||||
*/
|
*/
|
||||||
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
|
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
* \param cmd possible command to sent to tracker
|
* \param cmd possible command to sent to tracker
|
||||||
*/
|
*/
|
||||||
@ -358,7 +358,7 @@ class AllreduceBase : public IEngine {
|
|||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer);
|
ReduceFunction reducer);
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
* \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure
|
||||||
* \param sendrecvbuf_ buffer for both sending and receiving data
|
* \param sendrecvbuf_ buffer for both sending and receiving data
|
||||||
* \param size the size of the data to be broadcasted
|
* \param size the size of the data to be broadcasted
|
||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
|
|||||||
@ -7,20 +7,23 @@
|
|||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <numeric> // for partial_sum
|
|
||||||
#include <vector> // for vector
|
|
||||||
|
|
||||||
|
#include "broadcast.h"
|
||||||
#include "comm.h" // for Comm, Channel
|
#include "comm.h" // for Comm, Channel
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective::cpu_impl {
|
namespace xgboost::collective {
|
||||||
|
namespace cpu_impl {
|
||||||
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
||||||
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
|
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
|
||||||
std::shared_ptr<Channel> next_ch) {
|
std::shared_ptr<Channel> next_ch) {
|
||||||
auto world = comm.World();
|
auto world = comm.World();
|
||||||
auto rank = comm.Rank();
|
auto rank = comm.Rank();
|
||||||
CHECK_LT(worker_off, world);
|
CHECK_LT(worker_off, world);
|
||||||
|
if (world == 1) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < world; ++r) {
|
for (std::int32_t r = 0; r < world; ++r) {
|
||||||
auto send_rank = (rank + world - r + worker_off) % world;
|
auto send_rank = (rank + world - r + worker_off) % world;
|
||||||
@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int8_t> recv) {
|
||||||
|
std::size_t offset = 0;
|
||||||
|
for (std::int32_t r = 0; r < comm.World(); ++r) {
|
||||||
|
auto as_bytes = sizes[r];
|
||||||
|
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
offset += as_bytes;
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
} // namespace cpu_impl
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int8_t const> data,
|
common::Span<std::int64_t const> offset,
|
||||||
common::Span<std::int64_t> offset,
|
|
||||||
common::Span<std::int8_t> erased_result) {
|
common::Span<std::int8_t> erased_result) {
|
||||||
auto world = comm.World();
|
auto world = comm.World();
|
||||||
|
if (world == 1) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
auto rank = comm.Rank();
|
auto rank = comm.Rank();
|
||||||
|
|
||||||
auto prev = BootstrapPrev(rank, comm.World());
|
auto prev = BootstrapPrev(rank, comm.World());
|
||||||
@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
auto prev_ch = comm.Chan(prev);
|
auto prev_ch = comm.Chan(prev);
|
||||||
auto next_ch = comm.Chan(next);
|
auto next_ch = comm.Chan(next);
|
||||||
|
|
||||||
// get worker offset
|
|
||||||
CHECK_EQ(world + 1, offset.size());
|
|
||||||
std::fill_n(offset.data(), offset.size(), 0);
|
|
||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
|
||||||
CHECK_EQ(*offset.cbegin(), 0);
|
|
||||||
|
|
||||||
// copy data
|
|
||||||
auto current = erased_result.subspan(offset[rank], data.size_bytes());
|
|
||||||
auto erased_data = EraseType(data);
|
|
||||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < world; ++r) {
|
for (std::int32_t r = 0; r < world; ++r) {
|
||||||
auto send_rank = (rank + world - r) % world;
|
auto send_rank = (rank + world - r) % world;
|
||||||
auto send_off = offset[send_rank];
|
auto send_off = offset[send_rank];
|
||||||
@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
}
|
}
|
||||||
return comm.Block();
|
return comm.Block();
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective::cpu_impl
|
} // namespace detail
|
||||||
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -12,25 +12,44 @@
|
|||||||
#include "../common/type.h" // for EraseType
|
#include "../common/type.h" // for EraseType
|
||||||
#include "comm.h" // for Comm, Channel
|
#include "comm.h" // for Comm, Channel
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
namespace cpu_impl {
|
namespace cpu_impl {
|
||||||
/**
|
/**
|
||||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
|
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
||||||
* = 1, then it owns the third segment.
|
* worker_off = 1, then it owns the third segment.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
std::size_t segment_size, std::int32_t worker_off,
|
std::size_t segment_size, std::int32_t worker_off,
|
||||||
std::shared_ptr<Channel> prev_ch,
|
std::shared_ptr<Channel> prev_ch,
|
||||||
std::shared_ptr<Channel> next_ch);
|
std::shared_ptr<Channel> next_ch);
|
||||||
|
|
||||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
/**
|
||||||
common::Span<std::int8_t const> data,
|
* @brief Implement allgather-v using broadcast.
|
||||||
common::Span<std::int64_t> offset,
|
*
|
||||||
common::Span<std::int8_t> erased_result);
|
* https://arxiv.org/abs/1812.05964
|
||||||
|
*/
|
||||||
|
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int8_t> recv);
|
||||||
} // namespace cpu_impl
|
} // namespace cpu_impl
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int64_t> offset) {
|
||||||
|
// get worker offset
|
||||||
|
std::fill_n(offset.data(), offset.size(), 0);
|
||||||
|
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||||
|
CHECK_EQ(*offset.cbegin(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// An implementation that's used by both cpu and gpu
|
||||||
|
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int64_t const> offset,
|
||||||
|
common::Span<std::int8_t> erased_result);
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||||
auto n_bytes = sizeof(T) * size;
|
auto n_bytes = sizeof(T) * size;
|
||||||
@ -68,9 +87,15 @@ template <typename T>
|
|||||||
auto h_result = common::Span{result.data(), result.size()};
|
auto h_result = common::Span{result.data(), result.size()};
|
||||||
auto erased_result = common::EraseType(h_result);
|
auto erased_result = common::EraseType(h_result);
|
||||||
auto erased_data = common::EraseType(data);
|
auto erased_data = common::EraseType(data);
|
||||||
std::vector<std::int64_t> offset(world + 1);
|
std::vector<std::int64_t> recv_segments(world + 1);
|
||||||
|
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};
|
||||||
|
|
||||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
// get worker offset
|
||||||
common::Span{offset.data(), offset.size()}, erased_result);
|
detail::AllgatherVOffset(sizes, s_segments);
|
||||||
|
// copy data
|
||||||
|
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
|
||||||
|
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||||
|
|
||||||
|
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -12,12 +12,10 @@
|
|||||||
#include "allreduce.h" // for Allreduce
|
#include "allreduce.h" // for Allreduce
|
||||||
#include "broadcast.h" // for Broadcast
|
#include "broadcast.h" // for Broadcast
|
||||||
#include "comm.h" // for Comm
|
#include "comm.h" // for Comm
|
||||||
#include "xgboost/context.h" // for Context
|
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
|
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
|
ArrayInterfaceHandler::Type, Op op) {
|
||||||
Op op) {
|
|
||||||
namespace coll = ::xgboost::collective;
|
namespace coll = ::xgboost::collective;
|
||||||
|
|
||||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||||
@ -55,21 +53,45 @@ namespace xgboost::collective {
|
|||||||
return comm.Block();
|
return comm.Block();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
|
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
common::Span<std::int8_t> data, std::int32_t root) {
|
std::int32_t root) {
|
||||||
return cpu_impl::Broadcast(comm, data, root);
|
return cpu_impl::Broadcast(comm, data, root);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
|
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
common::Span<std::int8_t> data, std::size_t size) {
|
std::int64_t size) {
|
||||||
return RingAllgather(comm, data, size);
|
return RingAllgather(comm, data, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
|
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||||
common::Span<std::int8_t const> data,
|
|
||||||
common::Span<std::int64_t const> sizes,
|
common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int64_t> recv_segments,
|
common::Span<std::int64_t> recv_segments,
|
||||||
common::Span<std::int8_t> recv) {
|
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||||
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
|
// get worker offset
|
||||||
|
detail::AllgatherVOffset(sizes, recv_segments);
|
||||||
|
|
||||||
|
// copy data
|
||||||
|
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
|
||||||
|
if (current.data() != data.data()) {
|
||||||
|
std::copy_n(data.data(), data.size(), current.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (algo) {
|
||||||
|
case AllgatherVAlgo::kRing:
|
||||||
|
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||||
|
case AllgatherVAlgo::kBcast:
|
||||||
|
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
|
||||||
|
default: {
|
||||||
|
return Fail("Unknown algorithm for allgather-v");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(XGBOOST_USE_NCCL)
|
||||||
|
Coll* Coll::MakeCUDAVar() {
|
||||||
|
LOG(FATAL) << "NCCL is required for device communication.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
254
src/collective/coll.cu
Normal file
254
src/collective/coll.cu
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
#include <cstdint> // for int8_t, int64_t
|
||||||
|
|
||||||
|
#include "../common/cuda_context.cuh"
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
#include "../data/array_interface.h"
|
||||||
|
#include "allgather.h" // for AllgatherVOffset
|
||||||
|
#include "coll.cuh"
|
||||||
|
#include "comm.cuh"
|
||||||
|
#include "nccl.h"
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
|
||||||
|
|
||||||
|
NCCLColl::~NCCLColl() = default;
|
||||||
|
namespace {
|
||||||
|
Result GetNCCLResult(ncclResult_t code) {
|
||||||
|
if (code == ncclSuccess) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
|
||||||
|
if (code == ncclUnhandledCudaError) {
|
||||||
|
// nccl usually preserves the last error so we can get more details.
|
||||||
|
auto err = cudaPeekAtLastError();
|
||||||
|
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||||
|
} else if (code == ncclSystemError) {
|
||||||
|
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||||
|
"the network interface for NCCL via environment variables listed in its reference: "
|
||||||
|
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||||
|
}
|
||||||
|
return Fail(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
|
||||||
|
auto fatal = [] {
|
||||||
|
LOG(FATAL) << "Invalid type for NCCL operation.";
|
||||||
|
return ncclHalf; // dummy return to silent the compiler warning.
|
||||||
|
};
|
||||||
|
using H = ArrayInterfaceHandler;
|
||||||
|
switch (type) {
|
||||||
|
case H::kF2:
|
||||||
|
return ncclHalf;
|
||||||
|
case H::kF4:
|
||||||
|
return ncclFloat32;
|
||||||
|
case H::kF8:
|
||||||
|
return ncclFloat64;
|
||||||
|
case H::kF16:
|
||||||
|
return fatal();
|
||||||
|
case H::kI1:
|
||||||
|
return ncclInt8;
|
||||||
|
case H::kI2:
|
||||||
|
return fatal();
|
||||||
|
case H::kI4:
|
||||||
|
return ncclInt32;
|
||||||
|
case H::kI8:
|
||||||
|
return ncclInt64;
|
||||||
|
case H::kU1:
|
||||||
|
return ncclUint8;
|
||||||
|
case H::kU2:
|
||||||
|
return fatal();
|
||||||
|
case H::kU4:
|
||||||
|
return ncclUint32;
|
||||||
|
case H::kU8:
|
||||||
|
return ncclUint64;
|
||||||
|
}
|
||||||
|
return fatal();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsBitwiseOp(Op const& op) {
|
||||||
|
return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Func>
|
||||||
|
void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> out_buffer,
|
||||||
|
std::int8_t const* device_buffer, Func func, std::int32_t world_size,
|
||||||
|
std::size_t size) {
|
||||||
|
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
||||||
|
auto result = device_buffer[idx];
|
||||||
|
for (auto rank = 1; rank < world_size; rank++) {
|
||||||
|
result = func(result, device_buffer[rank * size + idx]);
|
||||||
|
}
|
||||||
|
out_buffer[idx] = result;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle,
|
||||||
|
common::Span<std::int8_t> data, Op op) {
|
||||||
|
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
|
||||||
|
auto* device_buffer = buffer.data().get();
|
||||||
|
|
||||||
|
// First gather data from all the workers.
|
||||||
|
CHECK(handle);
|
||||||
|
auto rc = GetNCCLResult(
|
||||||
|
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then reduce locally.
|
||||||
|
switch (op) {
|
||||||
|
case Op::kBitwiseAND:
|
||||||
|
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and<std::int8_t>(),
|
||||||
|
pcomm->World(), data.size());
|
||||||
|
break;
|
||||||
|
case Op::kBitwiseOR:
|
||||||
|
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or<std::int8_t>(),
|
||||||
|
pcomm->World(), data.size());
|
||||||
|
break;
|
||||||
|
case Op::kBitwiseXOR:
|
||||||
|
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor<std::int8_t>(),
|
||||||
|
pcomm->World(), data.size());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||||
|
ncclRedOp_t result{ncclMax};
|
||||||
|
switch (op) {
|
||||||
|
case Op::kMax:
|
||||||
|
result = ncclMax;
|
||||||
|
break;
|
||||||
|
case Op::kMin:
|
||||||
|
result = ncclMin;
|
||||||
|
break;
|
||||||
|
case Op::kSum:
|
||||||
|
result = ncclSum;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unsupported reduce operation.";
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
ArrayInterfaceHandler::Type type, Op op) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||||
|
CHECK(nccl);
|
||||||
|
return Success() << [&] {
|
||||||
|
if (IsBitwiseOp(op)) {
|
||||||
|
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
|
||||||
|
} else {
|
||||||
|
return DispatchDType(type, [=](auto t) {
|
||||||
|
using T = decltype(t);
|
||||||
|
auto rdata = common::RestoreType<T>(data);
|
||||||
|
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||||
|
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||||
|
return GetNCCLResult(rc);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} << [&] { return nccl->Block(); };
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
std::int32_t root) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||||
|
CHECK(nccl);
|
||||||
|
return Success() << [&] {
|
||||||
|
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||||
|
nccl->Handle(), nccl->Stream()));
|
||||||
|
} << [&] { return nccl->Block(); };
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
std::int64_t size) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||||
|
CHECK(nccl);
|
||||||
|
auto send = data.subspan(comm.Rank() * size, size);
|
||||||
|
return Success() << [&] {
|
||||||
|
return GetNCCLResult(
|
||||||
|
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
|
||||||
|
} << [&] { return nccl->Block(); };
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace cuda_impl {
|
||||||
|
/**
|
||||||
|
* @brief Implement allgather-v using broadcast.
|
||||||
|
*
|
||||||
|
* https://arxiv.org/abs/1812.05964
|
||||||
|
*/
|
||||||
|
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||||
|
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||||
|
std::size_t offset = 0;
|
||||||
|
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||||
|
auto as_bytes = sizes[r];
|
||||||
|
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||||
|
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||||
|
if (rc != ncclSuccess) {
|
||||||
|
return GetNCCLResult(rc);
|
||||||
|
}
|
||||||
|
offset += as_bytes;
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
} << [] { return GetNCCLResult(ncclGroupEnd()); };
|
||||||
|
}
|
||||||
|
} // namespace cuda_impl
|
||||||
|
|
||||||
|
[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||||
|
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||||
|
CHECK(nccl);
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (algo) {
|
||||||
|
case AllgatherVAlgo::kRing: {
|
||||||
|
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||||
|
// get worker offset
|
||||||
|
detail::AllgatherVOffset(sizes, recv_segments);
|
||||||
|
// copy data
|
||||||
|
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
|
||||||
|
if (current.data() != data.data()) {
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(),
|
||||||
|
cudaMemcpyDeviceToDevice, nccl->Stream()));
|
||||||
|
}
|
||||||
|
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||||
|
} << [] {
|
||||||
|
return GetNCCLResult(ncclGroupEnd());
|
||||||
|
} << [&] { return nccl->Block(); };
|
||||||
|
}
|
||||||
|
case AllgatherVAlgo::kBcast: {
|
||||||
|
return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv);
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
return Fail("Unknown algorithm for allgather-v");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
|
|
||||||
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
29
src/collective/coll.cuh
Normal file
29
src/collective/coll.cuh
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint> // for int8_t, int64_t
|
||||||
|
|
||||||
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
|
#include "coll.h" // for Coll
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
class NCCLColl : public Coll {
|
||||||
|
public:
|
||||||
|
~NCCLColl() override;
|
||||||
|
|
||||||
|
[[nodiscard]] Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
ArrayInterfaceHandler::Type type, Op op) override;
|
||||||
|
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
std::int32_t root) override;
|
||||||
|
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
std::int64_t size) override;
|
||||||
|
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -2,17 +2,20 @@
|
|||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstddef> // for size_t
|
|
||||||
#include <cstdint> // for int8_t, int64_t
|
#include <cstdint> // for int8_t, int64_t
|
||||||
#include <memory> // for enable_shared_from_this
|
#include <memory> // for enable_shared_from_this
|
||||||
|
|
||||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
#include "comm.h" // for Comm
|
#include "comm.h" // for Comm
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/context.h" // for Context
|
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
enum class AllgatherVAlgo {
|
||||||
|
kRing = 0, // use ring-based allgather-v
|
||||||
|
kBcast = 1, // use broadcast-based allgather-v
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Interface and base implementation for collective.
|
* @brief Interface and base implementation for collective.
|
||||||
*/
|
*/
|
||||||
@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
|||||||
Coll() = default;
|
Coll() = default;
|
||||||
virtual ~Coll() noexcept(false) {} // NOLINT
|
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||||
|
|
||||||
|
Coll* MakeCUDAVar();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Allreduce
|
* @brief Allreduce
|
||||||
*
|
*
|
||||||
@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
|||||||
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
|
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
|
||||||
* the CPU implementation.
|
* the CPU implementation.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
|
[[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
common::Span<std::int8_t> data,
|
|
||||||
ArrayInterfaceHandler::Type type, Op op);
|
ArrayInterfaceHandler::Type type, Op op);
|
||||||
/**
|
/**
|
||||||
* @brief Broadcast
|
* @brief Broadcast
|
||||||
@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
|||||||
* @param [in,out] data Data buffer for input and output.
|
* @param [in,out] data Data buffer for input and output.
|
||||||
* @param [in] root Root rank for broadcast.
|
* @param [in] root Root rank for broadcast.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
|
[[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
common::Span<std::int8_t> data, std::int32_t root);
|
std::int32_t root);
|
||||||
/**
|
/**
|
||||||
* @brief Allgather
|
* @brief Allgather
|
||||||
*
|
*
|
||||||
* @param [in,out] data Data buffer for input and output.
|
* @param [in,out] data Data buffer for input and output.
|
||||||
* @param [in] size Size of data for each worker.
|
* @param [in] size Size of data for each worker.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
|
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
common::Span<std::int8_t> data, std::size_t size);
|
std::int64_t size);
|
||||||
/**
|
/**
|
||||||
* @brief Allgather with variable length.
|
* @brief Allgather with variable length.
|
||||||
*
|
*
|
||||||
* @param [in] data Input data for the current worker.
|
* @param [in] data Input data for the current worker.
|
||||||
* @param [in] sizes Size of the input from each worker.
|
* @param [in] sizes Size of the input from each worker.
|
||||||
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
|
* @param [out] recv_segments pre-allocated offset buffer for each worker in the output,
|
||||||
* should be equal to (world + 1).
|
* size should be equal to (world + 1). GPU ring-based implementation
|
||||||
|
* doesn't use the buffer.
|
||||||
* @param [out] recv pre-allocated buffer for output.
|
* @param [out] recv pre-allocated buffer for output.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
|
[[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||||
common::Span<std::int8_t const> data,
|
|
||||||
common::Span<std::int64_t const> sizes,
|
common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int64_t> recv_segments,
|
common::Span<std::int64_t> recv_segments,
|
||||||
common::Span<std::int8_t> recv);
|
common::Span<std::int8_t> recv, AllgatherVAlgo algo);
|
||||||
};
|
};
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
|||||||
}
|
}
|
||||||
|
|
||||||
RabitComm::~RabitComm() noexcept(false) {
|
RabitComm::~RabitComm() noexcept(false) {
|
||||||
if (!IsDistributed()) {
|
if (!this->IsDistributed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto rc = this->Shutdown();
|
auto rc = this->Shutdown();
|
||||||
|
|||||||
112
src/collective/comm.cu
Normal file
112
src/collective/comm.cu
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
#include <algorithm> // for sort
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for uint64_t, int8_t
|
||||||
|
#include <cstring> // for memcpy
|
||||||
|
#include <memory> // for shared_ptr
|
||||||
|
#include <sstream> // for stringstream
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||||
|
#include "../common/type.h" // for EraseType
|
||||||
|
#include "broadcast.h" // for Broadcast
|
||||||
|
#include "comm.cuh" // for NCCLComm
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
|
||||||
|
static const int kRootRank = 0;
|
||||||
|
ncclUniqueId id;
|
||||||
|
if (comm.Rank() == kRootRank) {
|
||||||
|
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||||
|
}
|
||||||
|
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
|
||||||
|
kRootRank);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
*pid = id;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr std::size_t kUuidLength =
|
||||||
|
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(std::uint64_t);
|
||||||
|
|
||||||
|
void GetCudaUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid, DeviceOrd device) {
|
||||||
|
cudaDeviceProp prob{};
|
||||||
|
dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal));
|
||||||
|
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid) {
|
||||||
|
std::stringstream ss;
|
||||||
|
for (auto v : uuid) {
|
||||||
|
ss << std::hex << v;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
|
||||||
|
return new NCCLComm{ctx, *this, pimpl};
|
||||||
|
}
|
||||||
|
|
||||||
|
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||||
|
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||||
|
root.TaskID()},
|
||||||
|
stream_{dh::DefaultStream()} {
|
||||||
|
this->world_ = root.World();
|
||||||
|
this->rank_ = root.Rank();
|
||||||
|
this->domain_ = root.Domain();
|
||||||
|
if (!root.IsDistributed()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||||
|
|
||||||
|
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
|
||||||
|
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
|
||||||
|
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
||||||
|
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||||
|
|
||||||
|
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||||
|
CHECK(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
|
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||||
|
std::size_t j = 0;
|
||||||
|
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||||
|
converted[j] = s_uuid.subspan(i, kUuidLength);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(converted.begin(), converted.end());
|
||||||
|
auto iter = std::unique(converted.begin(), converted.end());
|
||||||
|
auto n_uniques = std::distance(converted.begin(), iter);
|
||||||
|
|
||||||
|
CHECK_EQ(n_uniques, root.World())
|
||||||
|
<< "Multiple processes within communication group running on same CUDA "
|
||||||
|
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||||
|
|
||||||
|
rc = GetUniqueId(root, &nccl_unique_id_);
|
||||||
|
CHECK(rc.OK()) << rc.Report();
|
||||||
|
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||||
|
|
||||||
|
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||||
|
this->channels_.emplace_back(
|
||||||
|
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NCCLComm::~NCCLComm() {
|
||||||
|
if (nccl_comm_) {
|
||||||
|
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
67
src/collective/comm.cuh
Normal file
67
src/collective/comm.cuh
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
#include "nccl.h"
|
||||||
|
#endif // XGBOOST_USE_NCCL
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
#include "coll.h"
|
||||||
|
#include "comm.h"
|
||||||
|
#include "xgboost/context.h"
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
|
||||||
|
inline Result GetCUDAResult(cudaError rc) {
|
||||||
|
if (rc == cudaSuccess) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
std::string msg = thrust::system_error(rc, thrust::cuda_category()).what();
|
||||||
|
return Fail(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
class NCCLComm : public Comm {
|
||||||
|
ncclComm_t nccl_comm_{nullptr};
|
||||||
|
ncclUniqueId nccl_unique_id_{};
|
||||||
|
dh::CUDAStreamView stream_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
|
||||||
|
|
||||||
|
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
|
||||||
|
[[nodiscard]] Result LogTracker(std::string) const override {
|
||||||
|
LOG(FATAL) << "Device comm is used for logging.";
|
||||||
|
return Fail("Undefined.");
|
||||||
|
}
|
||||||
|
~NCCLComm() override;
|
||||||
|
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||||
|
[[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; }
|
||||||
|
[[nodiscard]] Result Block() const override {
|
||||||
|
auto rc = this->Stream().Sync(false);
|
||||||
|
return GetCUDAResult(rc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NCCLChannel : public Channel {
|
||||||
|
std::int32_t rank_{-1};
|
||||||
|
ncclComm_t nccl_comm_{};
|
||||||
|
dh::CUDAStreamView stream_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
|
||||||
|
dh::CUDAStreamView stream)
|
||||||
|
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
|
||||||
|
|
||||||
|
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||||
|
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||||
|
}
|
||||||
|
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||||
|
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||||
|
}
|
||||||
|
[[nodiscard]] Result Block() override {
|
||||||
|
auto rc = stream_.Sync(false);
|
||||||
|
return GetCUDAResult(rc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -8,7 +8,6 @@
|
|||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <thread> // for thread
|
#include <thread> // for thread
|
||||||
#include <type_traits> // for remove_const_t
|
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
@ -16,6 +15,7 @@
|
|||||||
#include "protocol.h" // for PeerInfo
|
#include "protocol.h" // for PeerInfo
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class Channel;
|
class Channel;
|
||||||
|
class Coll;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Base communicator storing info about the tracker and other communicators.
|
* @brief Base communicator storing info about the tracker and other communicators.
|
||||||
*/
|
*/
|
||||||
class Comm {
|
class Comm {
|
||||||
protected:
|
protected:
|
||||||
std::int32_t world_{1};
|
std::int32_t world_{-1};
|
||||||
std::int32_t rank_{0};
|
std::int32_t rank_{0};
|
||||||
std::chrono::seconds timeout_{DefaultTimeoutSec()};
|
std::chrono::seconds timeout_{DefaultTimeoutSec()};
|
||||||
std::int32_t retry_{DefaultRetry()};
|
std::int32_t retry_{DefaultRetry()};
|
||||||
@ -69,12 +70,14 @@ class Comm {
|
|||||||
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
|
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
|
||||||
[[nodiscard]] auto Domain() const { return domain_; }
|
[[nodiscard]] auto Domain() const { return domain_; }
|
||||||
[[nodiscard]] auto Timeout() const { return timeout_; }
|
[[nodiscard]] auto Timeout() const { return timeout_; }
|
||||||
|
[[nodiscard]] auto Retry() const { return retry_; }
|
||||||
|
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||||
|
|
||||||
[[nodiscard]] auto Rank() const { return rank_; }
|
[[nodiscard]] auto Rank() const { return rank_; }
|
||||||
[[nodiscard]] auto World() const { return world_; }
|
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
||||||
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
|
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
||||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||||
[[nodiscard]] Result Block() const { return loop_->Block(); }
|
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||||
|
|
||||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||||
return channels_.at(rank);
|
return channels_.at(rank);
|
||||||
@ -83,6 +86,8 @@ class Comm {
|
|||||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||||
|
|
||||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||||
|
|
||||||
|
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
|
||||||
};
|
};
|
||||||
|
|
||||||
class RabitComm : public Comm {
|
class RabitComm : public Comm {
|
||||||
@ -116,7 +121,7 @@ class Channel {
|
|||||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||||
: sock_{std::move(sock)}, comm_{comm} {}
|
: sock_{std::move(sock)}, comm_{comm} {}
|
||||||
|
|
||||||
void SendAll(std::int8_t const* ptr, std::size_t n) {
|
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||||
CHECK(sock_.get());
|
CHECK(sock_.get());
|
||||||
comm_.Submit(std::move(op));
|
comm_.Submit(std::move(op));
|
||||||
@ -125,7 +130,7 @@ class Channel {
|
|||||||
this->SendAll(data.data(), data.size_bytes());
|
this->SendAll(data.data(), data.size_bytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecvAll(std::int8_t* ptr, std::size_t n) {
|
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||||
CHECK(sock_.get());
|
CHECK(sock_.get());
|
||||||
comm_.Submit(std::move(op));
|
comm_.Submit(std::move(op));
|
||||||
@ -133,7 +138,7 @@ class Channel {
|
|||||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||||
|
|
||||||
[[nodiscard]] auto Socket() const { return sock_; }
|
[[nodiscard]] auto Socket() const { return sock_; }
|
||||||
[[nodiscard]] Result Block() { return comm_.Block(); }
|
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class Tracker {
|
|||||||
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
||||||
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
||||||
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
||||||
|
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
class RabitTracker : public Tracker {
|
class RabitTracker : public Tracker {
|
||||||
@ -124,7 +125,6 @@ class RabitTracker : public Tracker {
|
|||||||
|
|
||||||
std::future<Result> Run() override;
|
std::future<Result> Run() override;
|
||||||
|
|
||||||
[[nodiscard]] std::int32_t Port() const { return port_; }
|
|
||||||
[[nodiscard]] Json WorkerArgs() const override {
|
[[nodiscard]] Json WorkerArgs() const override {
|
||||||
Json args{Object{}};
|
Json args{Object{}};
|
||||||
args["DMLC_TRACKER_URI"] = String{host_};
|
args["DMLC_TRACKER_URI"] = String{host_};
|
||||||
|
|||||||
@ -1171,7 +1171,13 @@ class CUDAStreamView {
|
|||||||
operator cudaStream_t() const { // NOLINT
|
operator cudaStream_t() const { // NOLINT
|
||||||
return stream_;
|
return stream_;
|
||||||
}
|
}
|
||||||
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
|
cudaError_t Sync(bool error = true) {
|
||||||
|
if (error) {
|
||||||
|
dh::safe_cuda(cudaStreamSynchronize(stream_));
|
||||||
|
return cudaSuccess;
|
||||||
|
}
|
||||||
|
return cudaStreamSynchronize(stream_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
||||||
|
|||||||
@ -20,7 +20,6 @@
|
|||||||
#include "../common/cuda_context.cuh" // CUDAContext
|
#include "../common/cuda_context.cuh" // CUDAContext
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/io.h"
|
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
#include "../data/ellpack_page.cuh"
|
#include "../data/ellpack_page.cuh"
|
||||||
#include "../data/ellpack_page.h"
|
#include "../data/ellpack_page.h"
|
||||||
@ -40,7 +39,6 @@
|
|||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/parameter.h"
|
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/task.h" // for ObjInfo
|
#include "xgboost/task.h" // for ObjInfo
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../src/collective/allgather.h" // for RingAllgather
|
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||||
|
#include "../../../src/collective/coll.h" // for Coll
|
||||||
#include "../../../src/collective/comm.h" // for RabitComm
|
#include "../../../src/collective/comm.h" // for RabitComm
|
||||||
#include "gtest/gtest.h" // for AssertionR...
|
#include "gtest/gtest.h" // for AssertionR...
|
||||||
#include "test_worker.h" // for TestDistri...
|
#include "test_worker.h" // for TestDistri...
|
||||||
@ -63,25 +64,7 @@ class Worker : public WorkerForTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestV() {
|
void CheckV(common::Span<std::int32_t> result) {
|
||||||
{
|
|
||||||
// basic test
|
|
||||||
std::int32_t n{comm_.Rank()};
|
|
||||||
std::vector<std::int32_t> result;
|
|
||||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
|
||||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
|
||||||
ASSERT_EQ(result[i], i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// V test
|
|
||||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
|
||||||
std::vector<std::int32_t> result;
|
|
||||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
|
||||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
|
||||||
std::int32_t k{0};
|
std::int32_t k{0};
|
||||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||||
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
|
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
|
||||||
@ -93,6 +76,66 @@ class Worker : public WorkerForTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void TestVRing() {
|
||||||
|
// V test
|
||||||
|
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||||
|
std::vector<std::int32_t> result;
|
||||||
|
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||||
|
CheckV(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestVBasic() {
|
||||||
|
// basic test
|
||||||
|
std::int32_t n{comm_.Rank()};
|
||||||
|
std::vector<std::int32_t> result;
|
||||||
|
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||||
|
ASSERT_EQ(result[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestVAlgo() {
|
||||||
|
// V test, broadcast
|
||||||
|
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||||
|
auto s_data = common::Span{data.data(), data.size()};
|
||||||
|
|
||||||
|
std::vector<std::int64_t> sizes(comm_.World(), 0);
|
||||||
|
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||||
|
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
std::shared_ptr<Coll> pcoll{new Coll{}};
|
||||||
|
|
||||||
|
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
|
||||||
|
std::vector<std::int32_t> recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0));
|
||||||
|
|
||||||
|
auto s_recv = common::Span{recv.data(), recv.size()};
|
||||||
|
|
||||||
|
rc = pcoll->AllgatherV(comm_, common::EraseType(s_data),
|
||||||
|
common::Span{sizes.data(), sizes.size()},
|
||||||
|
common::Span{recv_segments.data(), recv_segments.size()},
|
||||||
|
common::EraseType(s_recv), AllgatherVAlgo::kBcast);
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
CheckV(s_recv);
|
||||||
|
|
||||||
|
// Test inplace
|
||||||
|
auto test_inplace = [&] (AllgatherVAlgo algo) {
|
||||||
|
std::fill_n(s_recv.data(), s_recv.size(), 0);
|
||||||
|
auto current = s_recv.subspan(recv_segments[comm_.Rank()],
|
||||||
|
recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]);
|
||||||
|
std::copy_n(data.data(), data.size(), current.data());
|
||||||
|
rc = pcoll->AllgatherV(comm_, common::EraseType(current),
|
||||||
|
common::Span{sizes.data(), sizes.size()},
|
||||||
|
common::Span{recv_segments.data(), recv_segments.size()},
|
||||||
|
common::EraseType(s_recv), algo);
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
CheckV(s_recv);
|
||||||
|
};
|
||||||
|
|
||||||
|
test_inplace(AllgatherVAlgo::kBcast);
|
||||||
|
test_inplace(AllgatherVAlgo::kRing);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AllgatherTest, V) {
|
TEST_F(AllgatherTest, VBasic) {
|
||||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
std::int32_t r) {
|
std::int32_t r) {
|
||||||
Worker worker{host, port, timeout, n_workers, r};
|
Worker worker{host, port, timeout, n_workers, r};
|
||||||
worker.TestV();
|
worker.TestVBasic();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AllgatherTest, VRing) {
|
||||||
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker worker{host, port, timeout, n_workers, r};
|
||||||
|
worker.TestVRing();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AllgatherTest, VAlgo) {
|
||||||
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker worker{host, port, timeout, n_workers, r};
|
||||||
|
worker.TestVAlgo();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
117
tests/cpp/collective/test_allgather.cu
Normal file
117
tests/cpp/collective/test_allgather.cu
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <thrust/device_vector.h> // for device_vector
|
||||||
|
#include <thrust/equal.h> // for equal
|
||||||
|
#include <xgboost/span.h> // for Span
|
||||||
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t, int64_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||||
|
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||||
|
#include "../../../src/common/type.h" // for EraseType
|
||||||
|
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||||
|
#include "test_worker.h" // for TestDistributed, WorkerForTest
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class Worker : public NCCLWorkerForTest {
|
||||||
|
public:
|
||||||
|
using NCCLWorkerForTest::NCCLWorkerForTest;
|
||||||
|
|
||||||
|
void TestV(AllgatherVAlgo algo) {
|
||||||
|
{
|
||||||
|
// basic test
|
||||||
|
std::size_t n = 1;
|
||||||
|
// create data
|
||||||
|
dh::device_vector<std::int32_t> data(n, comm_.Rank());
|
||||||
|
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
|
||||||
|
// get size
|
||||||
|
std::vector<std::int64_t> sizes(comm_.World(), -1);
|
||||||
|
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||||
|
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
// create result
|
||||||
|
dh::device_vector<std::int32_t> result(comm_.World(), -1);
|
||||||
|
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||||
|
|
||||||
|
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||||
|
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||||
|
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
|
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||||
|
ASSERT_EQ(result[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// V test
|
||||||
|
std::size_t n = 256 * 256;
|
||||||
|
// create data
|
||||||
|
dh::device_vector<std::int32_t> data(n * nccl_comm_->Rank(), nccl_comm_->Rank());
|
||||||
|
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
|
||||||
|
// get size
|
||||||
|
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
|
||||||
|
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
|
||||||
|
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||||
|
// create result
|
||||||
|
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
|
||||||
|
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||||
|
|
||||||
|
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||||
|
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||||
|
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
// check segment size
|
||||||
|
if (algo != AllgatherVAlgo::kBcast) {
|
||||||
|
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
|
||||||
|
ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t));
|
||||||
|
ASSERT_EQ(size, sizes[nccl_comm_->Rank()]);
|
||||||
|
}
|
||||||
|
// check data
|
||||||
|
std::size_t k{0};
|
||||||
|
for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) {
|
||||||
|
std::size_t s = n * r;
|
||||||
|
auto current = dh::ToSpan(result).subspan(k, s);
|
||||||
|
std::vector<std::int32_t> h_data(current.size());
|
||||||
|
dh::CopyDeviceSpanToVector(&h_data, current);
|
||||||
|
for (auto v : h_data) {
|
||||||
|
ASSERT_EQ(v, r);
|
||||||
|
}
|
||||||
|
k += s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllgatherTestGPU : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||||
|
auto n_workers = common::AllVisibleGPUs();
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker w{host, port, timeout, n_workers, r};
|
||||||
|
w.Setup();
|
||||||
|
w.TestV(AllgatherVAlgo::kRing);
|
||||||
|
w.TestV(AllgatherVAlgo::kBcast);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
|
||||||
|
auto n_workers = common::AllVisibleGPUs();
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker w{host, port, timeout, n_workers, r};
|
||||||
|
w.Setup();
|
||||||
|
w.TestV(AllgatherVAlgo::kBcast);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
@ -6,10 +6,10 @@
|
|||||||
#include "../../../src/collective/allreduce.h"
|
#include "../../../src/collective/allreduce.h"
|
||||||
#include "../../../src/collective/coll.h" // for Coll
|
#include "../../../src/collective/coll.h" // for Coll
|
||||||
#include "../../../src/collective/tracker.h"
|
#include "../../../src/collective/tracker.h"
|
||||||
|
#include "../../../src/common/type.h" // for EraseType
|
||||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class AllreduceWorker : public WorkerForTest {
|
class AllreduceWorker : public WorkerForTest {
|
||||||
public:
|
public:
|
||||||
@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BitOr() {
|
void BitOr() {
|
||||||
Context ctx;
|
|
||||||
std::vector<std::uint32_t> data(comm_.World(), 0);
|
std::vector<std::uint32_t> data(comm_.World(), 0);
|
||||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||||
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
|
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
|
||||||
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
|
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
|
||||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
for (auto v : data) {
|
for (auto v : data) {
|
||||||
|
|||||||
70
tests/cpp/collective/test_allreduce.cu
Normal file
70
tests/cpp/collective/test_allreduce.cu
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <thrust/host_vector.h> // for host_vector
|
||||||
|
|
||||||
|
#include "../../../src/collective/coll.h" // for Coll
|
||||||
|
#include "../../../src/common/common.h"
|
||||||
|
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||||
|
#include "../../../src/common/type.h" // for EraseType
|
||||||
|
#include "../helpers.h" // for MakeCUDACtx
|
||||||
|
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||||
|
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class AllreduceTestGPU : public SocketTest {};
|
||||||
|
|
||||||
|
class Worker : public NCCLWorkerForTest {
|
||||||
|
public:
|
||||||
|
using NCCLWorkerForTest::NCCLWorkerForTest;
|
||||||
|
|
||||||
|
void BitOr() {
|
||||||
|
dh::device_vector<std::uint32_t> data(comm_.World(), 0);
|
||||||
|
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||||
|
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||||
|
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
thrust::host_vector<std::uint32_t> h_data(data.size());
|
||||||
|
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
|
||||||
|
for (auto v : h_data) {
|
||||||
|
ASSERT_EQ(v, ~std::uint32_t{0});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Acc() {
|
||||||
|
dh::device_vector<double> data(314, 1.5);
|
||||||
|
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||||
|
ArrayInterfaceHandler::kF8, Op::kSum);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||||
|
auto v = data[i];
|
||||||
|
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(AllreduceTestGPU, BitOr) {
|
||||||
|
auto n_workers = common::AllVisibleGPUs();
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker w{host, port, timeout, n_workers, r};
|
||||||
|
w.Setup();
|
||||||
|
w.BitOr();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AllreduceTestGPU, Sum) {
|
||||||
|
auto n_workers = common::AllVisibleGPUs();
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker w{host, port, timeout, n_workers, r};
|
||||||
|
w.Setup();
|
||||||
|
w.Acc();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) {
|
|||||||
Worker worker{host, port, timeout, n_workers, r};
|
Worker worker{host, port, timeout, n_workers, r};
|
||||||
worker.Run();
|
worker.Run();
|
||||||
});
|
});
|
||||||
}
|
} // namespace
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
32
tests/cpp/collective/test_worker.cuh
Normal file
32
tests/cpp/collective/test_worker.cuh
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <memory> // for shared_ptr
|
||||||
|
|
||||||
|
#include "../../../src/collective/coll.h" // for Coll
|
||||||
|
#include "../../../src/collective/comm.h" // for Comm
|
||||||
|
#include "test_worker.h"
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
class NCCLWorkerForTest : public WorkerForTest {
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<Coll> coll_;
|
||||||
|
std::shared_ptr<xgboost::collective::Comm> nccl_comm_;
|
||||||
|
std::shared_ptr<Coll> nccl_coll_;
|
||||||
|
Context ctx_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using WorkerForTest::WorkerForTest;
|
||||||
|
|
||||||
|
void Setup() {
|
||||||
|
ctx_ = MakeCUDACtx(comm_.Rank());
|
||||||
|
coll_.reset(new Coll{});
|
||||||
|
nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_));
|
||||||
|
nccl_coll_.reset(coll_->MakeCUDAVar());
|
||||||
|
ASSERT_EQ(comm_.World(), nccl_comm_->World());
|
||||||
|
ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#pragma once
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
|
|||||||
@ -97,4 +97,29 @@ TEST(BitField, Clear) {
|
|||||||
TestBitFieldClear<RBitField8>(19);
|
TestBitFieldClear<RBitField8>(19);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BitField, CTZ) {
|
||||||
|
{
|
||||||
|
auto cnt = TrailingZeroBits(0);
|
||||||
|
ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto cnt = TrailingZeroBits(0b00011100);
|
||||||
|
ASSERT_EQ(cnt, 2);
|
||||||
|
cnt = detail::TrailingZeroBitsImpl(0b00011100);
|
||||||
|
ASSERT_EQ(cnt, 2);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto cnt = TrailingZeroBits(0b00011101);
|
||||||
|
ASSERT_EQ(cnt, 0);
|
||||||
|
cnt = detail::TrailingZeroBitsImpl(0b00011101);
|
||||||
|
ASSERT_EQ(cnt, 0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto cnt = TrailingZeroBits(0b1000000000000000);
|
||||||
|
ASSERT_EQ(cnt, 15);
|
||||||
|
cnt = detail::TrailingZeroBitsImpl(0b1000000000000000);
|
||||||
|
ASSERT_EQ(cnt, 15);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test {
|
|||||||
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
|
class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
|
||||||
|
|
||||||
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief poor man's gmock for message matching.
|
||||||
|
*
|
||||||
|
* @tparam Error The type of expected execption.
|
||||||
|
*
|
||||||
|
* @param submsg A substring of the actual error message.
|
||||||
|
* @param fn The function that throws Error
|
||||||
|
*/
|
||||||
|
template <typename Error, typename Fn>
|
||||||
|
void ExpectThrow(std::string submsg, Fn&& fn) {
|
||||||
|
try {
|
||||||
|
fn();
|
||||||
|
} catch (Error const& exc) {
|
||||||
|
auto actual = std::string{exc.what()};
|
||||||
|
ASSERT_NE(actual.find(submsg), std::string::npos)
|
||||||
|
<< "Expecting substring `" << submsg << "` from the error message."
|
||||||
|
<< " Got:\n"
|
||||||
|
<< actual << "\n";
|
||||||
|
return;
|
||||||
|
} catch (std::exception const& exc) {
|
||||||
|
auto actual = exc.what();
|
||||||
|
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(false) << "No exception is thrown";
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
84
tests/cpp/plugin/federated/test_federated_comm.cc
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string> // for string
|
||||||
|
#include <thread> // for thread
|
||||||
|
|
||||||
|
#include "../../../../plugin/federated/federated_comm.h"
|
||||||
|
#include "../../collective/net_test.h" // for SocketTest
|
||||||
|
#include "../../helpers.h" // for ExpectThrow
|
||||||
|
#include "test_worker.h" // for TestFederated
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class FederatedCommTest : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||||
|
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||||
|
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||||
|
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||||
|
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||||
|
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
||||||
|
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||||
|
auto construct = [] {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["federated_server_address"] = std::string("localhost:0");
|
||||||
|
config["federated_world_size"] = std::string("1");
|
||||||
|
config["federated_rank"] = Integer(0);
|
||||||
|
FederatedComm comm(config);
|
||||||
|
};
|
||||||
|
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||||
|
auto construct = [] {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["federated_server_address"] = std::string("localhost:0");
|
||||||
|
config["federated_world_size"] = 1;
|
||||||
|
config["federated_rank"] = std::string("0");
|
||||||
|
FederatedComm comm(config);
|
||||||
|
};
|
||||||
|
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["federated_world_size"] = 6;
|
||||||
|
config["federated_rank"] = 3;
|
||||||
|
config["federated_server_address"] = String{"localhost:0"};
|
||||||
|
FederatedComm comm{config};
|
||||||
|
EXPECT_EQ(comm.World(), 6);
|
||||||
|
EXPECT_EQ(comm.Rank(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, IsDistributed) {
|
||||||
|
FederatedComm comm{"localhost", 0, 2, 1};
|
||||||
|
EXPECT_TRUE(comm.IsDistributed());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommTest, InsecureTracker) {
|
||||||
|
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||||
|
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["federated_world_size"] = n_workers;
|
||||||
|
config["federated_rank"] = rank;
|
||||||
|
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||||
|
FederatedComm comm{config};
|
||||||
|
ASSERT_EQ(comm.Rank(), rank);
|
||||||
|
ASSERT_EQ(comm.World(), n_workers);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
42
tests/cpp/plugin/federated/test_worker.h
Normal file
42
tests/cpp/plugin/federated/test_worker.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <chrono> // for ms
|
||||||
|
#include <thread> // for thread
|
||||||
|
|
||||||
|
#include "../../../../plugin/federated/federated_tracker.h"
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
template <typename WorkerFn>
|
||||||
|
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||||
|
Json config{Object()};
|
||||||
|
config["federated_secure"] = Boolean{false};
|
||||||
|
config["n_workers"] = Integer{n_workers};
|
||||||
|
FederatedTracker tracker{config};
|
||||||
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
using namespace std::chrono_literals;
|
||||||
|
while (tracker.Port() == 0) {
|
||||||
|
std::this_thread::sleep_for(100ms);
|
||||||
|
}
|
||||||
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
workers.emplace_back([=] { fn(port, i); });
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& t : workers) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rc = tracker.Shutdown();
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
ASSERT_TRUE(fut.get().OK());
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022-2023 XGBoost contributors
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class ServerForTest {
|
|||||||
explicit ServerForTest(std::size_t world_size) {
|
explicit ServerForTest(std::size_t world_size) {
|
||||||
server_thread_.reset(new std::thread([this, world_size] {
|
server_thread_.reset(new std::thread([this, world_size] {
|
||||||
grpc::ServerBuilder builder;
|
grpc::ServerBuilder builder;
|
||||||
xgboost::federated::FederatedService service{world_size};
|
xgboost::federated::FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||||
int selected_port;
|
int selected_port;
|
||||||
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
|
builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
|
||||||
builder.RegisterService(&service);
|
builder.RegisterService(&service);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user