From a5a58102e5e82fa508514c34cd8e5f408dcfd3e1 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 20 May 2024 11:56:23 +0800 Subject: [PATCH] Revamp the rabit implementation. (#10112) This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages. --- CMakeLists.txt | 4 - R-package/src/Makevars.in | 8 +- R-package/src/Makevars.win | 8 +- cmake/Utils.cmake | 1 + demo/dask/cpu_training.py | 2 +- include/xgboost/c_api.h | 184 ++-- include/xgboost/collective/result.h | 5 +- include/xgboost/collective/socket.h | 71 +- jvm-packages/create_jni.py | 11 +- jvm-packages/pom.xml | 5 + .../ml/dmlc/xgboost4j/java/flink/XGBoost.java | 8 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 45 +- .../scala/spark/params/GeneralParams.scala | 28 +- .../spark/CommunicatorRobustnessSuite.scala | 96 +- .../scala/spark/ParameterSuite.scala | 2 - .../XGBoostCommunicatorRegressionSuite.scala | 30 - jvm-packages/xgboost4j-tester/generate_pom.py | 5 + .../ml/dmlc/xgboost4j/java/Communicator.java | 32 +- .../{IRabitTracker.java => ITracker.java} | 23 +- .../ml/dmlc/xgboost4j/java/RabitTracker.java | 231 +--- .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 2 +- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 17 +- .../dmlc/xgboost4j/java/util/UtilUnsafe.java | 1 - .../ml/dmlc/xgboost4j/scala/XGBoost.scala | 4 +- .../xgboost4j/src/native/xgboost4j.cpp | 142 ++- jvm-packages/xgboost4j/src/native/xgboost4j.h | 44 +- .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 4 +- plugin/federated/CMakeLists.txt | 20 +- plugin/federated/federated.old.proto | 81 -- plugin/federated/federated_client.h | 132 --- plugin/federated/federated_communicator.h | 195 ---- plugin/federated/federated_server.cc | 86 -- plugin/federated/federated_server.h | 37 - plugin/federated/federated_tracker.cc | 13 +- plugin/sycl/device_manager.cc | 59 +- plugin/sycl/objective/multiclass_obj.cc | 1 - plugin/sycl/objective/regression_obj.cc | 1 - plugin/sycl/predictor/predictor.cc | 1 - python-package/xgboost/collective.py | 112 +- python-package/xgboost/core.py | 2 +- python-package/xgboost/dask/__init__.py | 27 +- python-package/xgboost/federated.py | 112 +- python-package/xgboost/spark/utils.py | 14 +- python-package/xgboost/testing/__init__.py | 12 +- python-package/xgboost/tracker.py | 572 ++-------- rabit/CMakeLists.txt | 15 - rabit/LICENSE | 28 - rabit/README.md | 1 - rabit/include/rabit/base.h | 19 - rabit/include/rabit/c_api.h | 157 --- rabit/include/rabit/internal/engine.h | 197 ---- rabit/include/rabit/internal/io.h | 118 --- rabit/include/rabit/internal/rabit-inl.h | 234 ---- rabit/include/rabit/internal/socket.h | 14 +- rabit/include/rabit/internal/utils.h | 146 --- rabit/include/rabit/rabit.h | 237 ----- rabit/include/rabit/serializable.h | 26 - rabit/src/allreduce_base.cc | 997 ------------------ rabit/src/allreduce_base.h | 501 --------- rabit/src/allreduce_mock.h | 147 --- rabit/src/engine.cc | 106 -- rabit/src/engine_mock.cc | 14 - rabit/src/rabit_c_api.cc | 342 ------ src/c_api/c_api.cc | 82 +- src/c_api/c_api_error.cc | 20 +- src/c_api/c_api_error.h | 5 +- src/c_api/coll_c_api.cc | 118 ++- src/cli_main.cc | 31 +- src/collective/aggregator.cuh | 15 +- src/collective/aggregator.h | 118 ++- src/collective/allgather.cc | 7 +- src/collective/allreduce.cc | 10 +- src/collective/comm.cc | 16 +- src/collective/comm.h | 2 +- src/collective/comm_group.cc | 28 +- src/collective/communicator-inl.cc | 34 - src/collective/communicator-inl.cuh | 95 -- src/collective/communicator-inl.h | 309 +----- src/collective/communicator.cc | 63 -- src/collective/communicator.cu | 54 - src/collective/communicator.h | 247 ----- src/collective/device_communicator.cuh | 57 - .../device_communicator_adapter.cuh | 94 -- src/collective/in_memory_communicator.cc | 12 - src/collective/in_memory_communicator.h | 6 +- src/collective/in_memory_handler.cc | 99 +- src/collective/in_memory_handler.h | 41 +- src/collective/loop.cc | 106 +- src/collective/loop.h | 19 +- src/collective/nccl_device_communicator.cu | 243 ----- src/collective/nccl_device_communicator.cuh | 91 -- src/collective/noop_communicator.h | 32 - src/collective/protocol.h | 78 +- src/collective/rabit_communicator.h | 175 --- src/collective/result.cc | 13 +- src/collective/socket.cc | 48 +- src/collective/tracker.cc | 18 +- src/collective/tracker.h | 3 +- src/common/io.h | 96 +- src/common/quantile.cc | 34 +- src/common/quantile.cu | 38 +- src/common/random.h | 9 +- src/data/array_interface.h | 7 +- src/data/data.cc | 62 +- src/data/iterative_dmatrix.cc | 20 +- src/data/iterative_dmatrix.cu | 9 +- src/data/proxy_dmatrix.h | 13 +- src/data/simple_dmatrix.cc | 10 +- src/data/sparse_page_source.h | 1 + src/learner.cc | 10 +- src/logging.cc | 9 +- src/metric/auc.cc | 9 +- src/metric/auc.cu | 23 +- src/metric/auc.h | 6 +- src/objective/adaptive.h | 4 +- src/predictor/cpu_predictor.cc | 20 +- src/predictor/gpu_predictor.cu | 20 +- src/tree/common_row_partitioner.h | 44 +- src/tree/fit_stump.cu | 6 +- src/tree/gpu_hist/evaluate_splits.cu | 14 +- src/tree/hist/evaluate_splits.h | 9 +- src/tree/hist/histogram.h | 42 +- src/tree/hist/param.cc | 20 +- src/tree/updater_gpu_hist.cu | 30 +- src/tree/updater_refresh.cc | 27 +- src/tree/updater_sync.cc | 7 +- tests/ci_build/Dockerfile.jvm_cross | 2 +- tests/ci_build/build_jvm_packages.sh | 1 - tests/ci_build/build_mock_cmake.sh | 10 - tests/ci_build/test_r_package.py | 1 - tests/cpp/collective/net_test.h | 41 - tests/cpp/collective/test_allgather.cc | 31 + tests/cpp/collective/test_allreduce.cc | 57 + tests/cpp/collective/test_communicator.cc | 63 -- .../collective/test_in_memory_communicator.cc | 237 ----- tests/cpp/collective/test_loop.cc | 8 +- .../test_nccl_device_communicator.cu | 99 -- .../cpp/collective/test_rabit_communicator.cc | 70 -- tests/cpp/collective/test_tracker.cc | 4 + tests/cpp/collective/test_worker.h | 72 +- tests/cpp/common/test_io.cc | 4 +- tests/cpp/common/test_quantile.cc | 76 +- tests/cpp/common/test_quantile.cu | 93 +- tests/cpp/common/test_quantile.h | 12 +- tests/cpp/data/test_metainfo.cc | 11 +- tests/cpp/data/test_simple_dmatrix.cc | 3 +- tests/cpp/helpers.h | 61 -- tests/cpp/metric/test_auc.cc | 68 -- tests/cpp/metric/test_auc.cu | 5 - tests/cpp/metric/test_auc.h | 31 +- tests/cpp/metric/test_distributed_metric.cc | 192 ++++ tests/cpp/metric/test_elementwise_metric.cc | 106 -- tests/cpp/metric/test_elementwise_metric.cu | 5 - tests/cpp/metric/test_elementwise_metric.h | 56 +- tests/cpp/metric/test_multiclass_metric.cc | 29 - tests/cpp/metric/test_multiclass_metric.cu | 5 - tests/cpp/metric/test_multiclass_metric.h | 20 +- tests/cpp/metric/test_rank_metric.cc | 89 +- tests/cpp/metric/test_rank_metric.cu | 5 - tests/cpp/metric/test_rank_metric.h | 16 +- tests/cpp/metric/test_survival_metric.cc | 11 +- tests/cpp/metric/test_survival_metric.cu | 26 +- tests/cpp/metric/test_survival_metric.h | 10 +- tests/cpp/objective/test_objective.cc | 2 +- tests/cpp/objective_helpers.cc | 20 +- tests/cpp/objective_helpers.h | 8 +- .../plugin/federated/test_federated_coll.cu | 26 + tests/cpp/plugin/federated/test_worker.h | 81 +- tests/cpp/plugin/helpers.h | 99 -- tests/cpp/plugin/test_federated_adapter.cu | 97 -- .../cpp/plugin/test_federated_communicator.cc | 161 --- tests/cpp/plugin/test_federated_data.cc | 10 +- tests/cpp/plugin/test_federated_learner.cc | 51 +- tests/cpp/plugin/test_federated_metrics.cc | 243 ----- tests/cpp/plugin/test_federated_server.cc | 133 --- tests/cpp/predictor/test_cpu_predictor.cc | 10 +- tests/cpp/predictor/test_gpu_predictor.cu | 32 +- tests/cpp/predictor/test_predictor.cc | 43 +- tests/cpp/rabit/allreduce_base_test.cc | 42 - tests/cpp/rabit/test_utils.cc | 6 - tests/cpp/test_learner.cc | 153 +-- tests/cpp/test_main.cc | 11 +- .../cpp/tree/gpu_hist/test_evaluate_splits.cu | 28 +- tests/cpp/tree/hist/test_histogram.cc | 5 +- tests/cpp/tree/test_approx.cc | 15 +- tests/cpp/tree/test_evaluate_splits.h | 8 +- tests/cpp/tree/test_fit_stump.cc | 7 +- tests/cpp/tree/test_gpu_hist.cu | 21 +- tests/cpp/tree/test_histmaker.cc | 7 +- .../cpp/tree/test_multi_target_tree_model.cc | 26 +- tests/cpp/tree/test_quantile_hist.cc | 13 +- tests/python/test_collective.py | 60 +- tests/python/test_tracker.py | 161 ++- tests/python/test_with_arrow.py | 11 +- tests/python/test_with_sklearn.py | 5 +- 195 files changed, 2768 insertions(+), 9234 deletions(-) rename jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/{IRabitTracker.java => ITracker.java} (56%) delete mode 100644 plugin/federated/federated.old.proto delete mode 100644 plugin/federated/federated_client.h delete mode 100644 plugin/federated/federated_communicator.h delete mode 100644 plugin/federated/federated_server.cc delete mode 100644 plugin/federated/federated_server.h delete mode 100644 rabit/CMakeLists.txt delete mode 100644 rabit/LICENSE delete mode 100644 rabit/README.md delete mode 100644 rabit/include/rabit/base.h delete mode 100644 rabit/include/rabit/c_api.h delete mode 100644 rabit/include/rabit/internal/engine.h delete mode 100644 rabit/include/rabit/internal/io.h delete mode 100644 rabit/include/rabit/internal/rabit-inl.h delete mode 100644 rabit/include/rabit/internal/utils.h delete mode 100644 rabit/include/rabit/rabit.h delete mode 100644 rabit/include/rabit/serializable.h delete mode 100644 rabit/src/allreduce_base.cc delete mode 100644 rabit/src/allreduce_base.h delete mode 100644 rabit/src/allreduce_mock.h delete mode 100644 rabit/src/engine.cc delete mode 100644 rabit/src/engine_mock.cc delete mode 100644 rabit/src/rabit_c_api.cc delete mode 100644 src/collective/communicator-inl.cc delete mode 100644 src/collective/communicator-inl.cuh delete mode 100644 src/collective/communicator.cc delete mode 100644 src/collective/communicator.cu delete mode 100644 src/collective/communicator.h delete mode 100644 src/collective/device_communicator.cuh delete mode 100644 src/collective/device_communicator_adapter.cuh delete mode 100644 src/collective/in_memory_communicator.cc delete mode 100644 src/collective/nccl_device_communicator.cu delete mode 100644 src/collective/nccl_device_communicator.cuh delete mode 100644 src/collective/noop_communicator.h delete mode 100644 src/collective/rabit_communicator.h delete mode 100755 tests/ci_build/build_mock_cmake.sh delete mode 100644 tests/cpp/collective/net_test.h delete mode 100644 tests/cpp/collective/test_communicator.cc delete mode 100644 tests/cpp/collective/test_in_memory_communicator.cc delete mode 100644 tests/cpp/collective/test_nccl_device_communicator.cu delete mode 100644 tests/cpp/collective/test_rabit_communicator.cc delete mode 100644 tests/cpp/metric/test_auc.cc delete mode 100644 tests/cpp/metric/test_auc.cu create mode 100644 tests/cpp/metric/test_distributed_metric.cc delete mode 100644 tests/cpp/metric/test_elementwise_metric.cc delete mode 100644 tests/cpp/metric/test_elementwise_metric.cu delete mode 100644 tests/cpp/metric/test_multiclass_metric.cc delete mode 100644 tests/cpp/metric/test_multiclass_metric.cu delete mode 100644 tests/cpp/metric/test_rank_metric.cu delete mode 100644 tests/cpp/plugin/helpers.h delete mode 100644 tests/cpp/plugin/test_federated_adapter.cu delete mode 100644 tests/cpp/plugin/test_federated_communicator.cc delete mode 100644 tests/cpp/plugin/test_federated_metrics.cc delete mode 100644 tests/cpp/plugin/test_federated_server.cc delete mode 100644 tests/cpp/rabit/allreduce_base_test.cc delete mode 100644 tests/cpp/rabit/test_utils.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index dbfa1cdc2..c69b0d2a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF) option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF) option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF) set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header") -option(RABIT_MOCK "Build rabit with mock" OFF) option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF) option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF) ## CUDA @@ -282,9 +281,6 @@ if(MSVC) endif() endif() -# rabit -add_subdirectory(rabit) - # core xgboost add_subdirectory(${xgboost_SOURCE_DIR}/src) target_link_libraries(objxgboost PUBLIC dmlc) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 69cdd09a3..93cfb8e5b 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -106,10 +106,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ - $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ - $(PKGROOT)/src/collective/communicator.o \ - $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/loop.o \ $(PKGROOT)/src/collective/socket.o \ @@ -134,7 +131,4 @@ OBJECTS= \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ $(PKGROOT)/src/c_api/c_api_error.o \ - $(PKGROOT)/amalgamation/dmlc-minimum0.o \ - $(PKGROOT)/rabit/src/engine.o \ - $(PKGROOT)/rabit/src/rabit_c_api.o \ - $(PKGROOT)/rabit/src/allreduce_base.o + $(PKGROOT)/amalgamation/dmlc-minimum0.o diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index b34d8c649..f160930e8 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -106,10 +106,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ - $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ - $(PKGROOT)/src/collective/communicator.o \ - $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/loop.o \ $(PKGROOT)/src/collective/socket.o \ @@ -134,7 +131,4 @@ OBJECTS= \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ $(PKGROOT)/src/c_api/c_api_error.o \ - $(PKGROOT)/amalgamation/dmlc-minimum0.o \ - $(PKGROOT)/rabit/src/engine.o \ - $(PKGROOT)/rabit/src/rabit_c_api.o \ - $(PKGROOT)/rabit/src/allreduce_base.o + $(PKGROOT)/amalgamation/dmlc-minimum0.o diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 9006bb0ea..317a71c00 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target) target_include_directories( ${target} PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap + ${xgboost_SOURCE_DIR}/rabit/include ${CUDAToolkit_INCLUDE_DIRS}) if(MSVC) diff --git a/demo/dask/cpu_training.py b/demo/dask/cpu_training.py index 59471da7c..b3a389458 100644 --- a/demo/dask/cpu_training.py +++ b/demo/dask/cpu_training.py @@ -16,7 +16,7 @@ def main(client: Client) -> None: m = 100000 n = 100 rng = da.random.default_rng(1) - X = rng.normal(size=(m, n)) + X = rng.normal(size=(m, n), chunks=(10000, -1)) y = X.sum(axis=1) # DaskDMatrix acts like normal DMatrix, works as a proxy for local diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 19b93c644..4b60fe01a 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values, * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface, - char const *c_json_config, DMatrixHandle m, +XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values, + char const *config, DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim, const float **out_result); @@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, * * @brief Experimental support for exposing internal communicator in XGBoost. * + * @note This is still under development. + * + * The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has + * changed significantly since its adoption. It consists of a tracker and a set of + * workers. The tracker is responsible for bootstrapping the communication group and + * handling centralized tasks like logging. The workers are actual communicators + * performing collective tasks like allreduce. + * + * To use the collective implementation, one needs to first create a tracker with + * corresponding parameters, then get the arguments for workers using + * XGTrackerWorkerArgs(). The obtained arguments can then be passed to the + * XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a + * XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses + * `std::thread` in C++, which has undefined behavior in a C++ destructor due to the + * runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the + * runtime is shutting down. This requirement is similar to a Python thread or socket, + * which should not be relied upon in a `__del__` function. + * + * Since it's used as a part of XGBoost, errors will be returned when a XGBoost function + * is called, for instance, training a booster might return a connection error. + * * @{ */ /** - * @brief Handle to tracker. + * @brief Handle to the tracker. * * There are currently two types of tracker in XGBoost, first one is `rabit`, while the - * other one is `federated`. + * other one is `federated`. `rabit` is used for normal collective communication, while + * `federated` is used for federated learning. * - * This is still under development. */ typedef void *TrackerHandle; /* NOLINT */ @@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */ * * @param config JSON encoded parameters. * - * - dmlc_communicator: String, the type of tracker to create. Available options are `rabit` - * and `federated`. + * - dmlc_communicator: String, the type of tracker to create. Available options are + * `rabit` and `federated`. See @ref TrackerHandle for more info. * - n_workers: Integer, the number of workers. * - port: (Optional) Integer, the port this tracker should listen to. - * - timeout: (Optional) Integer, timeout in seconds for various networking operations. + * - timeout: (Optional) Integer, timeout in seconds for various networking + operations. Default is 300 seconds. * * Some configurations are `rabit` specific: + * * - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host. + * This can be useful when the communicator cannot reliably obtain the host address. + * - sortby: (Optional) Integer. + * + 0: Sort workers by their host name. + * + 1: Sort workers by task IDs. * * Some `federated` specific configurations: - * - federated_secure: Boolean, whether this is a secure server. + * - federated_secure: Boolean, whether this is a secure server. False for testing. * - server_key_path: Path to the server key. Used only if this is a secure server. * - server_cert_path: Path to the server certificate. Used only if this is a secure server. * - client_cert_path: Path to the client certificate. Used only if this is a secure server. @@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config); */ XGB_DLL int XGTrackerFree(TrackerHandle handle); -/*! - * \brief Initialize the collective communicator. +/** + * @brief Initialize the collective communicator. * * Currently the communicator API is experimental, function signatures may change in the future * without notice. * - * Call this once before using anything. + * Call this once in the worker process before using anything. Please make sure + * XGCommunicatorFinalize() is called after use. The initialized commuicator is a global + * thread-local variable. * - * The additional configuration is not required. Usually the communicator will detect settings - * from environment variables. - * - * \param config JSON encoded configuration. Accepted JSON keys are: - * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. + * @param config JSON encoded configuration. Accepted JSON keys are: + * - dmlc_communicator: The type of the communicator, this should match the tracker type. * * rabit: Use Rabit. This is the default if the type is unspecified. * * federated: Use the gRPC interface for Federated Learning. - * Only applicable to the Rabit communicator (these are case-sensitive): - * - rabit_tracker_uri: Hostname of the tracker. - * - rabit_tracker_port: Port number of the tracker. - * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. - * - rabit_world_size: Total number of workers. - * - rabit_timeout: Enable timeout. - * - rabit_timeout_sec: Timeout in seconds. - * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as - * environment variables): - * - DMLC_TRACKER_URI: Hostname of the tracker. - * - DMLC_TRACKER_PORT: Port number of the tracker. - * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. - * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - * - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with - * `USE_DLOPEN_NCCL`. - * Only applicable to the Federated communicator (use upper case for environment variables, use + * + * Only applicable to the `rabit` communicator: + * - dmlc_tracker_uri: Hostname or IP address of the tracker. + * - dmlc_tracker_port: Port number of the tracker. + * - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment. + * - dmlc_retry: The number of retries for connection failure. + * - dmlc_timeout: Timeout in seconds. + * - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`. + * + * Only applicable to the `federated` communicator (use upper case for environment variables, use * lower case for runtime configuration): * - federated_server_address: Address of the federated server. * - federated_world_size: Number of federated workers. * - federated_rank: Rank of the current worker. - * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. - * - federated_client_key: Client key file path. Only needed for the SSL mode. - * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. - * \return 0 for success, -1 for failure. + * - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode. + * - federated_client_key_path: Client key file path. Only needed for the SSL mode. + * - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode. + * + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorInit(char const* config); -/*! - * \brief Finalize the collective communicator. +/** + * @brief Finalize the collective communicator. * - * Call this function after you finished all jobs. + * Call this function after you have finished all jobs. * - * \return 0 for success, -1 for failure. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorFinalize(void); -/*! - * \brief Get rank of current process. +/** + * @brief Get rank of the current process. * - * \return Rank of the worker. + * @return Rank of the worker. */ XGB_DLL int XGCommunicatorGetRank(void); -/*! - * \brief Get total number of processes. +/** + * @brief Get the total number of processes. * - * \return Total world size. + * @return Total world size. */ XGB_DLL int XGCommunicatorGetWorldSize(void); -/*! - * \brief Get if the communicator is distributed. +/** + * @brief Get if the communicator is distributed. * - * \return True if the communicator is distributed. + * @return True if the communicator is distributed. */ XGB_DLL int XGCommunicatorIsDistributed(void); -/*! - * \brief Print the message to the communicator. +/** + * @brief Print the message to the tracker. * - * This function can be used to communicate the information of the progress to the user who monitors - * the communicator. + * This function can be used to communicate the information of the progress to the user + * who monitors the tracker. * - * \param message The message to be printed. - * \return 0 for success, -1 for failure. + * @param message The message to be printed. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorPrint(char const *message); -/*! - * \brief Get the name of the processor. +/** + * @brief Get the name of the processor. * - * \param name_str Pointer to received returned processor name. - * \return 0 for success, -1 for failure. + * @param name_str Pointer to received returned processor name. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str); -/*! - * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. +/** + * @brief Broadcast a memory region to all others from root. This function is NOT + * thread-safe. * * Example: - * \code + * @code * int a = 1; * Broadcast(&a, sizeof(a), root); - * \endcode + * @endcode * - * \param send_receive_buffer Pointer to the send or receive buffer. - * \param size Size of the data. - * \param root The process rank to broadcast from. - * \return 0 for success, -1 for failure. + * @param send_receive_buffer Pointer to the send or receive buffer. + * @param size Size of the data in bytes. + * @param root The process rank to broadcast from. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root); -/*! - * \brief Perform in-place allreduce. This function is NOT thread-safe. +/** + * @brief Perform in-place allreduce. This function is NOT thread-safe. * * Example Usage: the following code gives sum of the result - * \code - * vector data(10); + * @code + * enum class Op { + * kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 + * }; + * std::vector data(10); * ... - * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); + * Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum); * ... - * \endcode + * @endcode - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param count Number of elements to be reduced. - * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. - * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. - * \return 0 for success, -1 for failure. + * @param send_receive_buffer Buffer for both sending and receiving data. + * @param count Number of elements to be reduced. + * @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. + * @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. + * + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op); diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h index 23e70a8e6..c126366a0 100644 --- a/include/xgboost/collective/result.h +++ b/include/xgboost/collective/result.h @@ -55,10 +55,9 @@ struct ResultImpl { #if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) #define __builtin_FILE() nullptr #define __builtin_LINE() (-1) -std::string MakeMsg(std::string&& msg, char const*, std::int32_t); -#else -std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line); #endif + +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line); } // namespace detail /** diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 0e098052c..c5dd977f6 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -16,6 +16,10 @@ #include // std::error_code, std::system_category #include // std::swap +#if defined(__linux__) +#include // for TIOCOUTQ, FIONREAD +#endif // defined(__linux__) + #if !defined(xgboost_IS_MINGW) #if defined(__MINGW32__) @@ -319,7 +323,8 @@ class TCPSocket { std::int32_t domain; socklen_t len = sizeof(domain); xgboost_CHECK_SYS_CALL( - getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast(&domain), &len), 0); + getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast(&domain), &len), + 0); return ret_iafamily(domain); #else struct sockaddr sa; @@ -426,6 +431,35 @@ class TCPSocket { return Success(); } + [[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) { + socklen_t optlen; + auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast(n_bytes), + &optlen); + if (rc != 0 || optlen != sizeof(std::int32_t)) { + return system::FailWithCode("getsockopt"); + } + return Success(); + } + [[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) { + socklen_t optlen; + auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast(n_bytes), + &optlen); + if (rc != 0 || optlen != sizeof(std::int32_t)) { + return system::FailWithCode("getsockopt"); + } + return Success(); + } +#if defined(__linux__) + [[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const { + return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success() + : system::FailWithCode("ioctl"); + } + [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const { + return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success() + : system::FailWithCode("ioctl"); + } +#endif // defined(__linux__) + [[nodiscard]] Result SetKeepAlive() { std::int32_t keepalive = 1; auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&keepalive), @@ -436,10 +470,9 @@ class TCPSocket { return Success(); } - [[nodiscard]] Result SetNoDelay() { - std::int32_t tcp_no_delay = 1; - auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcp_no_delay), - sizeof(tcp_no_delay)); + [[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) { + auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&no_delay), + sizeof(no_delay)); if (rc != 0) { return system::FailWithCode("Failed to set TCP no delay."); } @@ -602,45 +635,47 @@ class TCPSocket { } /** - * \brief Send data, without error then all data should be sent. + * @brief Send data, without error then all data should be sent. */ - [[nodiscard]] auto SendAll(void const *buf, std::size_t len) { + [[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) { char const *_buf = reinterpret_cast(buf); - std::size_t ndone = 0; + std::size_t &ndone = *n_sent; + ndone = 0; while (ndone < len) { ssize_t ret = send(handle_, _buf, len - ndone, 0); if (ret == -1) { if (system::LastErrorWouldBlock()) { - return ndone; + return Success(); } - system::ThrowAtError("send"); + return system::FailWithCode("send"); } _buf += ret; ndone += ret; } - return ndone; + return Success(); } /** - * \brief Receive data, without error then all data should be received. + * @brief Receive data, without error then all data should be received. */ - [[nodiscard]] auto RecvAll(void *buf, std::size_t len) { + [[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) { char *_buf = reinterpret_cast(buf); - std::size_t ndone = 0; + std::size_t &ndone = *n_recv; + ndone = 0; while (ndone < len) { ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL); if (ret == -1) { if (system::LastErrorWouldBlock()) { - return ndone; + return Success(); } - system::ThrowAtError("recv"); + return system::FailWithCode("recv"); } if (ret == 0) { - return ndone; + return Success(); } _buf += ret; ndone += ret; } - return ndone; + return Success(); } /** * \brief Send data using the socket diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index 865d07fe8..693546862 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -23,6 +23,7 @@ CONFIG = { "USE_NCCL": "OFF", "JVM_BINDINGS": "ON", "LOG_CAPI_INVOCATION": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", } @@ -97,10 +98,6 @@ def native_build(args): args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()] - # if enviorment set rabit_mock - if os.getenv("RABIT_MOCK", None) is not None: - args.append("-DRABIT_MOCK:BOOL=ON") - # if enviorment set GPU_ARCH_FLAG gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None) if gpu_arch_flag is not None: @@ -162,12 +159,6 @@ def native_build(args): maybe_makedirs(output_folder) cp("../lib/" + library_name, output_folder) - print("copying pure-Python tracker") - cp( - "../python-package/xgboost/tracker.py", - "{}/src/main/resources".format(xgboost4j), - ) - print("copying train/test files") maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark)) with cd("../demo/CLI/regression"): diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index e6faa35c3..c5354aad7 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -489,6 +489,11 @@ kryo 5.6.0 + + com.fasterxml.jackson.core + jackson-databind + 2.14.2 + commons-logging commons-logging diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java index 7a5e3ac68..99608b927 100644 --- a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java +++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java @@ -54,9 +54,9 @@ public class XGBoost { private final Map params; private final int round; - private final Map workerEnvs; + private final Map workerEnvs; - public MapFunction(Map params, int round, Map workerEnvs) { + public MapFunction(Map params, int round, Map workerEnvs) { this.params = params; this.round = round; this.workerEnvs = workerEnvs; @@ -174,9 +174,9 @@ public class XGBoost { int numBoostRound) throws Exception { final RabitTracker tracker = new RabitTracker(dtrain.getExecutionEnvironment().getParallelism()); - if (tracker.start(0L)) { + if (tracker.start()) { return dtrain - .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs())) + .mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs())) .reduce((x, y) -> x) .collect() .get(0); diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 5a1af886f..e17c68355 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import scala.collection.mutable import scala.util.Random import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker} import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} @@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession /** * Rabit tracker configurations. * - * @param workerConnectionTimeout The timeout for all workers to connect to the tracker. - * Set timeout length to zero to disable timeout. - * Use a finite, non-zero timeout value to prevent tracker from - * hanging indefinitely (in milliseconds) - * (supported by "scala" implementation only.) - * @param hostIp The Rabit Tracker host IP address which is only used for python implementation. + * @param timeout The number of seconds before timeout waiting for workers to connect. and + * for the tracker to shutdown. + * @param hostIp The Rabit Tracker host IP address. * This is only needed if the host IP cannot be automatically guessed. - * @param pythonExec The python executed path for Rabit Tracker, - * which is only used for python implementation. + * @param port The port number for the tracker to listen to. Use a system allocated one by + * default. */ -case class TrackerConf(workerConnectionTimeout: Long, - hostIp: String = "", pythonExec: String = "") +case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0) object TrackerConf { - def apply(): TrackerConf = TrackerConf(0L) + def apply(): TrackerConf = TrackerConf(0) } private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long) @@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel { private def buildDistributedBooster( buildWatches: () => Watches, xgbExecutionParam: XGBoostExecutionParams, - rabitEnv: java.util.Map[String, String], + rabitEnv: java.util.Map[String, Object], obj: ObjectiveTrait, eval: EvalTrait, prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = { @@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel { val taskId = TaskContext.getPartitionId().toString val attempt = TaskContext.get().attemptNumber.toString rabitEnv.put("DMLC_TASK_ID", taskId) - rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) val numRounds = xgbExecutionParam.numRounds val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0 @@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel { } /** visiable for testing */ - private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { - val tracker: IRabitTracker = new PyRabitTracker( - nWorkers, trackerConf.hostIp, trackerConf.pythonExec - ) + private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { + val tracker: ITracker = new RabitTracker( + nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout) tracker } - private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { + private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { val tracker = getTracker(nWorkers, trackerConf) - require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker") + require(tracker.start(), "FAULT: Failed to start tracker") tracker } @@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel { // Train for every ${savingRound} rounds and save the partially completed booster val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) val (booster, metrics) = try { - tracker.getWorkerEnvs().putAll(xgbRabitParams) - val rabitEnv = tracker.getWorkerEnvs + tracker.workerArgs().putAll(xgbRabitParams) + val rabitEnv = tracker.workerArgs val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => { var optionWatches: Option[() => Watches] = None @@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel { // of the training task fails the training stage can retry. ResultStage won't retry when // it fails. val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0) - val trackerReturnVal = tracker.waitFor(0L) - logger.info(s"Rabit returns with exit code $trackerReturnVal") - if (trackerReturnVal != 0) { - throw new XGBoostError("XGBoostModel training failed.") - } (booster, metrics) } finally { tracker.stop() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index b85f4dc8b..fafbd816a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params { * Rabit tracker configurations. The parameter must be provided as an instance of the * TrackerConf class, which has the following definition: * - * case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration, - * trackerImpl: String) + * case class TrackerConf(timeout: Int, hostIp: String, port: Int) * * See below for detailed explanations. * - * - trackerImpl: Select the implementation of Rabit tracker. - * default: "python" - * - * Choice between "python" or "scala". The former utilizes the Java wrapper of the - * Python Rabit tracker (in dmlc_core), and does not support timeout settings. - * The "scala" version removes Python components, and fully supports timeout settings. - * - * - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker. - * default: 0 millisecond (no timeout) + * - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds) + * default: 0 (no timeout) * + * Timeout for constructing the communication group and waiting for the tracker to + * shutdown when it's instructed to, doesn't apply to communication when tracking + * is running. * The timeout value should take the time of data loading and pre-processing into account, - * due to the lazy execution of Spark's operations. Alternatively, you may force Spark to + * due to potential lazy execution. Alternatively, you may force Spark to * perform data transformation before calling XGBoost.train(), so that this timeout truly * reflects the connection delay. Set a reasonable timeout value to prevent model * training/testing from hanging indefinitely, possible due to network issues. * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf). - * Ignored if the tracker implementation is "python". + * + * - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP + * cannot be automatically guessed. + * + * - port : The port number for the tracker to listen to. Use a system allocated one by + * default. */ final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations") setDefault(trackerConf, TrackerConf()) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala index 5445cd1bf..108053af5 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque import scala.util.Random -import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker} -import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus +import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker} import ml.dmlc.xgboost4j.scala.DMatrix import org.scalatest.funsuite.AnyFunSuite @@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { xgbParamsFactory.buildXGBRuntimeParams } - test("Customize host ip and python exec for Rabit tracker") { - val hostIp = "192.168.22.111" - val pythonExec = "/usr/bin/python3" - - val paramMap = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, hostIp)) - val xgbExecParams = getXGBoostExecutionParams(paramMap) - val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) - tracker match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.contains(hostIp)) - assert(cmd.startsWith("python")) - case _ => assert(false, "expected python tracker implementation") - } - - val paramMap1 = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "", pythonExec)) - val xgbExecParams1 = getXGBoostExecutionParams(paramMap1) - val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf) - tracker1 match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.startsWith(pythonExec)) - assert(!cmd.contains(hostIp)) - case _ => assert(false, "expected python tracker implementation") - } - - val paramMap2 = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, hostIp, pythonExec)) - val xgbExecParams2 = getXGBoostExecutionParams(paramMap2) - val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf) - tracker2 match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.startsWith(pythonExec)) - assert(cmd.contains(s" --host-ip=${hostIp}")) - case _ => assert(false, "expected python tracker implementation") - } - } - test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") { /* Deliberately create new instances of SparkContext in each unit test to avoid reusing the @@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { */ val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - val tracker = new PyRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs + val tracker = new RabitTracker(numWorkers) + tracker.start() + val trackerEnvs = tracker. workerArgs val workerCount: Int = numWorkers /* @@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { thrown: the thread running the dummy spark job (sparkThread) catches the exception and delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself. - The Java RabitTracker class reacts to exceptions by killing the spawned process running - the Python tracker. If at least one Rabit worker has yet connected to the tracker before - it is killed, the resulted connection failure will trigger the Rabit worker to call - "exit(-1);" in the native C++ code, effectively ending the dummy Spark task. - - In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are - isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks - running in separate containers. However, as unit tests are run in Spark local mode, in which - tasks are executed by threads belonging to the same process, one thread calling "exit(-1);" - ultimately kills the entire process, which also happens to host the Spark driver, causing - the entire Spark application to crash. - To prevent unit tests from crashing, deterministic delays were introduced to make sure that the exception is thrown at last, ideally after all worker connections have been established. - For the same reason, the Java RabitTracker class delays the killing of the Python tracker - process to ensure that pending worker connections are handled. */ val dummyTasks = rdd.mapPartitions { iter => Communicator.init(trackerEnvs) @@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { sparkThread.setUncaughtExceptionHandler(tracker) sparkThread.start() - assert(tracker.waitFor(0) != 0) + } + + test("Communicator allreduce works.") { + val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() + val tracker = new RabitTracker(numWorkers) + tracker.start() + val trackerEnvs = tracker.workerArgs + + val workerCount: Int = numWorkers + + rdd.mapPartitions { iter => + val index = iter.next() + Communicator.init(trackerEnvs) + val a = Array(1.0f, 2.0f, 3.0f) + System.out.println(a.mkString(", ")) + val b = Communicator.allReduce(a, Communicator.OpType.SUM) + for (i <- 0 to 2) { + assert(a(i) * workerCount == b(i)) + } + val c = Communicator.allReduce(a, Communicator.OpType.MIN); + for (i <- 0 to 2) { + assert(a(i) == c(i)) + } + Communicator.shutdown() + Iterator(index) + }.collect() } test("should allow the dataframe containing communicator calls to be partially evaluated for" + diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index f187f7394..20a95f2a2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkException import org.apache.spark.ml.param.ParamMap class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { - test("XGBoost and Spark parameters synchronize correctly") { val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic", "objective_type" -> "classification") @@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { intercept[SparkException] { xgb.fit(trainingDF) } - } test("fail training elegantly with unsupported eval metrics") { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala index 86b82e63c..136d39e8b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala @@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest { val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)) .fit(training) - assert(Communicator.communicatorEnvs.asScala.size > 3) - Communicator.communicatorEnvs.asScala.foreach( item => { - if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") - }) - val prediction2 = model2.transform(testDF).select("prediction").collect() // check parity w/o rabit cache prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) => @@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest { val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1) ).fit(training) - assert(Communicator.communicatorEnvs.asScala.size > 3) - Communicator.communicatorEnvs.asScala.foreach( item => { - if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") - }) // check the equality of single instance prediction val prediction2 = model2.transform(testDF).select("prediction").collect() // check parity w/o rabit cache @@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest { assert(math.abs(p1 - p2) < predictionErrorMin) } } - - test("test rabit timeout fail handle") { - val training = buildDataFrame(Classification.train) - // mock rank 0 failure during 8th allreduce synchronization - Communicator.mockList = Array("0,8,0,0").toList.asJava - - intercept[SparkException] { - new XGBoostClassifier(Map( - "eta" -> "0.1", - "max_depth" -> "10", - "verbosity" -> "1", - "objective" -> "binary:logistic", - "num_round" -> 5, - "num_workers" -> numWorkers, - "rabit_timeout" -> 0)) - .fit(training) - } - - Communicator.mockList = Array.empty.toList.asJava - } - } diff --git a/jvm-packages/xgboost4j-tester/generate_pom.py b/jvm-packages/xgboost4j-tester/generate_pom.py index eb7cf94b3..ad729b3a6 100644 --- a/jvm-packages/xgboost4j-tester/generate_pom.py +++ b/jvm-packages/xgboost4j-tester/generate_pom.py @@ -51,6 +51,11 @@ pom_template = """ commons-logging 1.2 + + com.fasterxml.jackson.core + jackson-databind + 2.14.2 + org.scalatest scalatest_${{scala.binary.version}} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java index 795e7d99e..ee1bc7b4a 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java @@ -7,6 +7,9 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * Collective communicator global class for synchronization. * @@ -30,8 +33,9 @@ public class Communicator { } public enum DataType implements Serializable { - INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4), - INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8); + FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8), + INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8), + UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8); private final int enumOp; private final int size; @@ -56,30 +60,20 @@ public class Communicator { } } - // used as way to test/debug passed communicator init parameters - public static Map communicatorEnvs; - public static List mockList = new LinkedList<>(); - /** * Initialize the collective communicator on current working thread. * * @param envs The additional environment variables to pass to the communicator. * @throws XGBoostError */ - public static void init(Map envs) throws XGBoostError { - communicatorEnvs = envs; - String[] args = new String[envs.size() * 2 + mockList.size() * 2]; - int idx = 0; - for (java.util.Map.Entry e : envs.entrySet()) { - args[idx++] = e.getKey(); - args[idx++] = e.getValue(); + public static void init(Map envs) throws XGBoostError { + ObjectMapper mapper = new ObjectMapper(); + try { + String jconfig = mapper.writeValueAsString(envs); + checkCall(XGBoostJNI.CommunicatorInit(jconfig)); + } catch (JsonProcessingException ex) { + throw new XGBoostError("Failed to read arguments for the communicator.", ex); } - // pass list of rabit mock strings eg mock=0,1,0,0 - for (String mock : mockList) { - args[idx++] = "mock"; - args[idx++] = mock; - } - checkCall(XGBoostJNI.CommunicatorInit(args)); } /** diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java similarity index 56% rename from jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java index 984fb80e6..1bfef677d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java @@ -1,14 +1,13 @@ package ml.dmlc.xgboost4j.java; import java.util.Map; -import java.util.concurrent.TimeUnit; /** - * Interface for Rabit tracker implementations with three public methods: + * Interface for a tracker implementations with three public methods: * - * - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given - * timeout value (in milliseconds.) - * - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients. + * - start(timeout): Start the tracker awaiting for worker connections, with a given + * timeout value (in seconds). + * - workerArgs(): Return the arguments needed to initialize Rabit clients. * - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout` * milliseconds. * @@ -21,7 +20,7 @@ import java.util.concurrent.TimeUnit; * The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and * brokers connections between workers. */ -public interface IRabitTracker extends Thread.UncaughtExceptionHandler { +public interface ITracker extends Thread.UncaughtExceptionHandler { enum TrackerStatus { SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3); @@ -36,9 +35,11 @@ public interface IRabitTracker extends Thread.UncaughtExceptionHandler { } } - Map getWorkerEnvs(); - boolean start(long workerConnectionTimeout); - void stop(); - // taskExecutionTimeout has no effect in current version of XGBoost. - int waitFor(long taskExecutionTimeout); + Map workerArgs() throws XGBoostError; + + boolean start() throws XGBoostError; + + void stop() throws XGBoostError; + + void waitFor(long taskExecutionTimeout) throws XGBoostError; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 0a05b3de0..914a493cc 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -1,101 +1,40 @@ package ml.dmlc.xgboost4j.java; -import java.io.*; -import java.util.HashMap; import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; /** * Java implementation of the Rabit tracker to coordinate distributed workers. - * As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both - * start() and waitFor() methods (i.e., the timeout is infinite.) - * - * For systems lacking Python environment, or for timeout functionality, consider using the Scala - * Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and - * provides timeout support. * * The tracker must be started on driver node before running distributed jobs. */ -public class RabitTracker implements IRabitTracker { +public class RabitTracker implements ITracker { // Maybe per tracker logger? private static final Log logger = LogFactory.getLog(RabitTracker.class); - // tracker python file. - private static String tracker_py = null; - private static TrackerProperties trackerProperties = TrackerProperties.getInstance(); - // environment variable to be pased. - private Map envs = new HashMap(); - // number of workers to be submitted. - private int numWorkers; - private String hostIp = ""; - private String pythonExec = ""; - private AtomicReference trackerProcess = new AtomicReference(); + private long handle = 0; + private Thread tracker_daemon; - static { - try { - initTrackerPy(); - } catch (IOException ex) { - logger.error("load tracker library failed."); - logger.error(ex); - } + public RabitTracker(int numWorkers) throws XGBoostError { + this(numWorkers, ""); } - /** - * Tracker logger that logs output from tracker. - */ - private class TrackerProcessLogger implements Runnable { - public void run() { - - Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class); - BufferedReader reader = new BufferedReader(new InputStreamReader( - trackerProcess.get().getErrorStream())); - String line; - try { - while ((line = reader.readLine()) != null) { - trackerProcessLogger.info(line); - } - trackerProcess.get().waitFor(); - int exitValue = trackerProcess.get().exitValue(); - if (exitValue != 0) { - trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue); - } else { - trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue); - } - } catch (IOException ex) { - trackerProcessLogger.error(ex.toString()); - } catch (InterruptedException ie) { - // we should not get here as RabitTracker is accessed in the main thread - ie.printStackTrace(); - logger.error("the RabitTracker thread is terminated unexpectedly"); - } - } - } - - private static void initTrackerPy() throws IOException { - try { - tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py"); - } catch (IOException ioe) { - logger.trace("cannot access tracker python script"); - throw ioe; - } - } - - public RabitTracker(int numWorkers) + public RabitTracker(int numWorkers, String hostIp) throws XGBoostError { + this(numWorkers, hostIp, 0, 300); + } + public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError { if (numWorkers < 1) { throw new XGBoostError("numWorkers must be greater equal to one"); } - this.numWorkers = numWorkers; - } - public RabitTracker(int numWorkers, String hostIp, String pythonExec) - throws XGBoostError { - this(numWorkers); - this.hostIp = hostIp; - this.pythonExec = pythonExec; + long[] out = new long[1]; + XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out)); + this.handle = out[0]; } public void uncaughtException(Thread t, Throwable e) { @@ -105,7 +44,7 @@ public class RabitTracker implements IRabitTracker { } catch (InterruptedException ex) { logger.error(ex); } finally { - trackerProcess.get().destroy(); + this.tracker_daemon.interrupt(); } } @@ -113,115 +52,43 @@ public class RabitTracker implements IRabitTracker { * Get environments that can be used to pass to worker. * @return The environment settings. */ - public Map getWorkerEnvs() { - return envs; + public Map workerArgs() throws XGBoostError { + // fixme: timeout + String[] args = new String[1]; + XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args)); + ObjectMapper mapper = new ObjectMapper(); + TypeReference> typeRef = new TypeReference>() { + }; + Map config; + try { + config = mapper.readValue(args[0], typeRef); + } catch (JsonProcessingException ex) { + throw new XGBoostError("Failed to get worker arguments.", ex); + } + return config; } - private void loadEnvs(InputStream ins) throws IOException { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(ins)); - assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); - String line; - while ((line = reader.readLine()) != null) { - if (line.trim().equals("DMLC_TRACKER_ENV_END")) { - break; - } - String[] sep = line.split("="); - if (sep.length == 2) { - envs.put(sep[0], sep[1]); - } + public void stop() throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle)); + } + + public boolean start() throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle)); + this.tracker_daemon = new Thread(() -> { + try { + XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0)); + } catch (XGBoostError ex) { + logger.error(ex); + return; // exit the thread } - reader.close(); - } catch (IOException ioe){ - logger.error("cannot get runtime configuration from tracker process"); - ioe.printStackTrace(); - throw ioe; - } + }); + this.tracker_daemon.setDaemon(true); + this.tracker_daemon.start(); + + return this.tracker_daemon.isAlive(); } - /** visible for testing */ - public String getRabitTrackerCommand() { - StringBuilder sb = new StringBuilder(); - if (pythonExec == null || pythonExec.isEmpty()) { - sb.append("python "); - } else { - sb.append(pythonExec + " "); - } - sb.append(" " + tracker_py + " "); - sb.append(" --log-level=DEBUG" + " "); - sb.append(" --num-workers=" + numWorkers + " "); - - // we first check the property then check the parameter - String hostIpFromProperties = trackerProperties.getHostIp(); - if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) { - logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties"); - sb.append(" --host-ip=" + hostIpFromProperties + " "); - } else if (hostIp != null & !hostIp.isEmpty()) { - logger.debug("Using the parametr host-ip: " + hostIp); - sb.append(" --host-ip=" + hostIp + " "); - } - return sb.toString(); - } - - private boolean startTrackerProcess() { - try { - String cmd = getRabitTrackerCommand(); - trackerProcess.set(Runtime.getRuntime().exec(cmd)); - loadEnvs(trackerProcess.get().getInputStream()); - return true; - } catch (IOException ioe) { - ioe.printStackTrace(); - return false; - } - } - - public void stop() { - if (trackerProcess.get() != null) { - trackerProcess.get().destroy(); - } - } - - public boolean start(long timeout) { - if (timeout > 0L) { - logger.warn("Python RabitTracker does not support timeout. " + - "The tracker will wait for all workers to connect indefinitely, unless " + - "it is interrupted manually. Use the Scala RabitTracker for timeout support."); - } - - if (startTrackerProcess()) { - logger.debug("Tracker started, with env=" + envs.toString()); - System.out.println("Tracker started, with env=" + envs.toString()); - // also start a tracker logger - Thread logger_thread = new Thread(new TrackerProcessLogger()); - logger_thread.setDaemon(true); - logger_thread.start(); - return true; - } else { - logger.error("FAULT: failed to start tracker process"); - stop(); - return false; - } - } - - public int waitFor(long timeout) { - if (timeout > 0L) { - logger.warn("Python RabitTracker does not support timeout. " + - "The tracker will wait for either all workers to finish tasks and send " + - "shutdown signal, or manual interruptions. " + - "Use the Scala RabitTracker for timeout support."); - } - - try { - trackerProcess.get().waitFor(); - int returnVal = trackerProcess.get().exitValue(); - logger.info("Tracker Process ends with exit code " + returnVal); - stop(); - return returnVal; - } catch (InterruptedException e) { - // we should not get here as RabitTracker is accessed in the main thread - e.printStackTrace(); - logger.error("the RabitTracker thread is terminated unexpectedly"); - return TrackerStatus.INTERRUPTED.getStatusCode(); - } + public void waitFor(long timeout) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout)); } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 2be62a343..71b4ff3f2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 236d53e90..b410d2be1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -54,7 +54,7 @@ class XGBoostJNI { public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, - String cache_info, long[] out); + String cache_info, long[] out); public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, int shapeParam, @@ -146,12 +146,24 @@ class XGBoostJNI { public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds); // communicator functions - public final static native int CommunicatorInit(String[] args); + public final static native int CommunicatorInit(String args); public final static native int CommunicatorFinalize(); public final static native int CommunicatorPrint(String msg); public final static native int CommunicatorGetRank(int[] out); public final static native int CommunicatorGetWorldSize(int[] out); + // Tracker functions + public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout, + long[] out); + + public final static native int TrackerRun(long handle); + + public final static native int TrackerWaitFor(long handle, long timeout); + + public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out); + + public final static native int TrackerFree(long handle); + // Perform Allreduce operation on data in sendrecvbuf. final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count, int enum_dtype, int enum_op); @@ -168,5 +180,4 @@ class XGBoostJNI { public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features); public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out); - } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java index 501a9cfe1..e3857a1d4 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java @@ -42,5 +42,4 @@ public final class UtilUnsafe { throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e); } } - } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 50d86c893..561b97ff3 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams { } } } - - diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 9ba944d5a..67b6a0ee4 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -1,20 +1,21 @@ /** - Copyright (c) 2014-2023 by Contributors - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + * Copyright 2014-2024, XGBoost Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "./xgboost4j.h" -#include #include #include #include @@ -23,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -1016,23 +1016,107 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit - * Signature: ([Ljava/lang/String;)I + * Signature: (Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit - (JNIEnv *jenv, jclass jcls, jobjectArray jargs) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv, + jclass jcls, + jstring jargs) { xgboost::Json config{xgboost::Object{}}; - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); - assert(len % 2 == 0); - for (bst_ulong i = 0; i < len / 2; ++i) { - jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i); - std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key)); - jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1); - std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value)); - config[key_str] = xgboost::String(value_str); + const char *args = jenv->GetStringUTFChars(jargs, nullptr); + JVM_CHECK_CALL(XGCommunicatorInit(args)); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerCreate + * Signature: (Ljava/lang/String;IIIJ[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate( + JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout, + jlongArray jout) { + using namespace xgboost; // NOLINT + + TrackerHandle handle; + Json config{Object{}}; + std::string shost{jenv->GetStringUTFChars(host, nullptr), + static_cast(jenv->GetStringLength(host))}; + if (!shost.empty()) { + config["host"] = shost; } - std::string json_str; - xgboost::Json::Dump(config, &json_str); - JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str())); + config["port"] = Integer{static_cast(port)}; + config["n_workers"] = Integer{static_cast(n_workers)}; + config["timeout"] = Integer{static_cast(timeout)}; + config["sortby"] = Integer{static_cast(sortby)}; + config["dmlc_communicator"] = String{"rabit"}; + std::string sconfig = Json::Dump(config); + JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle)); + setHandle(jenv, jout, handle); + + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerRun + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass, + jlong jhandle) { + auto handle = reinterpret_cast(jhandle); + JVM_CHECK_CALL(XGTrackerRun(handle, nullptr)); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWaitFor + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass, + jlong jhandle, + jlong timeout) { + using namespace xgboost; // NOLINT + + auto handle = reinterpret_cast(jhandle); + Json config{Object{}}; + config["timeout"] = Integer{static_cast(timeout)}; + std::string sconfig = Json::Dump(config); + JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str())); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWorkerArgs + * Signature: (JJ[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs( + JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) { + using namespace xgboost; // NOLINT + + Json config{Object{}}; + config["timeout"] = Integer{static_cast(timeout)}; + std::string sconfig = Json::Dump(config); + auto handle = reinterpret_cast(jhandle); + char const *args; + JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args)); + auto jargs = Json::Load(StringView{args}); + + jstring jret = jenv->NewStringUTF(args); + jenv->SetObjectArrayElement(jout, 0, jret); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass, + jlong jhandle) { + auto handle = reinterpret_cast(jhandle); + JVM_CHECK_CALL(XGTrackerFree(handle)); return 0; } @@ -1041,8 +1125,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit * Method: CommunicatorFinalize * Signature: ()I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize - (JNIEnv *jenv, jclass jcls) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *, + jclass) { JVM_CHECK_CALL(XGCommunicatorFinalize()); return 0; } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index cc4ad53d4..c8e48cfc9 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit - * Signature: ([Ljava/lang/String;)I + * Signature: (Ljava/lang/String;)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit - (JNIEnv *, jclass, jobjectArray); + (JNIEnv *, jclass, jstring); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI @@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize (JNIEnv *, jclass, jintArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerCreate + * Signature: (Ljava/lang/String;IIIJ[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate + (JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerRun + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun + (JNIEnv *, jclass, jlong); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWaitFor + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWorkerArgs + * Signature: (JJ[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs + (JNIEnv *, jclass, jlong, jlong, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree + (JNIEnv *, jclass, jlong); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorAllreduce diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index d658c5529..b6ffe84e3 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -298,7 +298,7 @@ public class DMatrixTest { @Test public void testTrainWithDenseMatrixRef() throws XGBoostError { - Map rabitEnv = new HashMap<>(); + Map rabitEnv = new HashMap<>(); rabitEnv.put("DMLC_TASK_ID", "0"); Communicator.init(rabitEnv); DMatrix trainMat = null; diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 0865e756e..9a51f59d5 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -31,31 +31,13 @@ protobuf_generate( PLUGIN "protoc-gen-grpc=\$" PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") -add_library(federated_old_proto STATIC federated.old.proto) -target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) -target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) -xgboost_target_properties(federated_old_proto) - -protobuf_generate( - TARGET federated_old_proto - LANGUAGE cpp - PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") -protobuf_generate( - TARGET federated_old_proto - LANGUAGE grpc - GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc - PLUGIN "protoc-gen-grpc=\$" - PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") - # Wrapper for the gRPC client. add_library(federated_client INTERFACE) -target_sources(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) -target_link_libraries(federated_client INTERFACE federated_old_proto) # Rabit engine for Federated Learning. target_sources( - objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc + objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc ) if(USE_CUDA) target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu) diff --git a/plugin/federated/federated.old.proto b/plugin/federated/federated.old.proto deleted file mode 100644 index 8450659fd..000000000 --- a/plugin/federated/federated.old.proto +++ /dev/null @@ -1,81 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -syntax = "proto3"; - -package xgboost.federated; - -service Federated { - rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} - rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {} - rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} - rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} -} - -enum DataType { - INT8 = 0; - UINT8 = 1; - INT32 = 2; - UINT32 = 3; - INT64 = 4; - UINT64 = 5; - FLOAT = 6; - DOUBLE = 7; -} - -enum ReduceOperation { - MAX = 0; - MIN = 1; - SUM = 2; - BITWISE_AND = 3; - BITWISE_OR = 4; - BITWISE_XOR = 5; -} - -message AllgatherRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; -} - -message AllgatherReply { - bytes receive_buffer = 1; -} - -message AllgatherVRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; -} - -message AllgatherVReply { - bytes receive_buffer = 1; -} - -message AllreduceRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; - DataType data_type = 4; - ReduceOperation reduce_operation = 5; -} - -message AllreduceReply { - bytes receive_buffer = 1; -} - -message BroadcastRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; - // The root rank to broadcast from. - int32 root = 4; -} - -message BroadcastReply { - bytes receive_buffer = 1; -} diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h deleted file mode 100644 index 0122a5cfe..000000000 --- a/plugin/federated/federated_client.h +++ /dev/null @@ -1,132 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include -#include -#include - -#include -#include -#include -#include - -namespace xgboost::federated { -/** - * @brief A wrapper around the gRPC client. - */ -class FederatedClient { - public: - FederatedClient(std::string const &server_address, int rank, std::string const &server_cert, - std::string const &client_key, std::string const &client_cert) - : stub_{[&] { - grpc::SslCredentialsOptions options; - options.pem_root_certs = server_cert; - options.pem_private_key = client_key; - options.pem_cert_chain = client_cert; - grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - auto channel = - grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args); - channel->WaitForConnected( - gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); - return Federated::NewStub(channel); - }()}, - rank_{rank} {} - - /** @brief Insecure client for connecting to localhost only. */ - FederatedClient(std::string const &server_address, int rank) - : stub_{[&] { - grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - return Federated::NewStub( - grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args)); - }()}, - rank_{rank} {} - - std::string Allgather(std::string_view send_buffer) { - AllgatherRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer.data(), send_buffer.size()); - - AllgatherReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Allgather(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Allgather RPC failed"); - } - } - - std::string AllgatherV(std::string_view send_buffer) { - AllgatherVRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer.data(), send_buffer.size()); - - AllgatherVReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->AllgatherV(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("AllgatherV RPC failed"); - } - } - - std::string Allreduce(std::string const &send_buffer, DataType data_type, - ReduceOperation reduce_operation) { - AllreduceRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer); - request.set_data_type(data_type); - request.set_reduce_operation(reduce_operation); - - AllreduceReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Allreduce(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Allreduce RPC failed"); - } - } - - std::string Broadcast(std::string const &send_buffer, int root) { - BroadcastRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer); - request.set_root(root); - - BroadcastReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Broadcast(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Broadcast RPC failed"); - } - } - - private: - std::unique_ptr const stub_; - int const rank_; - uint64_t sequence_number_{}; -}; -} // namespace xgboost::federated diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h deleted file mode 100644 index 46c6b0fda..000000000 --- a/plugin/federated/federated_communicator.h +++ /dev/null @@ -1,195 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include "../../src/c_api/c_api_utils.h" -#include "../../src/collective/communicator.h" -#include "../../src/common/io.h" -#include "federated_client.h" - -namespace xgboost::collective { -/** - * @brief A Federated Learning communicator class that handles collective communication. - */ -class FederatedCommunicator : public Communicator { - public: - /** - * @brief Create a new communicator based on JSON configuration. - * @param config JSON configuration. - * @return Communicator as specified by the JSON configuration. - */ - static Communicator *Create(Json const &config) { - std::string server_address{}; - int world_size{0}; - int rank{-1}; - std::string server_cert{}; - std::string client_key{}; - std::string client_cert{}; - - // Parse environment variables first. - auto *value = getenv("FEDERATED_SERVER_ADDRESS"); - if (value != nullptr) { - server_address = value; - } - value = getenv("FEDERATED_WORLD_SIZE"); - if (value != nullptr) { - world_size = std::stoi(value); - } - value = getenv("FEDERATED_RANK"); - if (value != nullptr) { - rank = std::stoi(value); - } - value = getenv("FEDERATED_SERVER_CERT"); - if (value != nullptr) { - server_cert = value; - } - value = getenv("FEDERATED_CLIENT_KEY"); - if (value != nullptr) { - client_key = value; - } - value = getenv("FEDERATED_CLIENT_CERT"); - if (value != nullptr) { - client_cert = value; - } - - // Runtime configuration overrides, optional as users can specify them as env vars. - server_address = OptionalArg(config, "federated_server_address", server_address); - world_size = - OptionalArg(config, "federated_world_size", static_cast(world_size)); - rank = OptionalArg(config, "federated_rank", static_cast(rank)); - server_cert = OptionalArg(config, "federated_server_cert", server_cert); - client_key = OptionalArg(config, "federated_client_key", client_key); - client_cert = OptionalArg(config, "federated_client_cert", client_cert); - - if (server_address.empty()) { - LOG(FATAL) << "Federated server address must be set."; - } - if (world_size == 0) { - LOG(FATAL) << "Federated world size must be set."; - } - if (rank == -1) { - LOG(FATAL) << "Federated rank must be set."; - } - return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key, - client_cert); - } - - /** - * @brief Construct a new federated communicator. - * - * @param world_size Total number of processes. - * @param rank Rank of the current process. - * @param server_address Address of the federated server (host:port). - * @param server_cert_path Path to the server cert file. - * @param client_key_path Path to the client key file. - * @param client_cert_path Path to the client cert file. - */ - FederatedCommunicator(int world_size, int rank, std::string const &server_address, - std::string const &server_cert_path, std::string const &client_key_path, - std::string const &client_cert_path) - : Communicator{world_size, rank} { - if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) { - client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); - } else { - client_.reset(new xgboost::federated::FederatedClient( - server_address, rank, xgboost::common::ReadAll(server_cert_path), - xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path))); - } - } - - /** - * @brief Construct an insecure federated communicator without using SSL. - * @param world_size Total number of processes. - * @param rank Rank of the current process. - * @param server_address Address of the federated server (host:port). - */ - FederatedCommunicator(int world_size, int rank, std::string const &server_address) - : Communicator{world_size, rank} { - client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); - } - - ~FederatedCommunicator() override { client_.reset(); } - - /** - * \brief Get if the communicator is distributed. - * \return True. - */ - [[nodiscard]] bool IsDistributed() const override { return true; } - - /** - * \brief Get if the communicator is federated. - * \return True. - */ - [[nodiscard]] bool IsFederated() const override { return true; } - - /** - * \brief Perform allgather. - * \param input Buffer for sending data. - */ - std::string AllGather(std::string_view input) override { - return client_->Allgather(input); - } - - /** - * \brief Perform variable-length allgather. - * \param input Buffer for sending data. - */ - std::string AllGatherV(std::string_view input) override { - return client_->AllgatherV(input); - } - - /** - * \brief Perform in-place allreduce. - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param count Number of elements to be reduced. - * \param data_type Enumeration of data type. - * \param op Enumeration of operation type. - */ - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - std::string const send_buffer(reinterpret_cast(send_receive_buffer), - count * GetTypeSize(data_type)); - auto const received = - client_->Allreduce(send_buffer, static_cast(data_type), - static_cast(op)); - received.copy(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); - } - - /** - * \brief Broadcast a memory region to all others from root. - * \param send_receive_buffer Pointer to the send or receive buffer. - * \param size Size of the data. - * \param root The process rank to broadcast from. - */ - void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { - if (GetWorldSize() == 1) return; - if (GetRank() == root) { - std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); - client_->Broadcast(send_buffer, root); - } else { - auto const received = client_->Broadcast("", root); - received.copy(reinterpret_cast(send_receive_buffer), size); - } - } - - /** - * \brief Get the name of the processor. - * \return Name of the processor. - */ - std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); } - - /** - * \brief Print the message to the communicator. - * \param message The message to be printed. - */ - void Print(const std::string &message) override { LOG(CONSOLE) << message; } - - protected: - void Shutdown() override {} - - private: - std::unique_ptr client_{}; -}; -} // namespace xgboost::collective diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc deleted file mode 100644 index 9dd97c2e1..000000000 --- a/plugin/federated/federated_server.cc +++ /dev/null @@ -1,86 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "federated_server.h" - -#include -#include // for Server -#include -#include - -#include - -#include "../../src/collective/comm.h" -#include "../../src/common/io.h" -#include "../../src/common/json_utils.h" - -namespace xgboost::federated { -grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request, - AllgatherReply* reply) { - handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); - return grpc::Status::OK; -} - -grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request, - AllgatherVReply* reply) { - handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); - return grpc::Status::OK; -} - -grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request, - AllreduceReply* reply) { - handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), - static_cast(request->data_type()), - static_cast(request->reduce_operation())); - return grpc::Status::OK; -} - -grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request, - BroadcastReply* reply) { - handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), - request->root()); - return grpc::Status::OK; -} - -void RunServer(int port, std::size_t world_size, char const* server_key_file, - char const* server_cert_file, char const* client_cert_file) { - std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{static_cast(world_size)}; - - grpc::ServerBuilder builder; - auto options = - grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); - options.pem_root_certs = xgboost::common::ReadAll(client_cert_file); - auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); - key.private_key = xgboost::common::ReadAll(server_key_file); - key.cert_chain = xgboost::common::ReadAll(server_cert_file); - options.pem_key_cert_pairs.push_back(key); - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); - builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); - LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " - << world_size; - - server->Wait(); -} - -void RunInsecureServer(int port, std::size_t world_size) { - std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{static_cast(world_size)}; - - grpc::ServerBuilder builder; - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); - LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " - << world_size; - - server->Wait(); -} -} // namespace xgboost::federated diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h deleted file mode 100644 index 4692ad6c2..000000000 --- a/plugin/federated/federated_server.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2022-2024, XGBoost contributors - */ -#pragma once - -#include - -#include // for int32_t - -#include "../../src/collective/in_memory_handler.h" - -namespace xgboost::federated { -class FederatedService final : public Federated::Service { - public: - explicit FederatedService(std::int32_t world_size) : handler_{world_size} {} - - grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, - AllgatherReply* reply) override; - - grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request, - AllgatherVReply* reply) override; - - grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, - AllreduceReply* reply) override; - - grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, - BroadcastReply* reply) override; - - private: - xgboost::collective::InMemoryHandler handler_; -}; - -void RunServer(int port, std::size_t world_size, char const* server_key_file, - char const* server_cert_file, char const* client_cert_file); - -void RunInsecureServer(int port, std::size_t world_size); -} // namespace xgboost::federated diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc index 5051d43cb..95c0824d9 100644 --- a/plugin/federated/federated_tracker.cc +++ b/plugin/federated/federated_tracker.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost contributors + * Copyright 2022-2024, XGBoost contributors */ #include "federated_tracker.h" @@ -8,13 +8,12 @@ #include // for int32_t #include // for exception +#include // for future, async #include // for numeric_limits #include // for string -#include // for sleep_for #include "../../src/common/io.h" // for ReadAll #include "../../src/common/json_utils.h" // for RequiredArg -#include "../../src/common/timer.h" // for Timer namespace xgboost::collective { namespace federated { @@ -36,8 +35,8 @@ grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest AllreduceReply* reply) { handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), - static_cast(request->data_type()), - static_cast(request->reduce_operation())); + static_cast(request->data_type()), + static_cast(request->reduce_operation())); return grpc::Status::OK; } @@ -53,9 +52,13 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { auto is_secure = RequiredArg(config, "federated_secure", __func__); if (is_secure) { + StringView msg{"Empty certificate path."}; server_key_path_ = RequiredArg(config, "server_key_path", __func__); + CHECK(!server_key_path_.empty()) << msg; server_cert_file_ = RequiredArg(config, "server_cert_path", __func__); + CHECK(!server_cert_file_.empty()) << msg; client_cert_file_ = RequiredArg(config, "client_cert_path", __func__); + CHECK(!client_cert_file_.empty()) << msg; } } diff --git a/plugin/sycl/device_manager.cc b/plugin/sycl/device_manager.cc index 0254cdd6a..072c9fd55 100644 --- a/plugin/sycl/device_manager.cc +++ b/plugin/sycl/device_manager.cc @@ -5,11 +5,12 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-W#pragma-messages" -#include #pragma GCC diagnostic pop #include "../sycl/device_manager.h" +#include "../../src/collective/communicator-inl.h" + namespace xgboost { namespace sycl { @@ -21,22 +22,23 @@ namespace sycl { } bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) || - (rabit::IsDistributed()); + (collective::IsDistributed()); if (not_use_default_selector) { DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal; + const int device_idx = + collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { - auto& devices = device_register.devices; - CHECK_LT(device_idx, devices.size()); - return devices[device_idx]; + auto& devices = device_register.devices; + CHECK_LT(device_idx, devices.size()); + return devices[device_idx]; } else if (device_spec.IsSyclCPU()) { - auto& cpu_devices = device_register.cpu_devices; - CHECK_LT(device_idx, cpu_devices.size()); - return cpu_devices[device_idx]; + auto& cpu_devices = device_register.cpu_devices; + CHECK_LT(device_idx, cpu_devices.size()); + return cpu_devices[device_idx]; } else { - auto& gpu_devices = device_register.gpu_devices; - CHECK_LT(device_idx, gpu_devices.size()); - return gpu_devices[device_idx]; + auto& gpu_devices = device_register.gpu_devices; + CHECK_LT(device_idx, gpu_devices.size()); + return gpu_devices[device_idx]; } } else { if (device_spec.IsSyclCPU()) { @@ -62,24 +64,25 @@ namespace sycl { } bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) || - (rabit::IsDistributed()); + (collective::IsDistributed()); std::lock_guard guard(queue_registering_mutex); if (not_use_default_selector) { - DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal; - if (device_spec.IsSyclDefault()) { - auto& devices = device_register.devices; - CHECK_LT(device_idx, devices.size()); - queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); - } else if (device_spec.IsSyclCPU()) { - auto& cpu_devices = device_register.cpu_devices; - CHECK_LT(device_idx, cpu_devices.size()); - queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);; - } else if (device_spec.IsSyclGPU()) { - auto& gpu_devices = device_register.gpu_devices; - CHECK_LT(device_idx, gpu_devices.size()); - queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); - } + DeviceRegister& device_register = GetDevicesRegister(); + const int device_idx = + collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; + if (device_spec.IsSyclDefault()) { + auto& devices = device_register.devices; + CHECK_LT(device_idx, devices.size()); + queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); + } else if (device_spec.IsSyclCPU()) { + auto& cpu_devices = device_register.cpu_devices; + CHECK_LT(device_idx, cpu_devices.size()); + queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]); + } else if (device_spec.IsSyclGPU()) { + auto& gpu_devices = device_register.gpu_devices; + CHECK_LT(device_idx, gpu_devices.size()); + queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); + } } else { if (device_spec.IsSyclCPU()) { queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v); diff --git a/plugin/sycl/objective/multiclass_obj.cc b/plugin/sycl/objective/multiclass_obj.cc index 3104dd35e..16efe2a45 100644 --- a/plugin/sycl/objective/multiclass_obj.cc +++ b/plugin/sycl/objective/multiclass_obj.cc @@ -6,7 +6,6 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-W#pragma-messages" -#include #pragma GCC diagnostic pop #include diff --git a/plugin/sycl/objective/regression_obj.cc b/plugin/sycl/objective/regression_obj.cc index 985498717..82467a7c4 100644 --- a/plugin/sycl/objective/regression_obj.cc +++ b/plugin/sycl/objective/regression_obj.cc @@ -9,7 +9,6 @@ #include #include #pragma GCC diagnostic pop -#include #include #include diff --git a/plugin/sycl/predictor/predictor.cc b/plugin/sycl/predictor/predictor.cc index 943949c2a..c941bca10 100755 --- a/plugin/sycl/predictor/predictor.cc +++ b/plugin/sycl/predictor/predictor.cc @@ -4,7 +4,6 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-W#pragma-messages" -#include #pragma GCC diagnostic pop #include diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index a41d296bf..468d38942 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -1,17 +1,17 @@ """XGBoost collective communication related API.""" import ctypes -import json import logging import os import pickle +import platform from enum import IntEnum, unique from typing import Any, Dict, List, Optional import numpy as np from ._typing import _T -from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str +from .core import _LIB, _check_call, build_info, c_str, make_jcargs, py_str LOGGER = logging.getLogger("[xgboost.collective]") @@ -21,49 +21,35 @@ def init(**args: Any) -> None: Parameters ---------- - args: Dict[str, Any] + args : Keyword arguments representing the parameters and their values. Accepted parameters: - - xgboost_communicator: The type of the communicator. Can be set as an environment - variable. + - dmlc_communicator: The type of the communicator. * rabit: Use Rabit. This is the default if the type is unspecified. * federated: Use the gRPC interface for Federated Learning. - Only applicable to the Rabit communicator (these are case sensitive): - -- rabit_tracker_uri: Hostname of the tracker. - -- rabit_tracker_port: Port number of the tracker. - -- rabit_task_id: ID of the current task, can be used to obtain deterministic rank - assignment. - -- rabit_world_size: Total number of workers. - -- rabit_hadoop_mode: Enable Hadoop support. - -- rabit_tree_reduce_minsize: Minimal size for tree reduce. - -- rabit_reduce_ring_mincount: Minimal count to perform ring reduce. - -- rabit_reduce_buffer: Size of the reduce buffer. - -- rabit_bootstrap_cache: Size of the bootstrap cache. - -- rabit_debug: Enable debugging. - -- rabit_timeout: Enable timeout. - -- rabit_timeout_sec: Timeout in seconds. - -- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. - Only applicable to the Rabit communicator (these are case-sensitive, and can be set as - environment variables): - -- DMLC_TRACKER_URI: Hostname of the tracker. - -- DMLC_TRACKER_PORT: Port number of the tracker. - -- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank - assignment. - -- DMLC_ROLE: Role of the current task, "worker" or "server". - -- DMLC_NUM_ATTEMPT: Number of attempts after task failure. - -- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - Only applicable to the Federated communicator (use upper case for environment variables, use - lower case for runtime configuration): - -- federated_server_address: Address of the federated server. - -- federated_world_size: Number of federated workers. - -- federated_rank: Rank of the current worker. - -- federated_server_cert: Server certificate file path. Only needed for the SSL mode. - -- federated_client_key: Client key file path. Only needed for the SSL mode. - -- federated_client_cert: Client certificate file path. Only needed for the SSL mode. + + Only applicable to the Rabit communicator: + - dmlc_tracker_uri: Hostname of the tracker. + - dmlc_tracker_port: Port number of the tracker. + - dmlc_task_id: ID of the current task, can be used to obtain deterministic + - dmlc_retry: The number of retry when handling network errors. + - dmlc_timeout: Timeout in seconds. + - dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication. + + Only applicable to the Federated communicator (use upper case for environment + variables, use lower case for runtime configuration): + + - federated_server_address: Address of the federated server. + - federated_world_size: Number of federated workers. + - federated_rank: Rank of the current worker. + - federated_server_cert: Server certificate file path. Only needed for the SSL + mode. + - federated_client_key: Client key file path. Only needed for the SSL mode. + - federated_client_cert: Client certificate file path. Only needed for the SSL + mode. """ - config = from_pystr_to_cstr(json.dumps(args)) - _check_call(_LIB.XGCommunicatorInit(config)) + _check_call(_LIB.XGCommunicatorInit(make_jcargs(**args))) def finalize() -> None: @@ -157,7 +143,7 @@ def broadcast(data: _T, root: int) -> _T: assert data is not None, "need to pass in data when broadcasting" s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) length.value = len(s) - # run first broadcast + # Run first broadcast _check_call( _LIB.XGCommunicatorBroadcast( ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root @@ -184,16 +170,27 @@ def broadcast(data: _T, root: int) -> _T: # enumeration of dtypes -DTYPE_ENUM__ = { - np.dtype("int8"): 0, - np.dtype("uint8"): 1, - np.dtype("int32"): 2, - np.dtype("uint32"): 3, - np.dtype("int64"): 4, - np.dtype("uint64"): 5, - np.dtype("float32"): 6, - np.dtype("float64"): 7, -} +def _map_dtype(dtype: np.dtype) -> int: + dtype_map = { + np.dtype("float16"): 0, + np.dtype("float32"): 1, + np.dtype("float64"): 2, + np.dtype("int8"): 4, + np.dtype("int16"): 5, + np.dtype("int32"): 6, + np.dtype("int64"): 7, + np.dtype("uint8"): 8, + np.dtype("uint16"): 9, + np.dtype("uint32"): 10, + np.dtype("uint64"): 11, + } + if platform.system() != "Windows": + dtype_map.update({np.dtype("float128"): 3}) + + if dtype not in dtype_map: + raise TypeError(f"data type {dtype} is not supported on the current platform.") + + return dtype_map[dtype] @unique @@ -229,24 +226,23 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid """ if not isinstance(data, np.ndarray): raise TypeError("allreduce only takes in numpy.ndarray") - buf = data.ravel() - if buf.base is data.base: - buf = buf.copy() - if buf.dtype not in DTYPE_ENUM__: - raise TypeError(f"data type {buf.dtype} not supported") + buf = data.ravel().copy() _check_call( _LIB.XGCommunicatorAllreduce( buf.ctypes.data_as(ctypes.c_void_p), buf.size, - DTYPE_ENUM__[buf.dtype], + _map_dtype(buf.dtype), int(op), - None, - None, ) ) return buf +def signal_error() -> None: + """Kill the process.""" + _check_call(_LIB.XGCommunicatorSignalError()) + + class CommunicatorContext: """A context controlling collective communicator initialization and finalization.""" diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index b54fef40c..76251d65c 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -295,7 +295,7 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None: if device and device.find(":") != -1: raise ValueError( "Distributed training doesn't support selecting device ordinal as GPUs are" - " managed by the distributed framework. use `device=cuda` or `device=gpu`" + " managed by the distributed frameworks. use `device=cuda` or `device=gpu`" " instead." ) diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index bdc40360b..c74caecdb 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -71,6 +71,7 @@ from xgboost.core import ( Metric, Objective, QuantileDMatrix, + XGBoostError, _check_distributed_params, _deprecate_positional_args, _expect, @@ -90,7 +91,7 @@ from xgboost.sklearn import ( _wrap_evaluation_matrices, xgboost_model_doc, ) -from xgboost.tracker import RabitTracker, get_host_ip +from xgboost.tracker import RabitTracker from xgboost.training import train as worker_train from .utils import get_n_threads @@ -160,36 +161,38 @@ def _try_start_tracker( n_workers: int, addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]], ) -> Dict[str, Union[int, str]]: - env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers} + env: Dict[str, Union[int, str]] = {} try: if isinstance(addrs[0], tuple): host_ip = addrs[0][0] port = addrs[0][1] rabit_tracker = RabitTracker( - host_ip=get_host_ip(host_ip), n_workers=n_workers, + host_ip=host_ip, port=port, - use_logger=False, + sortby="task", ) else: addr = addrs[0] assert isinstance(addr, str) or addr is None - host_ip = get_host_ip(addr) rabit_tracker = RabitTracker( - host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task" + n_workers=n_workers, host_ip=addr, sortby="task" ) - env.update(rabit_tracker.worker_envs()) - rabit_tracker.start(n_workers) - thread = Thread(target=rabit_tracker.join) + + rabit_tracker.start() + thread = Thread(target=rabit_tracker.wait_for) thread.daemon = True thread.start() - except socket.error as e: - if len(addrs) < 2 or e.errno != 99: + env.update(rabit_tracker.worker_args()) + + except XGBoostError as e: + if len(addrs) < 2: raise LOGGER.warning( - "Failed to bind address '%s', trying to use '%s' instead.", + "Failed to bind address '%s', trying to use '%s' instead. Error:\n %s", str(addrs[0]), str(addrs[1]), + str(e), ) env = _try_start_tracker(n_workers, addrs[1:]) diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py index 0214e4e20..dcba9ec81 100644 --- a/python-package/xgboost/federated.py +++ b/python-package/xgboost/federated.py @@ -1,45 +1,85 @@ -"""XGBoost Federated Learning related API.""" +"""XGBoost Experimental Federated Learning related API.""" -from .core import _LIB, XGBoostError, _check_call, build_info, c_str +import ctypes +from threading import Thread +from typing import Any, Dict, Optional + +from .core import _LIB, _check_call, make_jcargs +from .tracker import RabitTracker -def run_federated_server( - port: int, - world_size: int, - server_key_path: str = "", - server_cert_path: str = "", - client_cert_path: str = "", -) -> None: - """Run the Federated Learning server. +class FederatedTracker(RabitTracker): + """Tracker for federated training. Parameters ---------- - port : int - The port to listen on. - world_size: int + n_workers : The number of federated workers. - server_key_path: str - Path to the server private key file. SSL is turned off if empty. - server_cert_path: str - Path to the server certificate file. SSL is turned off if empty. - client_cert_path: str - Path to the client certificate file. SSL is turned off if empty. + + port : + The port to listen on. + + secure : + Whether this is a secure instance. If True, then the following arguments for SSL + must be provided. + + server_key_path : + Path to the server private key file. + + server_cert_path : + Path to the server certificate file. + + client_cert_path : + Path to the client certificate file. + """ - if build_info()["USE_FEDERATED"]: - if not server_key_path or not server_cert_path or not client_cert_path: - _check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size)) - else: - _check_call( - _LIB.XGBRunFederatedServer( - port, - world_size, - c_str(server_key_path), - c_str(server_cert_path), - c_str(client_cert_path), - ) - ) - else: - raise XGBoostError( - "XGBoost needs to be built with the federated learning plugin " - "enabled in order to use this module" + + def __init__( # pylint: disable=R0913, W0231 + self, + n_workers: int, + port: int, + secure: bool, + server_key_path: str = "", + server_cert_path: str = "", + client_cert_path: str = "", + timeout: int = 300, + ) -> None: + handle = ctypes.c_void_p() + args = make_jcargs( + n_workers=n_workers, + port=port, + dmlc_communicator="federated", + federated_secure=secure, + server_key_path=server_key_path, + server_cert_path=server_cert_path, + client_cert_path=client_cert_path, + timeout=int(timeout), ) + _check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle))) + self.handle = handle + + +def run_federated_server( # pylint: disable=too-many-arguments + n_workers: int, + port: int, + server_key_path: Optional[str] = None, + server_cert_path: Optional[str] = None, + client_cert_path: Optional[str] = None, + timeout: int = 300, +) -> Dict[str, Any]: + """See :py:class:`~xgboost.federated.FederatedTracker` for more info.""" + args: Dict[str, Any] = {"n_workers": n_workers} + secure = all( + path is not None + for path in [server_key_path, server_cert_path, client_cert_path] + ) + tracker = FederatedTracker( + n_workers=n_workers, port=port, secure=secure, timeout=timeout + ) + tracker.start() + + thread = Thread(target=tracker.wait_for) + thread.daemon = True + thread.start() + args.update(tracker.worker_args()) + return args diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 1403596c0..0a421031e 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -47,21 +47,21 @@ class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods """Context with PySpark specific task ID.""" def __init__(self, context: BarrierTaskContext, **args: Any) -> None: - args["DMLC_TASK_ID"] = str(context.partitionId()) + args["dmlc_task_id"] = str(context.partitionId()) super().__init__(**args) def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: """Start Rabit tracker with n_workers""" - env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers} + args: Dict[str, Any] = {"n_workers": n_workers} host = _get_host_ip(context) - rabit_context = RabitTracker(host_ip=host, n_workers=n_workers, sortby="task") - env.update(rabit_context.worker_envs()) - rabit_context.start(n_workers) - thread = Thread(target=rabit_context.join) + tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task") + tracker.start() + thread = Thread(target=tracker.wait_for) thread.daemon = True thread.start() - return env + args.update(tracker.worker_args()) + return args def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 77907c42f..b85c0f325 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -111,8 +111,6 @@ def no_sklearn() -> PytestSkip: def no_dask() -> PytestSkip: - if sys.platform.startswith("win"): - return {"reason": "Unsupported platform.", "condition": True} return no_mod("dask") @@ -193,6 +191,10 @@ def no_multiple(*args: Any) -> PytestSkip: return {"condition": condition, "reason": reason} +def skip_win() -> PytestSkip: + return {"reason": "Unsupported platform.", "condition": is_windows()} + + def skip_s390x() -> PytestSkip: condition = platform.machine() == "s390x" reason = "Known to fail on s390x" @@ -968,18 +970,18 @@ def run_with_rabit( exception_queue.put(e) tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) - tracker.start(world_size) + tracker.start() workers = [] for _ in range(world_size): - worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),)) + worker = threading.Thread(target=run_worker, args=(tracker.worker_args(),)) workers.append(worker) worker.start() for worker in workers: worker.join() assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}" - tracker.join() + tracker.wait_for() def column_split_feature_names( diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 606c63791..d88b25640 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -1,64 +1,12 @@ -# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches -""" -This script is a variant of dmlc-core/dmlc_tracker/tracker.py, -which is a specialized version for xgboost tasks. -""" -import argparse -import logging +"""Tracker for XGBoost collective.""" + +import ctypes +import json import socket -import struct -import sys -from threading import Thread -from typing import Dict, List, Optional, Set, Tuple, Union +from enum import IntEnum, unique +from typing import Dict, Optional, Union -_RingMap = Dict[int, Tuple[int, int]] -_TreeMap = Dict[int, List[int]] - - -class ExSocket: - """ - Extension of socket to handle recv and send of special data - """ - - def __init__(self, sock: socket.socket) -> None: - self.sock = sock - - def recvall(self, nbytes: int) -> bytes: - """Receive number of bytes.""" - res = [] - nread = 0 - while nread < nbytes: - chunk = self.sock.recv(min(nbytes - nread, 1024)) - nread += len(chunk) - res.append(chunk) - return b"".join(res) - - def recvint(self) -> int: - """Receive an integer of 32 bytes""" - return struct.unpack("@i", self.recvall(4))[0] - - def sendint(self, value: int) -> None: - """Send an integer of 32 bytes""" - self.sock.sendall(struct.pack("@i", value)) - - def sendstr(self, value: str) -> None: - """Send a Python string""" - self.sendint(len(value)) - self.sock.sendall(value.encode()) - - def recvstr(self) -> str: - """Receive a Python string""" - slen = self.recvint() - return self.recvall(slen).decode() - - -# magic number used to verify existence of data -MAGIC_NUM = 0xFF99 - - -def get_some_ip(host: str) -> str: - """Get ip from host""" - return socket.getaddrinfo(host, None)[0][4][0] +from .core import _LIB, _check_call, make_jcargs def get_family(addr: str) -> int: @@ -66,439 +14,95 @@ def get_family(addr: str) -> int: return socket.getaddrinfo(addr, None)[0][0] -class WorkerEntry: - """Hanlder to each worker.""" - - def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): - worker = ExSocket(sock) - self.sock = worker - self.host = get_some_ip(s_addr[0]) - magic = worker.recvint() - assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}" - worker.sendint(MAGIC_NUM) - self.rank = worker.recvint() - self.world_size = worker.recvint() - self.task_id = worker.recvstr() - self.cmd = worker.recvstr() - self.wait_accept = 0 - self.port: Optional[int] = None - - def print(self, use_logger: bool) -> None: - """Execute the print command from worker.""" - msg = self.sock.recvstr() - # On dask we use print to avoid setting global verbosity. - if use_logger: - logging.info(msg.strip()) - else: - print(msg.strip(), flush=True) - - def decide_rank(self, job_map: Dict[str, int]) -> int: - """Get the rank of current entry.""" - if self.rank >= 0: - return self.rank - if self.task_id != "NULL" and self.task_id in job_map: - return job_map[self.task_id] - return -1 - - def assign_rank( - self, - rank: int, - wait_conn: Dict[int, "WorkerEntry"], - tree_map: _TreeMap, - parent_map: Dict[int, int], - ring_map: _RingMap, - ) -> List[int]: - """Assign the rank for current entry.""" - self.rank = rank - nnset = set(tree_map[rank]) - rprev, next_rank = ring_map[rank] - self.sock.sendint(rank) - # send parent rank - self.sock.sendint(parent_map[rank]) - # send world size - self.sock.sendint(len(tree_map)) - self.sock.sendint(len(nnset)) - # send the rprev and next link - for r in nnset: - self.sock.sendint(r) - # send prev link - if rprev not in (-1, rank): - nnset.add(rprev) - self.sock.sendint(rprev) - else: - self.sock.sendint(-1) - # send next link - if next_rank not in (-1, rank): - nnset.add(next_rank) - self.sock.sendint(next_rank) - else: - self.sock.sendint(-1) - - return self._get_remote(wait_conn, nnset) - - def _get_remote( - self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int] - ) -> List[int]: - while True: - conset = [] - for r in badset: - if r in wait_conn: - conset.append(r) - self.sock.sendint(len(conset)) - self.sock.sendint(len(badset) - len(conset)) - for r in conset: - self.sock.sendstr(wait_conn[r].host) - port = wait_conn[r].port - assert port is not None - # send port of this node to other workers so that they can call connect - self.sock.sendint(port) - self.sock.sendint(r) - nerr = self.sock.recvint() - if nerr != 0: - continue - self.port = self.sock.recvint() - rmset = [] - # all connection was successuly setup - for r in conset: - wait_conn[r].wait_accept -= 1 - if wait_conn[r].wait_accept == 0: - rmset.append(r) - for r in rmset: - wait_conn.pop(r, None) - self.wait_accept = len(badset) - len(conset) - return rmset - - class RabitTracker: - """ - tracker for rabit - """ - - def __init__( - self, - host_ip: str, - n_workers: int, - port: int = 0, - use_logger: bool = False, - sortby: str = "host", - ) -> None: - """A Python implementation of RABIT tracker. - - Parameters - .......... - use_logger: - Use logging.info for tracker print command. When set to False, Python print - function is used instead. - - sortby: - How to sort the workers for rank assignment. The default is host, but users - can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain - deterministic rank assignment. Available options are: - - host - - task - - """ - sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM) - sock.bind((host_ip, port)) - self.port = sock.getsockname()[1] - sock.listen(256) - self.sock = sock - self.host_ip = host_ip - self.thread: Optional[Thread] = None - self.n_workers = n_workers - self._use_logger = use_logger - self._sortby = sortby - logging.info("start listen on %s:%d", host_ip, self.port) - - def __del__(self) -> None: - if hasattr(self, "sock"): - self.sock.close() - - @staticmethod - def _get_neighbor(rank: int, n_workers: int) -> List[int]: - rank = rank + 1 - ret = [] - if rank > 1: - ret.append(rank // 2 - 1) - if rank * 2 - 1 < n_workers: - ret.append(rank * 2 - 1) - if rank * 2 < n_workers: - ret.append(rank * 2) - return ret - - def worker_envs(self) -> Dict[str, Union[str, int]]: - """ - get environment variables for workers - can be passed in as args or envs - """ - return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port} - - def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: - tree_map: _TreeMap = {} - parent_map: Dict[int, int] = {} - for r in range(n_workers): - tree_map[r] = self._get_neighbor(r, n_workers) - parent_map[r] = (r + 1) // 2 - 1 - return tree_map, parent_map - - def find_share_ring( - self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int - ) -> List[int]: - """ - get a ring structure that tends to share nodes with the tree - return a list starting from rank - """ - nset = set(tree_map[rank]) - cset = nset - {parent_map[rank]} - if not cset: - return [rank] - rlst = [rank] - cnt = 0 - for v in cset: - vlst = self.find_share_ring(tree_map, parent_map, v) - cnt += 1 - if cnt == len(cset): - vlst.reverse() - rlst += vlst - return rlst - - def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap: - """ - get a ring connection used to recover local data - """ - assert parent_map[0] == -1 - rlst = self.find_share_ring(tree_map, parent_map, 0) - assert len(rlst) == len(tree_map) - ring_map: _RingMap = {} - n_workers = len(tree_map) - for r in range(n_workers): - rprev = (r + n_workers - 1) % n_workers - rnext = (r + 1) % n_workers - ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) - return ring_map - - def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]: - """ - get the link map, this is a bit hacky, call for better algorithm - to place similar nodes together - """ - tree_map, parent_map = self._get_tree(n_workers) - ring_map = self.get_ring(tree_map, parent_map) - rmap = {0: 0} - k = 0 - for i in range(n_workers - 1): - k = ring_map[k][1] - rmap[k] = i + 1 - - ring_map_: _RingMap = {} - tree_map_: _TreeMap = {} - parent_map_: Dict[int, int] = {} - for k, v in ring_map.items(): - ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) - for k, tree_nodes in tree_map.items(): - tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes] - for k, parent in parent_map.items(): - if k != 0: - parent_map_[rmap[k]] = rmap[parent] - else: - parent_map_[rmap[k]] = -1 - return tree_map_, parent_map_, ring_map_ - - def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]: - if self._sortby == "host": - pending.sort(key=lambda s: s.host) - elif self._sortby == "task": - pending.sort(key=lambda s: s.task_id) - return pending - - def accept_workers(self, n_workers: int) -> None: - """Wait for all workers to connect to the tracker.""" - - # set of nodes that finishes the job - shutdown: Dict[int, WorkerEntry] = {} - # set of nodes that is waiting for connections - wait_conn: Dict[int, WorkerEntry] = {} - # maps job id to rank - job_map: Dict[str, int] = {} - # list of workers that is pending to be assigned rank - pending: List[WorkerEntry] = [] - # lazy initialize tree_map - tree_map = None - - while len(shutdown) != n_workers: - fd, s_addr = self.sock.accept() - s = WorkerEntry(fd, s_addr) - if s.cmd == "print": - s.print(self._use_logger) - continue - if s.cmd == "shutdown": - assert s.rank >= 0 and s.rank not in shutdown - assert s.rank not in wait_conn - shutdown[s.rank] = s - logging.debug("Received %s signal from %d", s.cmd, s.rank) - continue - assert s.cmd == "start" - # lazily initialize the workers - if tree_map is None: - assert s.cmd == "start" - if s.world_size > 0: - n_workers = s.world_size - tree_map, parent_map, ring_map = self.get_link_map(n_workers) - # set of nodes that is pending for getting up - todo_nodes = list(range(n_workers)) - else: - assert s.world_size in (-1, n_workers) - if s.cmd == "recover": - assert s.rank >= 0 - - rank = s.decide_rank(job_map) - # batch assignment of ranks - if rank == -1: - assert todo_nodes - pending.append(s) - if len(pending) == len(todo_nodes): - pending = self._sort_pending(pending) - for s in pending: - rank = todo_nodes.pop(0) - if s.task_id != "NULL": - job_map[s.task_id] = rank - s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - if s.wait_accept > 0: - wait_conn[rank] = s - logging.debug( - "Received %s signal from %s; assign rank %d", - s.cmd, - s.host, - s.rank, - ) - if not todo_nodes: - logging.info("@tracker All of %d nodes getting started", n_workers) - else: - s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - logging.debug("Received %s signal from %d", s.cmd, s.rank) - if s.wait_accept > 0: - wait_conn[rank] = s - logging.info("@tracker All nodes finishes job") - - def start(self, n_workers: int) -> None: - """Strat the tracker, it will wait for `n_workers` to connect.""" - - def run() -> None: - self.accept_workers(n_workers) - - self.thread = Thread(target=run, args=(), daemon=True) - self.thread.start() - - def join(self) -> None: - """Wait for the tracker to finish.""" - while self.thread is not None and self.thread.is_alive(): - self.thread.join(100) - - def alive(self) -> bool: - """Wether the tracker thread is alive""" - return self.thread is not None and self.thread.is_alive() - - -def get_host_ip(host_ip: Optional[str] = None) -> str: - """Get the IP address of current host. If `host_ip` is not none then it will be - returned as it's - - """ - if host_ip is None or host_ip == "auto": - host_ip = "ip" - - if host_ip == "dns": - host_ip = socket.getfqdn() - elif host_ip == "ip": - from socket import gaierror - - try: - host_ip = socket.gethostbyname(socket.getfqdn()) - except gaierror: - logging.debug( - "gethostbyname(socket.getfqdn()) failed... trying on hostname()" - ) - host_ip = socket.gethostbyname(socket.gethostname()) - if host_ip.startswith("127."): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # doesn't have to be reachable - s.connect(("10.255.255.255", 1)) - host_ip = s.getsockname()[0] - - assert host_ip is not None - return host_ip - - -def start_rabit_tracker(args: argparse.Namespace) -> None: - """Standalone function to start rabit tracker. + """Tracker for the collective used in XGBoost, acting as a coordinator between + workers. Parameters - ---------- - args: arguments to start the rabit tracker. + .......... + sortby: + + How to sort the workers for rank assignment. The default is host, but users can + set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain + deterministic rank assignment. Available options are: + - host + - task + + timeout : + + Timeout for constructing the communication group and waiting for the tracker to + shutdown when it's instructed to, doesn't apply to communication when tracking + is running. + + The timeout value should take the time of data loading and pre-processing into + account, due to potential lazy execution. + + The :py:meth:`.wait_for` method has a different timeout parameter that can stop + the tracker even if the tracker is still being used. A value error is raised + when timeout is reached. + """ - envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers} - rabit = RabitTracker( - host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True - ) - envs.update(rabit.worker_envs()) - rabit.start(args.num_workers) - sys.stdout.write("DMLC_TRACKER_ENV_START\n") - # simply write configuration to stdout - for k, v in envs.items(): - sys.stdout.write(f"{k}={v}\n") - sys.stdout.write("DMLC_TRACKER_ENV_END\n") - sys.stdout.flush() - rabit.join() + @unique + class _SortBy(IntEnum): + HOST = 0 + TASK = 1 -def main() -> None: - """Main function if tracker is executed in standalone mode.""" - parser = argparse.ArgumentParser(description="Rabit Tracker start.") - parser.add_argument( - "--num-workers", - required=True, - type=int, - help="Number of worker process to be launched.", - ) - parser.add_argument( - "--num-servers", - default=0, - type=int, - help="Number of server process to be launched. Only used in PS jobs.", - ) - parser.add_argument( - "--host-ip", - default=None, - type=str, - help=( - "Host IP addressed, this is only needed " - + "if the host IP cannot be automatically guessed." - ), - ) - parser.add_argument( - "--log-level", - default="INFO", - type=str, - choices=["INFO", "DEBUG"], - help="Logging level of the logger.", - ) - args = parser.parse_args() + def __init__( # pylint: disable=too-many-arguments + self, + n_workers: int, + host_ip: Optional[str], + port: int = 0, + sortby: str = "host", + timeout: int = 0, + ) -> None: - fmt = "%(asctime)s %(levelname)s %(message)s" - if args.log_level == "INFO": - level = logging.INFO - elif args.log_level == "DEBUG": - level = logging.DEBUG - else: - raise RuntimeError(f"Unknown logging level {args.log_level}") + handle = ctypes.c_void_p() + if sortby not in ("host", "task"): + raise ValueError("Expecting either 'host' or 'task' for sortby.") + if host_ip is not None: + get_family(host_ip) # use python socket to stop early for invalid address + args = make_jcargs( + host=host_ip, + n_workers=n_workers, + port=port, + dmlc_communicator="rabit", + sortby=self._SortBy.HOST if sortby == "host" else self._SortBy.TASK, + timeout=int(timeout), + ) + _check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle))) + self.handle = handle - logging.basicConfig(format=fmt, level=level) + def free(self) -> None: + """Internal function for testing.""" + if hasattr(self, "handle"): + handle = self.handle + del self.handle + _check_call(_LIB.XGTrackerFree(handle)) - if args.num_servers == 0: - start_rabit_tracker(args) - else: - raise RuntimeError("Do not yet support start ps tracker in standalone mode.") + def __del__(self) -> None: + self.free() + def start(self) -> None: + """Start the tracker. Once started, the client still need to call the + :py:meth:`wait_for` method in order to wait for it to finish (think of it as a + thread). -if __name__ == "__main__": - main() + """ + _check_call(_LIB.XGTrackerRun(self.handle, make_jcargs())) + + def wait_for(self, timeout: Optional[int] = None) -> None: + """Wait for the tracker to finish all the work and shutdown. When timeout is + reached, a value error is raised. By default we don't have timeout since we + don't know how long it takes for the model to finish training. + + """ + _check_call(_LIB.XGTrackerWaitFor(self.handle, make_jcargs(timeout=timeout))) + + def worker_args(self) -> Dict[str, Union[str, int]]: + """Get arguments for workers.""" + c_env = ctypes.c_char_p() + _check_call(_LIB.XGTrackerWorkerArgs(self.handle, ctypes.byref(c_env))) + assert c_env.value is not None + env = json.loads(c_env.value) + return env diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt deleted file mode 100644 index 4562f864f..000000000 --- a/rabit/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -cmake_minimum_required(VERSION 3.18) - -find_package(Threads REQUIRED) - -set(RABIT_SOURCES - ${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc - ${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc) - -if(RABIT_MOCK) - list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc) -else() - list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine.cc) -endif() - -set(RABIT_SOURCES ${RABIT_SOURCES} PARENT_SCOPE) diff --git a/rabit/LICENSE b/rabit/LICENSE deleted file mode 100644 index 2485f4eaa..000000000 --- a/rabit/LICENSE +++ /dev/null @@ -1,28 +0,0 @@ -Copyright (c) 2014 by Contributors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of rabit nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - diff --git a/rabit/README.md b/rabit/README.md deleted file mode 100644 index 0be1b7015..000000000 --- a/rabit/README.md +++ /dev/null @@ -1 +0,0 @@ -# This directory contains the CPU network module for XGBoost. The library originates from [RABIT](https://github.com/dmlc/rabit) \ No newline at end of file diff --git a/rabit/include/rabit/base.h b/rabit/include/rabit/base.h deleted file mode 100644 index ab3a285d1..000000000 --- a/rabit/include/rabit/base.h +++ /dev/null @@ -1,19 +0,0 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file base.h - * \brief Macros common to all headers - * - * \author Hyunsu Cho - */ - -#ifndef RABIT_BASE_H_ -#define RABIT_BASE_H_ - -#ifndef _CRT_SECURE_NO_WARNINGS -#define _CRT_SECURE_NO_WARNINGS -#endif // _CRT_SECURE_NO_WARNINGS -#ifndef _CRT_SECURE_NO_DEPRECATE -#define _CRT_SECURE_NO_DEPRECATE -#endif // _CRT_SECURE_NO_DEPRECATE - -#endif // RABIT_BASE_H_ diff --git a/rabit/include/rabit/c_api.h b/rabit/include/rabit/c_api.h deleted file mode 100644 index 6c9e798be..000000000 --- a/rabit/include/rabit/c_api.h +++ /dev/null @@ -1,157 +0,0 @@ -/*! - * Copyright by Contributors - * \file c_api.h - * \author Tianqi Chen - * \brief a C style API of rabit. - */ -#ifndef RABIT_C_API_H_ -#define RABIT_C_API_H_ - -#ifdef __cplusplus -#define RABIT_EXTERN_C extern "C" -#include -#else -#define RABIT_EXTERN_C -#include -#endif // __cplusplus - -#if defined(_MSC_VER) || defined(_WIN32) -#define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport) -#else -#define RABIT_DLL RABIT_EXTERN_C __attribute__ ((visibility ("default"))) -#endif // defined(_MSC_VER) || defined(_WIN32) - -/*! \brief rabit unsigned long type */ -typedef unsigned long rbt_ulong; // NOLINT(*) - -/*! - * \brief initialize the rabit module, - * call this once before using anything - * The additional arguments is not necessary. - * Usually rabit will detect settings - * from environment variables. - * \param argc number of arguments in argv - * \param argv the array of input arguments - * \return true if rabit is initialized successfully otherwise false - */ -RABIT_DLL bool RabitInit(int argc, char *argv[]); - -/*! - * \brief finalize the rabit engine, - * call this function after you finished all jobs. - * \return true if rabit is initialized successfully otherwise false - */ -RABIT_DLL int RabitFinalize(void); - -/*! - * \brief get rank of previous process in ring topology - * \return rank number of worker - * */ -RABIT_DLL int RabitGetRingPrevRank(void); - -/*! - * \brief get rank of current process - * \return rank number of worker - * */ -RABIT_DLL int RabitGetRank(void); - -/*! - * \brief get total number of process - * \return total world size - * */ -RABIT_DLL int RabitGetWorldSize(void); - -/*! - * \brief get rank of current process - * \return if rabit is distributed - * */ -RABIT_DLL int RabitIsDistributed(void); - -/*! - * \brief print the msg to the tracker, - * this function can be used to communicate the information of the progress to - * the user who monitors the tracker - * \param msg the message to be printed - */ -RABIT_DLL int RabitTrackerPrint(const char *msg); -/*! - * \brief get name of processor - * \param out_name hold output string - * \param out_len hold length of output string - * \param max_len maximum buffer length of input - */ -RABIT_DLL void RabitGetProcessorName(char *out_name, - rbt_ulong *out_len, - rbt_ulong max_len); -/*! - * \brief broadcast an memory region to all others from root - * - * Example: int a = 1; Broadcast(&a, sizeof(a), root); - * \param sendrecv_data the pointer to send or receive buffer, - * \param size the size of the data - * \param root the root of process - */ -RABIT_DLL int RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root); - -/*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype - * \param size_node_slice size of the current node slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -RABIT_DLL int RabitAllgather(void *sendrecvbuf, size_t total_size, - size_t beginIndex, size_t size_node_slice, - size_t size_prev_slice, int enum_dtype); - -/*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * - * Example Usage: the following code gives sum of the result - * vector data(10); - * ... - * Allreduce(&data[0], data.size()); - * ... - * \param sendrecvbuf buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include - * \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit - * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to passed into the lazy preprocessing function - */ -RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, - int enum_op, void (*prepare_fun)(void *arg), - void *prepare_arg); - -/*! - * \return version number of current stored model, - * which means how many calls to CheckPoint we made so far - * \return rabit version number - */ -RABIT_DLL int RabitVersionNumber(void); - - -/*! - * \brief a Dummy function, - * used to cause force link of C API into the DLL. - * \code - * \/\/force link rabit C API library. - * static int must_link_rabit_ = RabitLinkTag(); - * \endcode - * \return a dummy integer. - */ -RABIT_DLL int RabitLinkTag(void); - -#endif // RABIT_C_API_H_ diff --git a/rabit/include/rabit/internal/engine.h b/rabit/include/rabit/internal/engine.h deleted file mode 100644 index aa074fb39..000000000 --- a/rabit/include/rabit/internal/engine.h +++ /dev/null @@ -1,197 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file engine.h - * \brief This file defines the core interface of rabit library - * \author Tianqi Chen, Nacho, Tianyi - */ -#ifndef RABIT_INTERNAL_ENGINE_H_ -#define RABIT_INTERNAL_ENGINE_H_ -#include -#include "rabit/serializable.h" - -namespace MPI { // NOLINT -/*! \brief MPI data type just to be compatible with MPI reduce function*/ -class Datatype; -} - -/*! \brief namespace of rabit */ -namespace rabit { -/*! \brief core interface of the engine */ -namespace engine { -/*! \brief interface of core Allreduce engine */ -class IEngine { - public: - /*! - * \brief Preprocessing function, that is called before AllReduce, - * used to prepare the data used by AllReduce - * \param arg additional possible argument used to invoke the preprocessor - */ - typedef void (PreprocFunction) (void *arg); // NOLINT - /*! - * \brief reduce function, the same form of MPI reduce function is used, - * to be compatible with MPI interface - * In all the functions, the memory is ensured to aligned to 64-bit - * which means it is OK to cast src,dst to double* int* etc - * \param src pointer to source space - * \param dst pointer to destination reduction - * \param count total number of elements to be reduced (note this is total number of elements instead of bytes) - * the definition of the reduce function should be type aware - * \param dtype the data type object, to be compatible with MPI reduce - */ - typedef void (ReduceFunction) (const void *src, // NOLINT - void *dst, int count, - const MPI::Datatype &dtype); - /*! \brief virtual destructor */ - ~IEngine() = default; - /*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ - virtual void Allgather(void *sendrecvbuf, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice) = 0; - /*! - * \brief performs in-place Allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the number of bytes the type has - * \param count number of elements to be reduced - * \param reducer reduce function - * \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to pass into the lazy preprocessing function - */ - virtual void Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr) = 0; - /*! - * \brief broadcasts data from root to every other node - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - */ - virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0; - /*! - * deprecated - */ - virtual int LoadCheckPoint() = 0; - /*! - * \brief Increase internal version number. Deprecated. - */ - virtual void CheckPoint() = 0; - /*! - * \return version number of the current stored model, - * which means how many calls to CheckPoint we made so far - * \sa LoadCheckPoint, CheckPoint - */ - virtual int VersionNumber() const = 0; - /*! \brief gets rank of previous node in ring topology */ - virtual int GetRingPrevRank() const = 0; - /*! \brief gets rank of current node */ - virtual int GetRank() const = 0; - /*! \brief gets total number of nodes */ - virtual int GetWorldSize() const = 0; - /*! \brief whether we run in distribted mode */ - virtual bool IsDistributed() const = 0; - /*! \brief gets the host name of the current node */ - virtual std::string GetHost() const = 0; - /*! - * \brief prints the msg in the tracker, - * this function can be used to communicate progress information to - * the user who monitors the tracker - * \param msg message to be printed in the tracker - */ - virtual void TrackerPrint(const std::string &msg) = 0; -}; - -/*! \brief initializes the engine module */ -bool Init(int argc, char *argv[]); -/*! \brief finalizes the engine module */ -bool Finalize(); -/*! \brief singleton method to get engine */ -IEngine *GetEngine(); - -/*! \brief namespace that contains stubs to be compatible with MPI */ -namespace mpi { -/*!\brief enum of all operators */ -enum OpType { - kMax = 0, - kMin = 1, - kSum = 2, - kBitwiseAND = 3, - kBitwiseOR = 4, - kBitwiseXOR = 5, -}; -/*!\brief enum of supported data types */ -enum DataType { - kChar = 0, - kUChar = 1, - kInt = 2, - kUInt = 3, - kLong = 4, - kULong = 5, - kFloat = 6, - kDouble = 7, - kLongLong = 8, - kULongLong = 9 -}; -} // namespace mpi -/*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ -void Allgather(void* sendrecvbuf, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice); -/*! - * \brief perform in-place Allreduce, on sendrecvbuf - * this is an internal function used by rabit to be able to compile with MPI - * do not use this function directly - * \param sendrecvbuf buffer for both sending and receiving data - * \param type_nbytes the number of bytes the type has - * \param count number of elements to be reduced - * \param reducer reduce function - * \param dtype the data type - * \param op the reduce operator type - * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to pass into the lazy preprocessing function. - */ -void Allreduce_(void *sendrecvbuf, // NOLINT - size_t type_nbytes, - size_t count, - IEngine::ReduceFunction red, - mpi::DataType dtype, - mpi::OpType op, - IEngine::PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr); -} // namespace engine -} // namespace rabit -#endif // RABIT_INTERNAL_ENGINE_H_ diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h deleted file mode 100644 index d5d0fee4d..000000000 --- a/rabit/include/rabit/internal/io.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2014-2023, XGBoost Contributors - * \file io.h - * \brief utilities with different serializable implementations - * \author Tianqi Chen - */ -#ifndef RABIT_INTERNAL_IO_H_ -#define RABIT_INTERNAL_IO_H_ - -#include -#include // for size_t -#include -#include // for memcpy -#include -#include -#include -#include - -#include "dmlc/io.h" -#include "xgboost/logging.h" - -namespace rabit::utils { -/*! \brief re-use definition of dmlc::SeekStream */ -using SeekStream = dmlc::SeekStream; -/** - * @brief Fixed size memory buffer as a stream. - */ -struct MemoryFixSizeBuffer : public SeekStream { - public: - // similar to SEEK_END in libc - static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); - - public: - /** - * @brief Ctor - * - * @param p_buffer Pointer to the source buffer with size `buffer_size`. - * @param buffer_size Size of the source buffer - */ - MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size) - : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) {} - ~MemoryFixSizeBuffer() override = default; - - std::size_t Read(void *ptr, std::size_t size) override { - std::size_t nread = std::min(buffer_size_ - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - void Write(const void *ptr, std::size_t size) override { - if (size == 0) return; - CHECK_LE(curr_ptr_ + size, buffer_size_); - std::memcpy(p_buffer_ + curr_ptr_, ptr, size); - curr_ptr_ += size; - } - void Seek(std::size_t pos) override { - if (pos == kSeekEnd) { - curr_ptr_ = buffer_size_; - } else { - curr_ptr_ = static_cast(pos); - } - } - /** - * @brief Current position in the buffer (stream). - */ - std::size_t Tell() override { return curr_ptr_; } - [[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } - - protected: - /*! \brief in memory buffer */ - char *p_buffer_{nullptr}; - /*! \brief current pointer */ - std::size_t buffer_size_{0}; - /*! \brief current pointer */ - std::size_t curr_ptr_{0}; -}; - -/*! \brief a in memory buffer that can be read and write as stream interface */ -struct MemoryBufferStream : public SeekStream { - public: - explicit MemoryBufferStream(std::string *p_buffer) - : p_buffer_(p_buffer) { - curr_ptr_ = 0; - } - ~MemoryBufferStream() override = default; - size_t Read(void *ptr, size_t size) override { - CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length"; - size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - void Write(const void *ptr, size_t size) override { - if (size == 0) return; - if (curr_ptr_ + size > p_buffer_->length()) { - p_buffer_->resize(curr_ptr_+size); - } - std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); - curr_ptr_ += size; - } - void Seek(size_t pos) override { - curr_ptr_ = static_cast(pos); - } - size_t Tell() override { - return curr_ptr_; - } - virtual bool AtEnd() const { - return curr_ptr_ == p_buffer_->length(); - } - - private: - /*! \brief in memory buffer */ - std::string *p_buffer_; - /*! \brief current pointer */ - size_t curr_ptr_; -}; // class MemoryBufferStream -} // namespace rabit::utils -#endif // RABIT_INTERNAL_IO_H_ diff --git a/rabit/include/rabit/internal/rabit-inl.h b/rabit/include/rabit/internal/rabit-inl.h deleted file mode 100644 index 49b086320..000000000 --- a/rabit/include/rabit/internal/rabit-inl.h +++ /dev/null @@ -1,234 +0,0 @@ -/*! - * Copyright (c) 2014-2019 by Contributors - * \file rabit-inl.h - * \brief implementation of inline template function for rabit interface - * - * \author Tianqi Chen - */ -#ifndef RABIT_INTERNAL_RABIT_INL_H_ -#define RABIT_INTERNAL_RABIT_INL_H_ -// use engine for implementation -#include -#include -#include "rabit/internal/io.h" -#include "rabit/internal/utils.h" -#include "rabit/rabit.h" - -namespace rabit { -namespace engine { -namespace mpi { -// template function to translate type to enum indicator -template -inline DataType GetType(); -template<> -inline DataType GetType() { - return kChar; -} -template<> -inline DataType GetType() { - return kUChar; -} -template<> -inline DataType GetType() { - return kInt; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kUInt; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kLong; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kULong; -} -template<> -inline DataType GetType() { - return kFloat; -} -template<> -inline DataType GetType() { - return kDouble; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kLongLong; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kULongLong; -} -} // namespace mpi -} // namespace engine - -namespace op { -struct Max { - static const engine::mpi::OpType kType = engine::mpi::kMax; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - if (dst < src) dst = src; - } -}; -struct Min { - static const engine::mpi::OpType kType = engine::mpi::kMin; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - if (dst > src) dst = src; - } -}; -struct Sum { - static const engine::mpi::OpType kType = engine::mpi::kSum; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst += src; - } -}; -struct BitAND { - static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst &= src; - } -}; -struct BitOR { - static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst |= src; - } -}; -struct BitXOR { - static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst ^= src; - } -}; -template -inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) { - const DType *src = static_cast(src_); - DType *dst = (DType *)dst_; // NOLINT(*) - for (int i = 0; i < len; i++) { - OP::Reduce(dst[i], src[i]); - } -} -} // namespace op - -// initialize the rabit engine -inline bool Init(int argc, char *argv[]) { - return engine::Init(argc, argv); -} -// finalize the rabit engine -inline bool Finalize() { - return engine::Finalize(); -} -// get the rank of the previous worker in ring topology -inline int GetRingPrevRank() { - return engine::GetEngine()->GetRingPrevRank(); -} -// get the rank of current process -inline int GetRank() { - return engine::GetEngine()->GetRank(); -} -// the the size of the world -inline int GetWorldSize() { - return engine::GetEngine()->GetWorldSize(); -} -// whether rabit is distributed -inline bool IsDistributed() { - return engine::GetEngine()->IsDistributed(); -} -// get the name of current processor -inline std::string GetProcessorName() { - return engine::GetEngine()->GetHost(); -} -// broadcast data to all other nodes from root -inline void Broadcast(void *sendrecv_data, size_t size, int root) { - engine::GetEngine()->Broadcast(sendrecv_data, size, root); -} -template -inline void Broadcast(std::vector *sendrecv_data, int root) { - size_t size = sendrecv_data->size(); - Broadcast(&size, sizeof(size), root); - if (sendrecv_data->size() != size) { - sendrecv_data->resize(size); - } - if (size != 0) { - Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root); - } -} -inline void Broadcast(std::string *sendrecv_data, int root) { - size_t size = sendrecv_data->length(); - Broadcast(&size, sizeof(size), root); - if (sendrecv_data->length() != size) { - sendrecv_data->resize(size); - } - if (size != 0) { - Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); - } -} - -// perform inplace Allreduce -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, - engine::mpi::GetType(), OP::kType, prepare_fun, prepare_arg); -} - -// C++11 support for lambda prepare function -#if DMLC_USE_CXX11 -inline void InvokeLambda(void *fun) { - (*static_cast*>(fun))(); -} -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - std::function prepare_fun) { - engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, - engine::mpi::GetType(), OP::kType, InvokeLambda, &prepare_fun); -} - -// Performs inplace Allgather -template -inline void Allgather(DType *sendrecvbuf, - size_t totalSize, - size_t beginIndex, - size_t sizeNodeSlice, - size_t sizePrevSlice) { - engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType), - (beginIndex + sizeNodeSlice) * sizeof(DType), - sizePrevSlice * sizeof(DType)); -} -#endif // C++11 - -// print message to the tracker -inline void TrackerPrint(const std::string &msg) { - engine::GetEngine()->TrackerPrint(msg); -} -#ifndef RABIT_STRICT_CXX98_ -inline void TrackerPrintf(const char *fmt, ...) { - const int kPrintBuffer = 1 << 10; - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - msg.resize(strlen(msg.c_str())); - TrackerPrint(msg); -} - -#endif // RABIT_STRICT_CXX98_ - -// deprecated, planned for removal after checkpoing from JVM package is removed. -inline int LoadCheckPoint() { return engine::GetEngine()->LoadCheckPoint(); } -// deprecated, increase internal version number -inline void CheckPoint() { engine::GetEngine()->CheckPoint(); } -// return the version number of currently stored model -inline int VersionNumber() { - return engine::GetEngine()->VersionNumber(); -} -} // namespace rabit -#endif // RABIT_INTERNAL_RABIT_INL_H_ diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index cec246efd..3701146d4 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file socket.h * \author Tianqi Chen */ @@ -95,7 +95,10 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) template std::enable_if_t, xgboost::collective::Result> PollError(E const& revents) { if ((revents & POLLERR) != 0) { - return xgboost::system::FailWithCode("Poll error condition."); + auto err = errno; + auto str = strerror(err); + return xgboost::system::FailWithCode(std::string{"Poll error condition:"} + std::string{str} + + " code:" + std::to_string(err)); } if ((revents & POLLNVAL) != 0) { return xgboost::system::FailWithCode("Invalid polling request."); @@ -211,12 +214,7 @@ struct PollHelper { } auto revents = pfd.revents & pfd.events; - if (!revents) { - // FIXME(jiamingy): remove this once rabit is replaced. - fds.erase(pfd.fd); - } else { - fds[pfd.fd].events = revents; - } + fds[pfd.fd].events = revents; } return xgboost::collective::Success(); } diff --git a/rabit/include/rabit/internal/utils.h b/rabit/include/rabit/internal/utils.h deleted file mode 100644 index c1739ce79..000000000 --- a/rabit/include/rabit/internal/utils.h +++ /dev/null @@ -1,146 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file utils.h - * \brief simple utils to support the code - * \author Tianqi Chen - */ -#ifndef RABIT_INTERNAL_UTILS_H_ -#define RABIT_INTERNAL_UTILS_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "dmlc/io.h" -#include "xgboost/logging.h" - -#if !defined(__GNUC__) || defined(__FreeBSD__) -#define fopen64 std::fopen -#endif // !defined(__GNUC__) || defined(__FreeBSD__) - -#ifndef _MSC_VER - -#ifdef _FILE_OFFSET_BITS -#if _FILE_OFFSET_BITS == 32 -#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") -#endif // _FILE_OFFSET_BITS == 32 -#endif // _FILE_OFFSET_BITS - -#ifdef __APPLE__ -#define off64_t off_t -#define fopen64 std::fopen -#endif // __APPLE__ - -extern "C" { -#include -} -#endif // _MSC_VER - -#include - -namespace rabit { -/*! \brief namespace for helper utils of the project */ -namespace utils { - -/*! \brief error message buffer length */ -const int kPrintBuffer = 1 << 12; - -/* \brief Case-insensitive string comparison */ -inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) { -#ifdef _MSC_VER - return _stricmp(s1, s2); -#else // _MSC_VER - return strcasecmp(s1, s2); -#endif // _MSC_VER -} - -/* \brief parse config string too bool*/ -inline bool StringToBool(const char* s) { - return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0; -} - -/*! \brief printf, prints messages to the console */ -inline void Printf(const char *fmt, ...) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(CONSOLE) << msg; -} - -/*! \brief assert a condition is true, use this to handle debug information */ -inline void Assert(bool exp, const char *fmt, ...) { - if (!exp) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(FATAL) << msg; - } -} - -/*!\brief same as assert, but this is intended to be used as a message for users */ -inline void Check(bool exp, const char *fmt, ...) { - if (!exp) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(FATAL) << msg; - } -} - -/*! \brief report error message, same as check */ -inline void Error(const char *fmt, ...) { - { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(FATAL) << msg; - } -} -} // namespace utils - -// Can not use std::min on Windows with msvc due to: -// error C2589: '(': illegal token on right side of '::' -template -auto Min(T const& l, T const& r) { - return l < r ? l : r; -} -// same with Min -template -auto Max(T const& l, T const& r) { - return l > r ? l : r; -} - -// easy utils that can be directly accessed in xgboost -/*! \brief get the beginning address of a vector */ -template -inline T *BeginPtr(std::vector &vec) { // NOLINT(*) - if (vec.size() == 0) { - return nullptr; - } else { - return &vec[0]; - } -} -inline char* BeginPtr(std::string &str) { // NOLINT(*) - if (str.length() == 0) return nullptr; - return &str[0]; -} -inline const char* BeginPtr(const std::string &str) { - if (str.length() == 0) return nullptr; - return &str[0]; -} -} // namespace rabit -#endif // RABIT_INTERNAL_UTILS_H_ diff --git a/rabit/include/rabit/rabit.h b/rabit/include/rabit/rabit.h deleted file mode 100644 index 10ea9a47f..000000000 --- a/rabit/include/rabit/rabit.h +++ /dev/null @@ -1,237 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file rabit.h - * \brief This file defines rabit's Allreduce/Broadcast interface - * The rabit engine contains the actual implementation - * Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant), - * - * rabit.h and serializable.h is all what the user needs to use the rabit interface - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#ifndef RABIT_RABIT_H_ // NOLINT(*) -#define RABIT_RABIT_H_ // NOLINT(*) -#include -#include -#include -// engine definition of rabit, defines internal implementation -// to use rabit interface, there is no need to read engine.h -// rabit.h and serializable.h are enough to use the interface -#include "./internal/engine.h" - -/*! \brief rabit namespace */ -namespace rabit { -/*! - * \brief defines stream used in rabit - * see definition of Stream in dmlc/io.h - */ -using Stream = dmlc::Stream; -/*! - * \brief defines serializable objects used in rabit - * see definition of Serializable in dmlc/io.h - */ -using Serializable = dmlc::Serializable; - -/*! - * \brief reduction operators namespace - */ -namespace op { -/*! - * \class rabit::op::Max - * \brief maximum reduction operator - */ -struct Max; -/*! - * \class rabit::op::Min - * \brief minimum reduction operator - */ -struct Min; -/*! - * \class rabit::op::Sum - * \brief sum reduction operator - */ -struct Sum; -/*! - * \class rabit::op::BitAND - * \brief bitwise AND reduction operator - */ -struct BitAND; -/*! - * \class rabit::op::BitOR - * \brief bitwise OR reduction operator - */ -struct BitOR; -/*! - * \class rabit::op::BitXOR - * \brief bitwise XOR reduction operator - */ -struct BitXOR; -} // namespace op -/*! - * \brief initializes rabit, call this once at the beginning of your program - * \param argc number of arguments in argv - * \param argv the array of input arguments - * \return true if initialized successfully, otherwise false - */ -inline bool Init(int argc, char *argv[]); -/*! - * \brief finalizes the rabit engine, call this function after you finished with all the jobs - * \return true if finalized successfully, otherwise false - */ -inline bool Finalize(); -/*! \brief gets rank of the current process - * \return rank number of worker*/ -inline int GetRank(); -/*! \brief gets total number of processes - * \return total world size*/ -inline int GetWorldSize(); -/*! \brief whether rabit env is in distributed mode - * \return is distributed*/ -inline bool IsDistributed(); - -/*! \brief gets processor's name - * \return processor name*/ -inline std::string GetProcessorName(); -/*! - * \brief prints the msg to the tracker, - * this function can be used to communicate progress information to - * the user who monitors the tracker - * \param msg the message to be printed - */ -inline void TrackerPrint(const std::string &msg); - -#ifndef RABIT_STRICT_CXX98_ -/*! - * \brief prints the msg to the tracker, this function may not be available - * in very strict c++98 compilers, though it usually is. - * this function can be used to communicate progress information to - * the user who monitors the tracker - * \param fmt the format string - */ -inline void TrackerPrintf(const char *fmt, ...); -#endif // RABIT_STRICT_CXX98_ -/*! - * \brief broadcasts a memory region to every node from the root - * - * Example: int a = 1; Broadcast(&a, sizeof(a), root); - * \param sendrecv_data the pointer to the send/receive buffer, - * \param size the data size - * \param root the process root - */ -inline void Broadcast(void *sendrecv_data, size_t size, int root); - -/*! - * \brief broadcasts an std::vector to every node from root - * \param sendrecv_data the pointer to send/receive vector, - * for the receiver, the vector does not need to be pre-allocated - * \param root the process root - * \tparam DType the data type stored in the vector, has to be a simple data type - * that can be directly transmitted by sending the sizeof(DType) - */ -template -inline void Broadcast(std::vector *sendrecv_data, int root); -/*! - * \brief broadcasts a std::string to every node from the root - * \param sendrecv_data the pointer to the send/receive buffer, - * for the receiver, the vector does not need to be pre-allocated - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - * \param root the process root - */ -inline void Broadcast(std::string *sendrecv_data, int root); -/*! - * \brief performs in-place Allreduce on sendrecvbuf - * this function is NOT thread-safe - * - * Example Usage: the following code does an Allreduce and outputs the sum as the result - * \code{.cpp} - * vector data(10); - * ... - * Allreduce(&data[0], data.size()); - * ... - * \endcode - * - * \param sendrecvbuf buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to pass into the lazy preprocessing function - * \tparam OP see namespace op, reduce operator - * \tparam DType data type - */ -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *) = nullptr, - void *prepare_arg = nullptr); - -/*! -* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, -* the data provided by current node k is [slice_begin, slice_end), -* the next node's segment must start with slice_end -* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments -* use a ring based algorithm -* -* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually -* \param total_size total size of data to be gathered -* \param slice_begin beginning of the current slice -* \param slice_end end of the current slice -* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size -*/ -template -inline void Allgather(DType *sendrecvbuf_, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice); - -// C++11 support for lambda prepare function -#if DMLC_USE_CXX11 -/*! - * \brief performs in-place Allreduce, on sendrecvbuf - * with a prepare function specified by a lambda function - * - * Example Usage: - * \code{.cpp} - * // the following code does an Allreduce and outputs the sum as the result - * vector data(10); - * ... - * Allreduce(&data[0], data.size(), [&]() { - * for (int i = 0; i < 10; ++i) { - * data[i] = i; - * } - * }); - * ... - * \endcode - * \param sendrecvbuf buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked - * by the function before performing Allreduce in order to initialize the data in sendrecvbuf. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \tparam OP see namespace op, reduce operator - * \tparam DType data type - */ -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - std::function prepare_fun); -#endif // C++11 - -/*! - * \brief deprecated, planned for removal after checkpoing from JVM package is removed. - */ -inline int LoadCheckPoint(); -/*! - * \brief deprecated, planned for removal after checkpoing from JVM package is removed. - */ -inline void CheckPoint(); - -/*! - * \return version number of the current stored model, - * which means how many calls to CheckPoint we made so far - * \sa LoadCheckPoint, CheckPoint - */ -inline int VersionNumber(); -} // namespace rabit -// implementation of template functions -#include "./internal/rabit-inl.h" -#endif // RABIT_RABIT_H_ // NOLINT(*) diff --git a/rabit/include/rabit/serializable.h b/rabit/include/rabit/serializable.h deleted file mode 100644 index 77508292a..000000000 --- a/rabit/include/rabit/serializable.h +++ /dev/null @@ -1,26 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file serializable.h - * \brief defines serializable interface of rabit - * \author Tianqi Chen - */ -#ifndef RABIT_SERIALIZABLE_H_ -#define RABIT_SERIALIZABLE_H_ -#include -#include -#include "rabit/internal/utils.h" - -namespace rabit { -/*! - * \brief defines stream used in rabit - * see definition of Stream in dmlc/io.h - */ -using Stream = dmlc::Stream ; -/*! - * \brief defines serializable objects used in rabit - * see definition of Serializable in dmlc/io.h - */ -using Serializable = dmlc::Serializable; - -} // namespace rabit -#endif // RABIT_SERIALIZABLE_H_ diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc deleted file mode 100644 index fcf80b414..000000000 --- a/rabit/src/allreduce_base.cc +++ /dev/null @@ -1,997 +0,0 @@ -/** - * Copyright 2014-2023, XGBoost Contributors - * \file allreduce_base.cc - * \brief Basic implementation of AllReduce - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#if !defined(NOMINMAX) && defined(_WIN32) -#define NOMINMAX -#endif // !defined(NOMINMAX) - -#include "allreduce_base.h" - -#include "rabit/base.h" -#include "rabit/internal/rabit-inl.h" -#include "xgboost/collective/result.h" - -#ifndef _WIN32 -#include -#endif // _WIN32 - -#include -#include - -namespace rabit::engine { -// constructor -AllreduceBase::AllreduceBase() { - tracker_uri = "NULL"; - tracker_port = 9000; - host_uri = ""; - rank = 0; - world_size = -1; - connect_retry = 5; - hadoop_mode = false; - version_number = 0; - // 32 K items - reduce_ring_mincount = 32 << 10; - // 1M reducer size each time - tree_reduce_minsize = 1 << 20; - // tracker URL - task_id = "NULL"; - err_link = nullptr; - dmlc_role = "worker"; - this->SetParam("rabit_reduce_buffer", "256MB"); - // setup possible environment variable of interest - // include dmlc support direct variables - env_vars.emplace_back("DMLC_TASK_ID"); - env_vars.emplace_back("DMLC_ROLE"); - env_vars.emplace_back("DMLC_NUM_ATTEMPT"); - env_vars.emplace_back("DMLC_TRACKER_URI"); - env_vars.emplace_back("DMLC_TRACKER_PORT"); - env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY"); -} - -// initialization function -bool AllreduceBase::Init(int argc, char* argv[]) { - // setup from environment variables - // handler to get variables from env - for (auto & env_var : env_vars) { - const char *value = getenv(env_var.c_str()); - if (value != nullptr) { - this->SetParam(env_var.c_str(), value); - } - } - // pass in arguments override env variable. - for (int i = 0; i < argc; ++i) { - char name[256], val[256]; - if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { - this->SetParam(name, val); - } - } - - { - // handling for hadoop - const char *task_id = getenv("mapred_tip_id"); - if (task_id == nullptr) { - task_id = getenv("mapreduce_task_id"); - } - if (hadoop_mode) { - utils::Check(task_id != nullptr, - "hadoop_mode is set but cannot find mapred_task_id"); - } - if (task_id != nullptr) { - this->SetParam("rabit_task_id", task_id); - this->SetParam("rabit_hadoop_mode", "1"); - } - const char *attempt_id = getenv("mapred_task_id"); - if (attempt_id != nullptr) { - const char *att = strrchr(attempt_id, '_'); - int num_trial; - if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) { - this->SetParam("rabit_num_trial", att + 1); - } - } - // handling for hadoop - const char *num_task = getenv("mapred_map_tasks"); - if (num_task == nullptr) { - num_task = getenv("mapreduce_job_maps"); - } - if (hadoop_mode) { - utils::Check(num_task != nullptr, - "hadoop_mode is set but cannot find mapred_map_tasks"); - } - if (num_task != nullptr) { - this->SetParam("rabit_world_size", num_task); - } - } - if (dmlc_role != "worker") { - LOG(FATAL) << "Rabit Module currently only works with dmlc worker"; - } - - // clear the setting before start reconnection - this->rank = -1; - //--------------------- - // start socket - xgboost::system::SocketStartup(); - utils::Assert(all_links.size() == 0, "can only call Init once"); - auto rc = xgboost::collective::GetHostName(&this->host_uri); - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } - // get information from tracker - rc = this->ReConnectLinks(); - if (rc.OK()) { - return true; - } - LOG(FATAL) << rc.Report(); - return false; -} - -bool AllreduceBase::Shutdown() { - try { - for (auto &all_link : all_links) { - if (!all_link.sock.IsClosed()) { - SafeColl(all_link.sock.Close()); - } - } - all_links.clear(); - tree_links.plinks.clear(); - - if (tracker_uri == "NULL") return true; - // notify tracker rank i have shutdown - xgboost::collective::TCPSocket tracker; - auto rc = this->ConnectTracker(&tracker); - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } - tracker.Send(xgboost::StringView{"shutdown"}); - SafeColl(tracker.Close()); - xgboost::system::SocketFinalize(); - return true; - } catch (std::exception const &e) { - LOG(WARNING) << "Failed to shutdown due to" << e.what(); - return false; - } -} - -void AllreduceBase::TrackerPrint(const std::string &msg) { - if (tracker_uri == "NULL") { - utils::Printf("%s", msg.c_str()); return; - } - xgboost::collective::TCPSocket tracker; - auto rc = this->ConnectTracker(&tracker); - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } - - tracker.Send(xgboost::StringView{"print"}); - tracker.Send(xgboost::StringView{msg}); - SafeColl(tracker.Close()); -} - -// util to parse data with unit suffix -inline size_t ParseUnit(const char *name, const char *val) { - char unit; - unsigned long amt; // NOLINT(*) - int n = sscanf(val, "%lu%c", &amt, &unit); - size_t amount = amt; - if (n == 2) { - switch (unit) { - case 'B': return amount; - case 'K': return amount << 10UL; - case 'M': return amount << 20UL; - case 'G': return amount << 30UL; - default: utils::Error("invalid format for %s", name); return 0; - } - } else if (n == 1) { - return amount; - } else { - utils::Error("invalid format for %s," \ - "shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name); - return 0; - } -} -/*! - * \brief set parameters to the engine - * \param name parameter name - * \param val parameter value - */ -void AllreduceBase::SetParam(const char *name, const char *val) { - if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; - if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); - if (!strcmp(name, "rabit_task_id")) task_id = val; - if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val; - if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val); - if (!strcmp(name, "DMLC_TASK_ID")) task_id = val; - if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val; - if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); - if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val); - if (!strcmp(name, "rabit_tree_reduce_minsize")) tree_reduce_minsize = atoi(val); - if (!strcmp(name, "rabit_reduce_ring_mincount")) { - reduce_ring_mincount = atoi(val); - utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0"); - } - if (!strcmp(name, "rabit_reduce_buffer")) { - reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3; - } - if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) { - connect_retry = atoi(val); - } - if (!strcmp(name, "rabit_timeout")) { - rabit_timeout = utils::StringToBool(val); - } - if (!strcmp(name, "rabit_timeout_sec")) { - timeout_sec = std::chrono::seconds(atoi(val)); - utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second"); - } - if (!strcmp(name, "rabit_enable_tcp_no_delay")) { - if (!strcmp(val, "true")) { - rabit_enable_tcp_no_delay = true; - } else { - rabit_enable_tcp_no_delay = false; - } - } -} - -/*! - * \brief initialize connection to the tracker - * \return a socket that initializes the connection - */ -[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker( - xgboost::collective::TCPSocket *out) const { - int magic = kMagic; - // get information from tracker - xgboost::collective::TCPSocket &tracker = *out; - - auto rc = - Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker); - if (!rc.OK()) { - return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); - } - - using utils::Assert; - if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) { - return xgboost::collective::Fail("Failed to send the verification number."); - } - if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) { - return xgboost::collective::Fail("Failed to recieve the verification number."); - } - if (magic != kMagic) { - return xgboost::collective::Fail("Invalid verification number."); - } - if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) { - return xgboost::collective::Fail("Failed to send the local rank back to the tracker."); - } - if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) { - return xgboost::collective::Fail("Failed to send the world size back to the tracker."); - } - if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) { - return xgboost::collective::Fail("Failed to send the task ID back to the tracker."); - } - - return xgboost::collective::Success(); -} -/*! - * \brief connect to the tracker to fix the missing links - * this function is also used when the engine start up - */ -[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) { - // single node mode - if (tracker_uri == "NULL") { - rank = 0; - world_size = 1; - return xgboost::collective::Success(); - } - - xgboost::collective::TCPSocket tracker; - auto rc = this->ConnectTracker(&tracker); - if (!rc.OK()) { - return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); - } - - LOG(INFO) << "task " << task_id << " connected to the tracker"; - tracker.Send(xgboost::StringView{cmd}); - - try { - // the rank of previous link, next link in ring - int prev_rank, next_rank; - // the rank of neighbors - std::map tree_neighbors; - using utils::Assert; - // get new ranks - int newrank, num_neighbors; - Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), - "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \ - sizeof(parent_rank), "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), - "ReConnectLink failure 4"); - Assert(rank == -1 || newrank == rank, - "must keep rank to same if the node already have one"); - rank = newrank; - - if (rank == -1) { - LOG(FATAL) << "tracker got overwhelmed and not able to assign correct rank"; - } - - LOG(CONSOLE) << "task " << task_id << " got new rank " << rank; - - Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \ - sizeof(num_neighbors), "ReConnectLink failure 4"); - for (int i = 0; i < num_neighbors; ++i) { - int nrank; - Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), - "ReConnectLink failure 4"); - tree_neighbors[nrank] = 1; - } - Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), - "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), - "ReConnectLink failure 4"); - - auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())}; - // create listening socket - std::int32_t port{0}; - SafeColl(sock_listen.BindHost(&port)); - SafeColl(sock_listen.Listen()); - - // get number of to connect and number of to accept nodes from tracker - int num_conn, num_accept, num_error = 1; - do { - for (auto & all_link : all_links) { - SafeColl(all_link.sock.Close()); - } - // tracker construct goodset - Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), - "ReConnectLink failure 7"); - Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), - "ReConnectLink failure 8"); - num_error = 0; - for (int i = 0; i < num_conn; ++i) { - LinkRecord r; - int hport, hrank; - std::string hname; - SafeColl(tracker.Recv(&hname)); - Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); - Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); - // connect to peer - if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry, - timeout_sec, &r.sock) - .OK()) { - num_error += 1; - SafeColl(r.sock.Close()); - continue; - } - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 12"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), - "ReConnectLink failure 13"); - utils::Check(hrank == r.rank, - "ReConnectLink failure, link rank inconsistent"); - bool match = false; - for (auto & all_link : all_links) { - if (all_link.rank == hrank) { - Assert(all_link.sock.IsClosed(), "Override a link that is active"); - all_link.sock = std::move(r.sock); - match = true; - break; - } - } - if (!match) all_links.emplace_back(std::move(r)); - } - Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), - "ReConnectLink failure 14"); - } while (num_error != 0); - // send back socket listening port to tracker - Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); - // close connection to tracker - SafeColl(tracker.Close()); - - // listen to incoming links - for (int i = 0; i < num_accept; ++i) { - LinkRecord r; - r.sock = sock_listen.Accept(); - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 15"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), - "ReConnectLink failure 15"); - bool match = false; - for (auto & all_link : all_links) { - if (all_link.rank == r.rank) { - utils::Assert(all_link.sock.IsClosed(), - "Override a link that is active"); - all_link.sock = std::move(r.sock); - match = true; - break; - } - } - if (!match) all_links.emplace_back(std::move(r)); - } - SafeColl(sock_listen.Close()); - - this->parent_index = -1; - // setup tree links and ring structure - tree_links.plinks.clear(); - for (auto &all_link : all_links) { - utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); - // set the socket to non-blocking mode, enable TCP keepalive - CHECK(all_link.sock.NonBlocking(true).OK()); - CHECK(all_link.sock.SetKeepAlive().OK()); - if (rabit_enable_tcp_no_delay) { - CHECK(all_link.sock.SetNoDelay().OK()); - } - if (tree_neighbors.count(all_link.rank) != 0) { - if (all_link.rank == parent_rank) { - parent_index = static_cast(tree_links.plinks.size()); - } - tree_links.plinks.push_back(&all_link); - } - if (all_link.rank == prev_rank) ring_prev = &all_link; - if (all_link.rank == next_rank) ring_next = &all_link; - } - Assert(parent_rank == -1 || parent_index != -1, - "cannot find parent in the link"); - Assert(prev_rank == -1 || ring_prev != nullptr, - "cannot find prev ring in the link"); - Assert(next_rank == -1 || ring_next != nullptr, - "cannot find next ring in the link"); - return xgboost::collective::Success(); - } catch (const std::exception& e) { - std::stringstream ss; - ss << "Failed in ReconnectLink " << e.what(); - return xgboost::collective::Fail(ss.str()); - } -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure - * - * NOTE on Allreduce: - * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce. - * It only means the current node get the correct result of Allreduce. - * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - if (count > reduce_ring_mincount) { - return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer); - } else { - return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer); - } -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf, - * this function implements tree-shape reduction - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - RefLinkVector &links = tree_links; - if (links.Size() == 0 || count == 0) return kSuccess; - // total size of message - const size_t total_size = type_nbytes * count; - // number of links - const int nlink = static_cast(links.Size()); - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - // size of space that we already performs reduce in up pass - size_t size_up_reduce = 0; - // size of space that we have already passed to parent - size_t size_up_out = 0; - // size of message we received, and send in the down pass - size_t size_down_in = 0; - // minimal size of each reducer - const size_t eachreduce = (tree_reduce_minsize / type_nbytes * type_nbytes); - - // initialize the link ring-buffer and pointer - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - links[i].InitBuffer(type_nbytes, count, reduce_buffer_size); - } - links[i].ResetSize(); - } - // if no children, no need to reduce - if (nlink == static_cast(parent_index != -1)) { - size_up_reduce = total_size; - } - // while we have not passed the messages out - while (true) { - // select helper - bool finished = true; - utils::PollHelper watcher; - for (int i = 0; i < nlink; ++i) { - if (i == parent_index) { - if (size_down_in != total_size) { - watcher.WatchRead(links[i].sock); - // only watch for exception in live channels - watcher.WatchException(links[i].sock); - finished = false; - } - if (size_up_out != total_size && size_up_out < size_up_reduce) { - watcher.WatchWrite(links[i].sock); - } - } else { - if (links[i].size_read != total_size) { - watcher.WatchRead(links[i].sock); - } - // size_write <= size_read - if (links[i].size_write != total_size) { - if (links[i].size_write < size_down_in) { - watcher.WatchWrite(links[i].sock); - } - // only watch for exception in live channels - watcher.WatchException(links[i].sock); - finished = false; - } - } - } - // finish running allreduce - if (finished) { - break; - } - // select must return - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - - // read data from childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && watcher.CheckRead(links[i].sock)) { - // make sure to receive minimal reducer size - // since each child reduce and sends the minimal reducer size - while (links[i].size_read < total_size - && links[i].size_read - size_up_reduce < eachreduce) { - ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - // this node have children, perform reduce - if (nlink > static_cast(parent_index != -1)) { - size_t buffer_size = 0; - // do upstream reduce - size_t max_reduce = total_size; - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - max_reduce = std::min(max_reduce, links[i].size_read); - utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, - "buffer size inconsistent"); - buffer_size = links[i].buffer_size; - } - } - utils::Assert(buffer_size != 0, "must assign buffer_size"); - // round to type_n4bytes - max_reduce = (max_reduce / type_nbytes * type_nbytes); - - // if max reduce is less than total size, we reduce multiple times of - // each reduce size - if (max_reduce < total_size) { - max_reduce = max_reduce - max_reduce % eachreduce; - } - - // perform reduce, can be at most two rounds - while (size_up_reduce < max_reduce) { - // start position - size_t start = size_up_reduce % buffer_size; - // perform read till end of buffer - size_t nread = std::min(buffer_size - start, - max_reduce - size_up_reduce); - utils::Assert(nread % type_nbytes == 0, "Allreduce: size check"); - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - reducer(links[i].buffer_head + start, - sendrecvbuf + size_up_reduce, - static_cast(nread / type_nbytes), - MPI::Datatype(type_nbytes)); - } - } - size_up_reduce += nread; - } - } - if (parent_index != -1) { - // pass message up to parent, can pass data that are already been reduced - if (size_up_out < size_up_reduce) { - ssize_t len = links[parent_index].sock. - Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); - if (len != -1) { - size_up_out += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - return ReportError(&links[parent_index], ret); - } - } - } - // read data from parent - if (watcher.CheckRead(links[parent_index].sock) && - total_size > size_down_in) { - size_t left_size = total_size-size_down_in; - size_t reduce_size_min = std::min(left_size, eachreduce); - size_t recved = 0; - while (recved < reduce_size_min) { - ssize_t len = links[parent_index].sock. - Recv(sendrecvbuf + size_down_in, total_size - size_down_in); - - if (len == 0) { - SafeColl(links[parent_index].sock.Close()); - return ReportError(&links[parent_index], kRecvZeroLen); - } - if (len != -1) { - size_down_in += static_cast(len); - utils::Assert(size_down_in <= size_up_out, - "Allreduce: boundary error"); - recved+=len; - - // if it receives more data than each reduce, it means the next block is sent. - // we double the reduce_size_min or add to left_size - while (recved > reduce_size_min) { - reduce_size_min += std::min(left_size-reduce_size_min, eachreduce); - } - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - return ReportError(&links[parent_index], ret); - } - } - } - } - } else { - // this is root, can use reduce as most recent point - size_down_in = size_up_out = size_up_reduce; - } - // can pass message down to children - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && links[i].size_write < size_down_in) { - ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - return kSuccess; -} -/*! - * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param total_size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { - RefLinkVector &links = tree_links; - if (links.Size() == 0 || total_size == 0) return kSuccess; - utils::Check(root < world_size, - "Broadcast: root should be smaller than world size"); - // number of links - const int nlink = static_cast(links.Size()); - // size of space already read from data - size_t size_in = 0; - // input link, -2 means unknown yet, -1 means this is root - int in_link = -2; - - // initialize the link statistics - for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - // root have all the data - if (this->rank == root) { - size_in = total_size; - in_link = -1; - } - // while we have not passed the messages out - while (true) { - bool finished = true; - // select helper - utils::PollHelper watcher; - for (int i = 0; i < nlink; ++i) { - if (in_link == -2) { - watcher.WatchRead(links[i].sock); finished = false; - } - if (i == in_link && links[i].size_read != total_size) { - watcher.WatchRead(links[i].sock); finished = false; - } - if (in_link != -2 && i != in_link && links[i].size_write != total_size) { - if (links[i].size_write < size_in) { - watcher.WatchWrite(links[i].sock); - } - finished = false; - } - } - // finish running - if (finished) break; - // select - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - if (in_link == -2) { - // probe in-link - for (int i = 0; i < nlink; ++i) { - if (watcher.CheckRead(links[i].sock)) { - ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - size_in = links[i].size_read; - if (size_in != 0) { - in_link = i; break; - } - } - } - } else { - // read from in link - if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) { - ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size); - if (ret != kSuccess) { - return ReportError(&links[in_link], ret); - } - size_in = links[in_link].size_read; - } - } - // send data to all out-link - for (int i = 0; i < nlink; ++i) { - if (i != in_link && links[i].size_write < size_in) { - ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - return kSuccess; -} -/*! - * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice) { - // read from next link and send to prev one - LinkRecord &prev = *ring_prev, &next = *ring_next; - // need to reply on special rank structure - utils::Assert(next.rank == (rank + 1) % world_size && - rank == (prev.rank + 1) % world_size, - "need to assume rank structure"); - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - const size_t stop_read = total_size + slice_begin; - const size_t stop_write = total_size + slice_begin - size_prev_slice; - size_t write_ptr = slice_begin; - size_t read_ptr = slice_end; - - while (true) { - // select helper - bool finished = true; - utils::PollHelper watcher; - if (read_ptr != stop_read) { - watcher.WatchRead(next.sock); - finished = false; - } - if (write_ptr != stop_write) { - if (write_ptr < read_ptr) { - watcher.WatchWrite(prev.sock); - } - finished = false; - } - if (finished) { - break; - } - - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { - size_t size = stop_read - read_ptr; - size_t start = read_ptr % total_size; - if (start + size > total_size) { - size = total_size - start; - } - ssize_t len = next.sock.Recv(sendrecvbuf + start, size); - if (len != -1) { - read_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - auto err = ReportError(&next, ret); - return err; - } - } - } - if (write_ptr < read_ptr && write_ptr != stop_write) { - size_t size = std::min(read_ptr, stop_write) - write_ptr; - size_t start = write_ptr % total_size; - if (start + size > total_size) { - size = total_size - start; - } - ssize_t len = prev.sock.Send(sendrecvbuf + start, size); - if (len != -1) { - write_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - auto err = ReportError(&prev, ret); - return err; - } - } - } - } - return kSuccess; -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, - * and will return the cause of failure - * - * Ring-based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType, TryAllreduce - */ -AllreduceBase::ReturnType -AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - // read from next link and send to prev one - LinkRecord &prev = *ring_prev, &next = *ring_next; - // need to reply on special rank structure - utils::Assert(next.rank == (rank + 1) % world_size && - rank == (prev.rank + 1) % world_size, - "need to assume rank structure"); - // total size of message - const size_t total_size = type_nbytes * count; - size_t n = static_cast(world_size); - size_t step = (count + n - 1) / n; - size_t r = static_cast(next.rank); - size_t write_ptr = std::min(r * step, count) * type_nbytes; - size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes; - size_t reduce_ptr = read_ptr; - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - // position to stop reading - const size_t stop_read = total_size + write_ptr; - // position to stop writing - size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes; - if (stop_write > stop_read) { - stop_write -= total_size; - utils::Assert(write_ptr <= stop_write, "write ptr boundary check"); - } - // use ring buffer in next position - next.InitBuffer(type_nbytes, step, reduce_buffer_size); - // set size_read to read pointer for ring buffer to work properly - next.size_read = read_ptr; - - while (true) { - // select helper - bool finished = true; - utils::PollHelper watcher; - if (read_ptr != stop_read) { - watcher.WatchRead(next.sock); - finished = false; - } - if (write_ptr != stop_write) { - if (write_ptr < reduce_ptr) { - watcher.WatchWrite(prev.sock); - } - finished = false; - } - if (finished) { - break; - } - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { - ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); - if (ret != kSuccess) { - return ReportError(&next, ret); - } - // sync the rate - read_ptr = next.size_read; - utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank); - const size_t buffer_size = next.buffer_size; - size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes; - while (reduce_ptr < max_reduce) { - size_t bstart = reduce_ptr % buffer_size; - size_t nread = std::min(buffer_size - bstart, - max_reduce - reduce_ptr); - size_t rstart = reduce_ptr % total_size; - nread = std::min(nread, total_size - rstart); - reducer(next.buffer_head + bstart, - sendrecvbuf + rstart, - static_cast(nread / type_nbytes), - MPI::Datatype(type_nbytes)); - reduce_ptr += nread; - } - } - if (write_ptr < reduce_ptr && write_ptr != stop_write) { - size_t size = std::min(reduce_ptr, stop_write) - write_ptr; - size_t start = write_ptr % total_size; - if (start + size > total_size) { - size = total_size - start; - } - ssize_t len = prev.sock.Send(sendrecvbuf + start, size); - if (len != -1) { - write_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&prev, ret); - } - } - } - return kSuccess; -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllreduceRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer); - if (ret != kSuccess) return ret; - size_t n = static_cast(world_size); - size_t step = (count + n - 1) / n; - size_t begin = std::min(rank * step, count) * type_nbytes; - size_t end = std::min((rank + 1) * step, count) * type_nbytes; - // previous rank - int prank = ring_prev->rank; - // get rank of previous - return TryAllgatherRing - (sendrecvbuf_, type_nbytes * count, - begin, end, - (std::min((prank + 1) * step, count) - - std::min(prank * step, count)) * type_nbytes); -} -} // namespace rabit::engine diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h deleted file mode 100644 index 9991c2138..000000000 --- a/rabit/src/allreduce_base.h +++ /dev/null @@ -1,501 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file allreduce_base.h - * \brief Basic implementation of AllReduce - * using TCP non-block socket and tree-shape reduction. - * - * This implementation provides basic utility of AllReduce and Broadcast - * without considering node failure - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#ifndef RABIT_ALLREDUCE_BASE_H_ -#define RABIT_ALLREDUCE_BASE_H_ - -#include -#include -#include -#include -#include - -#include "rabit/internal/engine.h" -#include "rabit/internal/socket.h" -#include "rabit/internal/utils.h" -#include "xgboost/collective/result.h" - -#ifdef RABIT_CXXTESTDEFS_H -#define private public -#define protected public -#endif // RABIT_CXXTESTDEFS_H - - -namespace MPI { // NOLINT -// MPI data type to be compatible with existing MPI interface -class Datatype { - public: - size_t type_size; - explicit Datatype(size_t type_size) : type_size(type_size) {} -}; -} -namespace rabit { -namespace engine { - -/*! \brief implementation of basic Allreduce engine */ -class AllreduceBase : public IEngine { - public: - // magic number to verify server - static const int kMagic = 0xff99; - // constant one byte out of band message to indicate error happening - AllreduceBase(); - virtual ~AllreduceBase() = default; - // initialize the manager - virtual bool Init(int argc, char* argv[]); - // shutdown the engine - virtual bool Shutdown(); - /*! - * \brief set parameters to the engine - * \param name parameter name - * \param val parameter value - */ - virtual void SetParam(const char *name, const char *val); - /*! - * \brief print the msg in the tracker, - * this function can be used to communicate the information of the progress to - * the user who monitors the tracker - * \param msg message to be printed in the tracker - */ - void TrackerPrint(const std::string &msg) override; - - /*! \brief get rank of previous node in ring topology*/ - int GetRingPrevRank() const override { - return ring_prev->rank; - } - /*! \brief get rank */ - int GetRank() const override { - return rank; - } - /*! \brief get rank */ - int GetWorldSize() const override { - if (world_size == -1) return 1; - return world_size; - } - /*! \brief whether is distributed or not */ - bool IsDistributed() const override { - return tracker_uri != "NULL"; - } - /*! \brief get rank */ - std::string GetHost() const override { - return host_uri; - } - - /*! - * \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ - void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, - size_t slice_end, size_t size_prev_slice) override { - if (world_size == 1 || world_size == -1) { - return; - } - utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin, - slice_end, size_prev_slice) == kSuccess, - "AllgatherRing failed"); - } - /*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to passed into the lazy preprocessing function - */ - void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer, PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr) override { - if (prepare_fun != nullptr) prepare_fun(prepare_arg); - if (world_size == 1 || world_size == -1) return; - utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == - kSuccess, - "Allreduce failed"); - } - /*! - * \brief broadcast data from root to all nodes - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ - void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override { - if (world_size == 1 || world_size == -1) return; - utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, - "Broadcast failed"); - } - /*! - * \brief deprecated - * \sa CheckPoint, VersionNumber - */ - int LoadCheckPoint() override { return 0; } - - // deprecated, increase internal version number - void CheckPoint() override { version_number += 1; } - /*! - * \return version number of current stored model, - * which means how many calls to CheckPoint we made so far - * \sa LoadCheckPoint, CheckPoint - */ - int VersionNumber() const override { - return version_number; - } - /*! - * \brief report current status to the job tracker - * depending on the job tracker we are in - */ - inline void ReportStatus() const { - if (hadoop_mode != 0) { - LOG(CONSOLE) << "reporter:status:Rabit Phase[" << version_number << "] Operation " << seq_counter << "\n"; - } - } - - protected: - /*! \brief enumeration of possible returning results from Try functions */ - enum ReturnTypeEnum { - /*! \brief execution is successful */ - kSuccess, - /*! \brief a link was reset by peer */ - kConnReset, - /*! \brief received a zero length message */ - kRecvZeroLen, - /*! \brief a neighbor node go down, the connection is dropped */ - kSockError, - /*! - * \brief another node which is not my neighbor go down, - * get Out-of-Band exception notification from my neighbor - */ - kGetExcept - }; - /*! \brief struct return type to avoid implicit conversion to int/bool */ - struct ReturnType { - /*! \brief internal return type */ - ReturnTypeEnum value; - // constructor - ReturnType() = default; - ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*) - inline bool operator==(const ReturnTypeEnum &v) const { - return value == v; - } - inline bool operator!=(const ReturnTypeEnum &v) const { - return value != v; - } - }; - /*! \brief translate errno to return type */ - static ReturnType Errno2Return() { - int errsv = xgboost::system::LastError(); - if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess; -#ifdef _WIN32 - if (errsv == WSAEWOULDBLOCK) return kSuccess; - if (errsv == WSAECONNRESET) return kConnReset; -#endif // _WIN32 - if (errsv == ECONNRESET) return kConnReset; - return kSockError; - } - // link record to a neighbor - struct LinkRecord { - public: - // socket to get data from/to link - xgboost::collective::TCPSocket sock; - // rank of the node in this link - int rank; - // size of data readed from link - size_t size_read; - // size of data sent to the link - size_t size_write; - // pointer to buffer head - char *buffer_head {nullptr}; - // buffer size, in bytes - size_t buffer_size {0}; - // constructor - LinkRecord() = default; - // initialize buffer - void InitBuffer(size_t type_nbytes, size_t count, - size_t reduce_buffer_size) { - size_t n = (type_nbytes * count + 7)/ 8; - auto to = Min(reduce_buffer_size, n); - buffer_.resize(to); - // make sure align to type_nbytes - buffer_size = - buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; - utils::Assert(type_nbytes <= buffer_size, - "too large type_nbytes=%lu, buffer_size=%lu", - type_nbytes, buffer_size); - // set buffer head - buffer_head = reinterpret_cast(BeginPtr(buffer_)); - } - // reset the recv and sent size - inline void ResetSize() { - size_write = size_read = 0; - } - /*! - * \brief read data into ring-buffer, with care not to existing useful override data - * position after protect_start - * \param protect_start all data start from protect_start is still needed in buffer - * read shall not override this - * \param max_size_read maximum logical amount we can read, size_read cannot exceed this value - * \return the type of reading - */ - inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) { - utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated"); - utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check"); - size_t ngap = size_read - protect_start; - utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); - size_t offset = size_read % buffer_size; - size_t nmax = max_size_read - size_read; - nmax = Min(nmax, buffer_size - ngap); - nmax = Min(nmax, buffer_size - offset); - if (nmax == 0) return kSuccess; - ssize_t len = sock.Recv(buffer_head + offset, nmax); - // length equals 0, remote disconnected - if (len == 0) { - SafeColl(sock.Close()); return kRecvZeroLen; - } - if (len == -1) return Errno2Return(); - size_read += static_cast(len); - return kSuccess; - } - /*! - * \brief read data into array, - * this function can not be used together with ReadToRingBuffer - * a link can either read into the ring buffer, or existing array - * \param max_size maximum size of array - * \return true if it is a successful read, false if there is some error happens, check errno - */ - inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) { - if (max_size == size_read) return kSuccess; - char *p = static_cast(recvbuf_); - ssize_t len = sock.Recv(p + size_read, max_size - size_read); - // length equals 0, remote disconnected - if (len == 0) { - SafeColl(sock.Close()); return kRecvZeroLen; - } - if (len == -1) return Errno2Return(); - size_read += static_cast(len); - return kSuccess; - } - /*! - * \brief write data in array to sock - * \param sendbuf_ head of array - * \param max_size maximum size of array - * \return true if it is a successful write, false if there is some error happens, check errno - */ - inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) { - const char *p = static_cast(sendbuf_); - ssize_t len = sock.Send(p + size_write, max_size - size_write); - if (len == -1) return Errno2Return(); - size_write += static_cast(len); - return kSuccess; - } - - private: - // recv buffer to get data from child - // aligned with 64 bits, will be able to perform 64 bits operations freely - std::vector buffer_; - }; - /*! - * \brief simple data structure that works like a vector - * but takes reference instead of space - */ - struct RefLinkVector { - std::vector plinks; - inline LinkRecord &operator[](size_t i) { - return *plinks[i]; - } - inline size_t Size() const { - return plinks.size(); - } - }; - /*! - * \brief initialize connection to the tracker - * \return a socket that initializes the connection - */ - [[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const; - /*! - * \brief connect to the tracker to fix the missing links - * this function is also used when the engine start up - * \param cmd possible command to sent to tracker - */ - [[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start"); - /*! - * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure - * - * NOTE on Allreduce: - * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce. - * It only means the current node get the correct result of Allreduce. - * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); - /*! - * \brief perform in-place allreduce, on sendrecvbuf, - * this function implements tree-shape reduction - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllreduceTree(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, size_t slice_end, - size_t size_prev_slice); - /*! - * \brief perform in-place allreduce, reduce on the sendrecvbuf, - * - * after the function, node k get k-th segment of the reduction result - * the k-th segment is defined by [k * step, min((k + 1) * step,count) ) - * where step = ceil(count / world_size) - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType, TryAllreduce - */ - ReturnType TryReduceScatterRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief perform in-place allreduce, on sendrecvbuf - * use a ring based algorithm, reduce-scatter + allgather - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllreduceRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief function used to report error when a link goes wrong - * \param link the pointer to the link who causes the error - * \param err the error type - */ - inline ReturnType ReportError(LinkRecord *link, ReturnType err) { - err_link = link; return err; - } - //---- data structure related to model ---- - // call sequence counter, records how many calls we made so far - // from last call to CheckPoint, LoadCheckPoint - int seq_counter{0}; // NOLINT - // version number of model - int version_number {0}; // NOLINT - // whether the job is running in Hadoop - bool hadoop_mode; // NOLINT - //---- local data related to link ---- - // index of parent link, can be -1, meaning this is root of the tree - int parent_index; // NOLINT - // rank of parent node, can be -1 - int parent_rank; // NOLINT - // sockets of all links this connects to - std::vector all_links; // NOLINT - // used to record the link where things goes wrong - LinkRecord *err_link; // NOLINT - // all the links in the reduction tree connection - RefLinkVector tree_links; // NOLINT - // pointer to links in the ring - LinkRecord *ring_prev, *ring_next; // NOLINT - //----- meta information----- - // list of enviroment variables that are of possible interest - std::vector env_vars; // NOLINT - // unique identifier of the possible job this process is doing - // used to assign ranks, optional, default to NULL - std::string task_id; // NOLINT - // uri of current host, to be set by Init - std::string host_uri; // NOLINT - // uri of tracker - std::string tracker_uri; // NOLINT - // role in dmlc jobs - std::string dmlc_role; // NOLINT - // port of tracker address - int tracker_port; // NOLINT - // reduce buffer size - size_t reduce_buffer_size; // NOLINT - // reduction method - int reduce_method; // NOLINT - // minimum count of cells to use ring based method - size_t reduce_ring_mincount; // NOLINT - // minimum block size per tree reduce - size_t tree_reduce_minsize; // NOLINT - // current rank - int rank; // NOLINT - // world size - int world_size; // NOLINT - // connect retry time - int connect_retry; // NOLINT - // by default, if rabit worker not recover in half an hour exit - std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT - // flag to enable rabit_timeout - bool rabit_timeout = false; // NOLINT - // Enable TCP node delay - bool rabit_enable_tcp_no_delay = false; // NOLINT -}; -} // namespace engine -} // namespace rabit -#endif // RABIT_ALLREDUCE_BASE_H_ diff --git a/rabit/src/allreduce_mock.h b/rabit/src/allreduce_mock.h deleted file mode 100644 index b24346586..000000000 --- a/rabit/src/allreduce_mock.h +++ /dev/null @@ -1,147 +0,0 @@ -/*! - * Copyright by Contributors - * \file allreduce_mock.h - * \brief Mock test module of AllReduce engine, - * insert failures in certain call point, to test if the engine is robust to failure - * - * \author Ignacio Cano, Tianqi Chen - */ -#ifndef RABIT_ALLREDUCE_MOCK_H_ -#define RABIT_ALLREDUCE_MOCK_H_ -#include -#include -#include -#include -#include "rabit/internal/engine.h" -#include "allreduce_base.h" - -namespace rabit { -namespace engine { -class AllreduceMock : public AllreduceBase { - public: - // constructor - AllreduceMock() { - num_trial_ = 0; - force_local_ = 0; - report_stats_ = 0; - tsum_allreduce_ = 0.0; - tsum_allgather_ = 0.0; - } - // destructor - ~AllreduceMock() override = default; - void SetParam(const char *name, const char *val) override { - AllreduceBase::SetParam(name, val); - // additional parameters - if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val); - if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val); - if (!strcmp(name, "report_stats")) report_stats_ = atoi(val); - if (!strcmp(name, "force_local")) force_local_ = atoi(val); - if (!strcmp(name, "mock")) { - MockKey k; - utils::Check(sscanf(val, "%d,%d,%d,%d", - &k.rank, &k.version, &k.seqno, &k.ntrial) == 4, - "invalid mock parameter"); - mock_map_[k] = 1; - } - } - void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer, PreprocFunction prepare_fun, - void *prepare_arg) override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce"); - double tstart = dmlc::GetTime(); - AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer, - prepare_fun, prepare_arg); - tsum_allreduce_ += dmlc::GetTime() - tstart; - } - void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, - size_t slice_end, size_t size_prev_slice) override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather"); - double tstart = dmlc::GetTime(); - AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end, - size_prev_slice); - tsum_allgather_ += dmlc::GetTime() - tstart; - } - void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast"); - AllreduceBase::Broadcast(sendrecvbuf_, total_size, root); - } - int LoadCheckPoint() override { - tsum_allreduce_ = 0.0; - tsum_allgather_ = 0.0; - time_checkpoint_ = dmlc::GetTime(); - if (force_local_ == 0) { - return AllreduceBase::LoadCheckPoint(); - } else { - return AllreduceBase::LoadCheckPoint(); - } - } - void CheckPoint() override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint"); - double tstart = dmlc::GetTime(); - double tbet_chkpt = tstart - time_checkpoint_; - AllreduceBase::CheckPoint(); - time_checkpoint_ = dmlc::GetTime(); - double tcost = dmlc::GetTime() - tstart; - if (report_stats_ != 0 && rank == 0) { - std::stringstream ss; - ss << "[v" << version_number << "] global_size=" - << ",check_tcost="<< tcost <<" sec" - << ",allreduce_tcost=" << tsum_allreduce_ << " sec" - << ",allgather_tcost=" << tsum_allgather_ << " sec" - << ",between_chpt=" << tbet_chkpt << "sec\n"; - this->TrackerPrint(ss.str()); - } - tsum_allreduce_ = 0.0; - tsum_allgather_ = 0.0; - } - - protected: - // force checkpoint to local - int force_local_; - // whether report statistics - int report_stats_; - // sum of allreduce - double tsum_allreduce_; - // sum of allgather - double tsum_allgather_; - double time_checkpoint_; - - private: - // key to identify the mock stage - struct MockKey { - int rank; - int version; - int seqno; - int ntrial; - MockKey() = default; - MockKey(int rank, int version, int seqno, int ntrial) - : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {} - inline bool operator==(const MockKey &b) const { - return rank == b.rank && - version == b.version && - seqno == b.seqno && - ntrial == b.ntrial; - } - inline bool operator<(const MockKey &b) const { - if (rank != b.rank) return rank < b.rank; - if (version != b.version) return version < b.version; - if (seqno != b.seqno) return seqno < b.seqno; - return ntrial < b.ntrial; - } - }; - // number of failure trials - int num_trial_; - // record all mock actions - std::map mock_map_; - // used to generate all kinds of exceptions - inline void Verify(const MockKey &key, const char *name) { - if (mock_map_.count(key) != 0) { - num_trial_ += 1; - // data processing frameworks runs on shared process - throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name); - } - } -}; -} // namespace engine -} // namespace rabit -#endif // RABIT_ALLREDUCE_MOCK_H_ diff --git a/rabit/src/engine.cc b/rabit/src/engine.cc deleted file mode 100644 index 89f25fa1e..000000000 --- a/rabit/src/engine.cc +++ /dev/null @@ -1,106 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file engine.cc - * \brief this file governs which implementation of engine we are actually using - * provides an singleton of engine interface - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#include -#include - -#include -#include "rabit/internal/engine.h" -#include "allreduce_base.h" - -namespace rabit { -namespace engine { -// singleton sync manager -#ifndef RABIT_USE_BASE -#ifndef RABIT_USE_MOCK -using Manager = AllreduceBase; -#else -typedef AllreduceMock Manager; -#endif // RABIT_USE_MOCK -#else -typedef AllreduceBase Manager; -#endif // RABIT_USE_BASE - -/*! \brief entry to to easily hold returning information */ -struct ThreadLocalEntry { - /*! \brief stores the current engine */ - std::unique_ptr engine; - /*! \brief whether init has been called */ - bool initialized{false}; - /*! \brief constructor */ - ThreadLocalEntry() = default; -}; - -// define the threadlocal store. -using EngineThreadLocal = dmlc::ThreadLocalStore; - -/*! \brief intiialize the synchronization module */ -bool Init(int argc, char *argv[]) { - ThreadLocalEntry* e = EngineThreadLocal::Get(); - if (e->engine.get() == nullptr) { - e->initialized = true; - e->engine.reset(new Manager()); - return e->engine->Init(argc, argv); - } else { - return true; - } -} - -/*! \brief finalize syncrhonization module */ -bool Finalize() { - ThreadLocalEntry* e = EngineThreadLocal::Get(); - if (e->engine.get() != nullptr) { - if (e->engine->Shutdown()) { - e->engine.reset(nullptr); - e->initialized = false; - return true; - } else { - return false; - } - } else { - return true; - } -} - -/*! \brief singleton method to get engine */ -IEngine *GetEngine() { - // un-initialized default manager. - static AllreduceBase default_manager; - ThreadLocalEntry* e = EngineThreadLocal::Get(); - IEngine* ptr = e->engine.get(); - if (ptr == nullptr) { - utils::Check(!e->initialized, "the rabit has not been initialized"); - return &default_manager; - } else { - return ptr; - } -} - -// perform in-place allgather, on sendrecvbuf -void Allgather(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice) { - GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, - slice_end, size_prev_slice); -} - - -// perform in-place allreduce, on sendrecvbuf -void Allreduce_(void *sendrecvbuf, // NOLINT - size_t type_nbytes, - size_t count, - IEngine::ReduceFunction red, - mpi::DataType, - mpi::OpType , - IEngine::PreprocFunction prepare_fun, - void *prepare_arg) { - GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg); -} -} // namespace engine -} // namespace rabit diff --git a/rabit/src/engine_mock.cc b/rabit/src/engine_mock.cc deleted file mode 100644 index 5c0f8505e..000000000 --- a/rabit/src/engine_mock.cc +++ /dev/null @@ -1,14 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file engine_mock.cc - * \brief this is an engine implementation that will - * insert failures in certain call point, to test if the engine is robust to failure - * \author Tianqi Chen - */ -// define use MOCK, os we will use mock Manager -#define NOMINMAX -// switch engine to AllreduceMock -#define RABIT_USE_MOCK -#include -#include "allreduce_mock.h" -#include "engine.cc" diff --git a/rabit/src/rabit_c_api.cc b/rabit/src/rabit_c_api.cc deleted file mode 100644 index c90fae830..000000000 --- a/rabit/src/rabit_c_api.cc +++ /dev/null @@ -1,342 +0,0 @@ -// Copyright by Contributors -// implementations in ctypes -#include -#include -#include -#include "rabit/rabit.h" -#include "rabit/c_api.h" - -#include "../../src/c_api/c_api_error.h" - -namespace rabit { -namespace c_api { -// helper use to avoid BitOR operator -template -struct FHelper { - static void - Allreduce(DType *senrecvbuf_, - size_t count, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - rabit::Allreduce(senrecvbuf_, count, - prepare_fun, prepare_arg); - } -}; - -template -struct FHelper { - static void - Allreduce(DType *, - size_t , - void (*)(void *arg), - void *) { - utils::Error("DataType does not support bitwise AND operation"); - } -}; - -template -struct FHelper { - static void - Allreduce(DType *, - size_t , - void (*)(void *arg), - void *) { - utils::Error("DataType does not support bitwise OR operation"); - } -}; - -template -struct FHelper { - static void - Allreduce(DType *, - size_t , - void (*)(void *arg), - void *) { - utils::Error("DataType does not support bitwise XOR operation"); - } -}; - -template -void Allreduce(void *sendrecvbuf_, - size_t count, - engine::mpi::DataType enum_dtype, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - using namespace engine::mpi; // NOLINT - switch (enum_dtype) { - case kChar: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kUChar: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kInt: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kUInt: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kLong: - rabit::Allreduce - (static_cast(sendrecvbuf_), // NOLINT(*) - count, prepare_fun, prepare_arg); - return; - case kULong: - rabit::Allreduce - (static_cast(sendrecvbuf_), // NOLINT(*) - count, prepare_fun, prepare_arg); - return; - case kFloat: - FHelper::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kDouble: - FHelper::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - default: utils::Error("unknown data_type"); - } -} -void Allreduce(void *sendrecvbuf, - size_t count, - engine::mpi::DataType enum_dtype, - engine::mpi::OpType enum_op, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - using namespace engine::mpi; // NOLINT - switch (enum_op) { - case kMax: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kMin: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kSum: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kBitwiseAND: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kBitwiseOR: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kBitwiseXOR: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - default: utils::Error("unknown enum_op"); - } -} - -void Allgather(void *sendrecvbuf_, - size_t total_size, - size_t beginIndex, - size_t size_node_slice, - size_t size_prev_slice, - int enum_dtype) { - using namespace engine::mpi; // NOLINT - size_t type_size = 0; - switch (enum_dtype) { - case kChar: - type_size = sizeof(char); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kUChar: - type_size = sizeof(unsigned char); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kInt: - type_size = sizeof(int); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kUInt: - type_size = sizeof(unsigned); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kLong: - type_size = sizeof(int64_t); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kULong: - type_size = sizeof(uint64_t); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kFloat: - type_size = sizeof(float); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kDouble: - type_size = sizeof(double); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - default: utils::Error("unknown data_type"); - } -} - -// wrapper for serialization -struct ReadWrapper : public Serializable { - std::string *p_str; - explicit ReadWrapper(std::string *p_str) - : p_str(p_str) {} - void Load(Stream *fi) override { - uint64_t sz; - utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, - "Read pickle string"); - p_str->resize(sz); - if (sz != 0) { - utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0, - "Read pickle string"); - } - } - void Save(Stream *) const override { - utils::Error("not implemented"); - } -}; - -struct WriteWrapper : public Serializable { - const char *data; - size_t length; - explicit WriteWrapper(const char *data, - size_t length) - : data(data), length(length) { - } - void Load(Stream *) override { - utils::Error("not implemented"); - } - void Save(Stream *fo) const override { - uint64_t sz = static_cast(length); - fo->Write(&sz, sizeof(sz)); - fo->Write(data, length * sizeof(char)); - } -}; -} // namespace c_api -} // namespace rabit - -RABIT_DLL bool RabitInit(int argc, char *argv[]) { - auto ret = rabit::Init(argc, argv); - if (!ret) { - XGBAPISetLastError("Failed to initialize RABIT."); - } - return ret; -} - -RABIT_DLL int RabitFinalize() { - auto ret = rabit::Finalize(); - if (!ret) { - XGBAPISetLastError("Failed to shutdown RABIT worker."); - } - return static_cast(ret); -} - -RABIT_DLL int RabitGetRingPrevRank() { - return rabit::GetRingPrevRank(); -} - -RABIT_DLL int RabitGetRank() { - return rabit::GetRank(); -} - -RABIT_DLL int RabitGetWorldSize() { - return rabit::GetWorldSize(); -} - -RABIT_DLL int RabitIsDistributed() { - return rabit::IsDistributed(); -} - -RABIT_DLL int RabitTrackerPrint(const char *msg) { - API_BEGIN() - std::string m(msg); - rabit::TrackerPrint(m); - API_END() -} - -RABIT_DLL void RabitGetProcessorName(char *out_name, - rbt_ulong *out_len, - rbt_ulong max_len) { - std::string s = rabit::GetProcessorName(); - if (s.length() > max_len) { - s.resize(max_len - 1); - } - strcpy(out_name, s.c_str()); // NOLINT(*) - *out_len = static_cast(s.length()); -} - -RABIT_DLL int RabitBroadcast(void *sendrecv_data, - rbt_ulong size, int root) { - API_BEGIN() - rabit::Broadcast(sendrecv_data, size, root); - API_END() -} - -RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size, - size_t beginIndex, size_t size_node_slice, - size_t size_prev_slice, int enum_dtype) { - API_BEGIN() - rabit::c_api::Allgather( - sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice, - static_cast(enum_dtype)); - API_END() -} - -RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, - int enum_op, void (*prepare_fun)(void *arg), - void *prepare_arg) { - API_BEGIN() - rabit::c_api::Allreduce(sendrecvbuf, count, - static_cast(enum_dtype), - static_cast(enum_op), - prepare_fun, prepare_arg); - API_END() -} - -RABIT_DLL int RabitVersionNumber() { - return rabit::VersionNumber(); -} - -RABIT_DLL int RabitLinkTag() { - return 0; -} diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 79d9793e6..45160baea 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -15,9 +15,9 @@ #include // for pair #include // for vector -#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor... #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry #include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... +#include "../common/error_msg.h" // for NoFederated #include "../common/hist_util.h" // for HistogramCuts #include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... #include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor @@ -27,11 +27,10 @@ #include "../data/simple_dmatrix.h" // for SimpleDMatrix #include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN #include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... -#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED +#include "dmlc/base.h" // for BeginPtr #include "dmlc/io.h" // for Stream #include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager #include "dmlc/thread_local.h" // for ThreadLocalStore -#include "rabit/c_api.h" // for RabitLinkTag #include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat... #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage @@ -46,10 +45,6 @@ #include "xgboost/string_view.h" // for StringView, operator<< #include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... -#if defined(XGBOOST_USE_FEDERATED) -#include "../../plugin/federated/federated_server.h" -#endif - using namespace xgboost; // NOLINT(*); XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) { @@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config, *out_features = dmlc::BeginPtr(feature_names_c); API_END(); } - -XGB_DLL int XGCommunicatorInit(char const* json_config) { - API_BEGIN(); - xgboost_CHECK_C_ARG_PTR(json_config); - Json config{Json::Load(StringView{json_config})}; - collective::Init(config); - API_END(); -} - -XGB_DLL int XGCommunicatorFinalize() { - API_BEGIN(); - collective::Finalize(); - API_END(); -} - -XGB_DLL int XGCommunicatorGetRank(void) { - return collective::GetRank(); -} - -XGB_DLL int XGCommunicatorGetWorldSize(void) { - return collective::GetWorldSize(); -} - -XGB_DLL int XGCommunicatorIsDistributed(void) { - return collective::IsDistributed(); -} - -XGB_DLL int XGCommunicatorPrint(char const *message) { - API_BEGIN(); - collective::Print(message); - API_END(); -} - -XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { - API_BEGIN(); - auto& local = *GlobalConfigAPIThreadLocalStore::Get(); - local.ret_str = collective::GetProcessorName(); - xgboost_CHECK_C_ARG_PTR(name_str); - *name_str = local.ret_str.c_str(); - API_END(); -} - -XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { - API_BEGIN(); - collective::Broadcast(send_receive_buffer, size, root); - API_END(); -} - -XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, - int enum_op) { - API_BEGIN(); - collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op); - API_END(); -} - -#if defined(XGBOOST_USE_FEDERATED) -XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path, - char const *server_cert_path, char const *client_cert_path) { - API_BEGIN(); - federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path); - API_END(); -} - -// Run a server without SSL for local testing. -XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) { - API_BEGIN(); - federated::RunInsecureServer(port, world_size); - API_END(); -} -#endif - -// force link rabit -static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag(); diff --git a/src/c_api/c_api_error.cc b/src/c_api/c_api_error.cc index 10e864c80..59dfb8854 100644 --- a/src/c_api/c_api_error.cc +++ b/src/c_api/c_api_error.cc @@ -1,22 +1,28 @@ -/*! - * Copyright (c) 2015 by Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file c_api_error.cc * \brief C error handling */ -#include -#include "xgboost/c_api.h" #include "./c_api_error.h" +#include + +#include "xgboost/c_api.h" +#include "../collective/comm.h" +#include "../collective/comm_group.h" + struct XGBAPIErrorEntry { std::string last_error; + std::int32_t code{-1}; }; using XGBAPIErrorStore = dmlc::ThreadLocalStore; -XGB_DLL const char *XGBGetLastError() { - return XGBAPIErrorStore::Get()->last_error.c_str(); -} +XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); } void XGBAPISetLastError(const char* msg) { XGBAPIErrorStore::Get()->last_error = msg; + XGBAPIErrorStore::Get()->code = -1; } + +XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; } diff --git a/src/c_api/c_api_error.h b/src/c_api/c_api_error.h index 11c440384..0ad4ac073 100644 --- a/src/c_api/c_api_error.h +++ b/src/c_api/c_api_error.h @@ -10,6 +10,7 @@ #include #include "c_api_utils.h" +#include "xgboost/collective/result.h" /*! \brief macro to guard beginning and end section of all functions */ #ifdef LOG_CAPI_INVOCATION @@ -30,7 +31,7 @@ #define API_END() \ } catch (dmlc::Error & _except_) { \ return XGBAPIHandleException(_except_); \ - } catch (std::exception const &_except_) { \ + } catch (std::exception const& _except_) { \ return XGBAPIHandleException(dmlc::Error(_except_.what())); \ } \ return 0; // NOLINT(*) @@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg); * \param e the exception * \return the return value of API after exception is handled */ -inline int XGBAPIHandleException(const dmlc::Error &e) { +inline int XGBAPIHandleException(const dmlc::Error& e) { XGBAPISetLastError(e.what()); return -1; } diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc index fba2647cc..1da226103 100644 --- a/src/c_api/coll_c_api.cc +++ b/src/c_api/coll_c_api.cc @@ -9,10 +9,15 @@ #include // for is_same_v, remove_pointer_t #include // for pair -#include "../collective/comm.h" // for DefaultTimeoutSec -#include "../collective/tracker.h" // for RabitTracker -#include "../common/timer.h" // for Timer -#include "c_api_error.h" // for API_BEGIN +#include "../collective/allgather.h" // for Allgather +#include "../collective/allreduce.h" // for Allreduce +#include "../collective/broadcast.h" // for Broadcast +#include "../collective/comm.h" // for DefaultTimeoutSec +#include "../collective/comm_group.h" // for GlobalCommGroup +#include "../collective/communicator-inl.h" // for GetProcessorName +#include "../collective/tracker.h" // for RabitTracker +#include "../common/timer.h" // for Timer +#include "c_api_error.h" // for API_BEGIN #include "xgboost/c_api.h" #include "xgboost/collective/result.h" // for Result #include "xgboost/json.h" // for Json @@ -20,10 +25,36 @@ #if defined(XGBOOST_USE_FEDERATED) #include "../../plugin/federated/federated_tracker.h" // for FederatedTracker -#else -#include "../common/error_msg.h" // for NoFederated #endif +namespace xgboost::collective { +void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) { + Context ctx; + DispatchDType(static_cast(data_type), [&](auto t) { + using T = decltype(t); + auto data = linalg::MakeTensorView( + &ctx, common::Span{static_cast(send_receive_buffer), count}, count); + auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast(op)); + SafeColl(rc); + }); +} + +void Broadcast(void *send_receive_buffer, std::size_t size, int root) { + Context ctx; + auto rc = Broadcast(&ctx, *GlobalCommGroup(), + linalg::MakeVec(static_cast(send_receive_buffer), size), root); + SafeColl(rc); +} + +void Allgather(void *send_receive_buffer, std::size_t size) { + Context ctx; + auto const &comm = GlobalCommGroup(); + auto rc = Allgather(&ctx, *comm, + linalg::MakeVec(reinterpret_cast(send_receive_buffer), size)); + SafeColl(rc); +} +} // namespace xgboost::collective + using namespace xgboost; // NOLINT namespace { @@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { constexpr std::int64_t kDft{collective::DefaultTimeoutSec()}; - std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft}; + std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count()) + : kDft}; common::Timer timer; timer.Start(); @@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { break; } - if (timer.Duration() > timeout && timeout.count() != 0) { + if (timer.Duration() > timeout && collective::HasTimeout(timeout)) { collective::SafeColl(collective::Fail("Timeout waiting for the tracker.")); } } @@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) { // Make sure no one else is waiting on the tracker. while (!ptr->first.unique()) { auto ela = timer.Duration().count(); - if (ela > ptr->first->Timeout().count()) { + if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) { LOG(WARNING) << "Time out " << ptr->first->Timeout().count() << " seconds reached for TrackerFree, killing the tracker."; break; @@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) { delete ptr; API_END(); } + +XGB_DLL int XGCommunicatorInit(char const *json_config) { + API_BEGIN(); + xgboost_CHECK_C_ARG_PTR(json_config); + Json config{Json::Load(StringView{json_config})}; + collective::GlobalCommGroupInit(config); + API_END(); +} + +XGB_DLL int XGCommunicatorFinalize(void) { + API_BEGIN(); + collective::GlobalCommGroupFinalize(); + API_END(); +} + +XGB_DLL int XGCommunicatorGetRank(void) { + API_BEGIN(); + return collective::GetRank(); + API_END(); +} + +XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); } + +XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); } + +XGB_DLL int XGCommunicatorPrint(char const *message) { + API_BEGIN(); + collective::Print(message); + API_END(); +} + +XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { + API_BEGIN(); + auto &local = *CollAPIThreadLocalStore::Get(); + local.ret_str = collective::GetProcessorName(); + xgboost_CHECK_C_ARG_PTR(name_str); + *name_str = local.ret_str.c_str(); + API_END(); +} + +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { + API_BEGIN(); + collective::Broadcast(send_receive_buffer, size, root); + API_END(); +} + +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, + int enum_op) { + API_BEGIN(); + collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op); + API_END(); +} + +// Not exposed to the public since the previous implementation didn't and we don't want to +// add unnecessary communicator API to a machine learning library. +XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) { + API_BEGIN(); + collective::Allgather(send_receive_buffer, count); + API_END(); +} + +// Not yet exposed to the public, error recovery is still WIP. +XGB_DLL int XGCommunicatorSignalError() { + API_BEGIN(); + auto msg = XGBGetLastError(); + SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg))); + API_END() +} diff --git a/src/cli_main.cc b/src/cli_main.cc index 276d67da8..54a345027 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -22,7 +22,6 @@ #include #include #include -#include "collective/communicator-inl.h" #include "common/common.h" #include "common/config.h" #include "common/io.h" @@ -193,10 +192,6 @@ class CLI { void CLITrain() { const double tstart_data_load = dmlc::GetTime(); - if (collective::IsDistributed()) { - std::string pname = collective::GetProcessorName(); - LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank(); - } // load in data. std::shared_ptr dtrain(DMatrix::Load( param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), @@ -235,15 +230,9 @@ class CLI { version += 1; } std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names); - if (collective::IsDistributed()) { - if (collective::GetRank() == 0) { - LOG(TRACKER) << res; - } - } else { - LOG(CONSOLE) << res; - } - if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 && - collective::GetRank() == 0) { + LOG(CONSOLE) << res; + + if (param_.save_period != 0 && (i + 1) % param_.save_period == 0) { std::ostringstream os; os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) << i + 1 << ".model"; @@ -256,8 +245,7 @@ class CLI { << " sec"; // always save final round if ((param_.save_period == 0 || - param_.num_round % param_.save_period != 0) && - collective::GetRank() == 0) { + param_.num_round % param_.save_period != 0)) { std::ostringstream os; if (param_.model_out == CLIParam::kNull) { os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) @@ -465,13 +453,6 @@ class CLI { } } - // Initialize the collective communicator. - Json json{JsonObject()}; - for (auto& kv : cfg) { - json[kv.first] = String(kv.second); - } - collective::Init(json); - param_.Configure(cfg); } @@ -507,10 +488,6 @@ class CLI { } return 0; } - - ~CLI() { - collective::Finalize(); - } }; } // namespace xgboost diff --git a/src/collective/aggregator.cuh b/src/collective/aggregator.cuh index 66766470b..d85e328aa 100644 --- a/src/collective/aggregator.cuh +++ b/src/collective/aggregator.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023 by XGBoost contributors + * Copyright 2023-2024, XGBoost contributors * * Higher level functions built on top the Communicator API, taking care of behavioral differences * between row-split vs column-split distributed training, and horizontal vs vertical federated @@ -13,7 +13,8 @@ #include #include -#include "communicator-inl.cuh" +#include "allreduce.h" +#include "xgboost/collective/result.h" // for Result namespace xgboost::collective { @@ -24,15 +25,17 @@ namespace xgboost::collective { * column-wise (vertically), the original values are returned. * * @tparam T The type of the values. + * * @param info MetaInfo about the DMatrix. - * @param device The device id. * @param values Pointer to the inputs to sum. * @param size Number of values to sum. */ -template -void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) { +template +[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, + linalg::TensorView values) { if (info.IsRowSplit()) { - collective::AllReduce(device.ordinal, values, size); + return collective::Allreduce(ctx, values, collective::Op::kSum); } + return Success(); } } // namespace xgboost::collective diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index bc652f2e8..a328a6120 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -11,11 +11,44 @@ #include #include +#include "allreduce.h" +#include "broadcast.h" +#include "comm.h" #include "communicator-inl.h" #include "xgboost/collective/result.h" // for Result #include "xgboost/data.h" // for MetaINfo namespace xgboost::collective { +namespace detail { +template +[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) { + std::string msg; + if (collective::GetRank() == 0) { + try { + fn(); + } catch (dmlc::Error const& e) { + msg = e.what(); + } + } + std::size_t msg_size{msg.size()}; + auto rc = Success() << [&] { + auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0); + return rc; + } << [&] { + if (msg_size > 0) { + msg.resize(msg_size); + return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0); + } + return Success(); + } << [&] { + if (msg_size > 0) { + LOG(FATAL) << msg; + } + return Success(); + }; + return rc; +} +} // namespace detail /** * @brief Apply the given function where the labels are. @@ -30,29 +63,19 @@ namespace xgboost::collective { * @param size The size of the buffer. * @param function The function used to calculate the results. */ -template -void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size, - FN&& function) { +template +void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size, + Fn&& fn) { if (info.IsVerticalFederated()) { - // We assume labels are only available on worker 0, so the calculation is done there and result - // broadcast to other workers. - std::string message; - if (collective::GetRank() == 0) { - try { - std::forward(function)(); - } catch (dmlc::Error& e) { - message = e.what(); - } - } - - collective::Broadcast(&message, 0); - if (message.empty()) { - collective::Broadcast(buffer, size, 0); - } else { - LOG(FATAL) << &message[0]; - } + auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] { + // We assume labels are only available on worker 0, so the calculation is done there and + // result broadcast to other workers. + return collective::Broadcast( + ctx, linalg::MakeVec(reinterpret_cast(buffer), size), 0); + }; + SafeColl(rc); } else { - std::forward(function)(); + std::forward(fn)(); } } @@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si * @param result The HostDeviceVector storing the results. * @param function The function used to calculate the results. */ -template -void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector* result, - Function&& function) { +template +void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector* result, + Fn&& fn) { if (info.IsVerticalFederated()) { // We assume labels are only available on worker 0, so the calculation is done there and result // broadcast to other workers. - std::string message; - if (collective::GetRank() == 0) { - try { - std::forward(function)(); - } catch (dmlc::Error& e) { - message = e.what(); - } - } + auto rc = detail::TryApplyWithLabels(ctx, fn); - collective::Broadcast(&message, 0); - if (!message.empty()) { - LOG(FATAL) << &message[0]; - return; - } - - std::size_t size{}; - if (collective::GetRank() == 0) { - size = result->Size(); - } - collective::Broadcast(&size, sizeof(std::size_t), 0); - - result->Resize(size); - collective::Broadcast(result->HostPointer(), size * sizeof(T), 0); + std::size_t size{result->Size()}; + rc = std::move(rc) << [&] { + return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0); + } << [&] { + result->Resize(size); + return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0); + }; + SafeColl(rc); } else { - std::forward(function)(); + std::forward(fn)(); } } @@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector* * @return The global max of the input. */ template -std::enable_if_t, T> GlobalMax(Context const*, +std::enable_if_t, T> GlobalMax(Context const* ctx, MetaInfo const& info, T value) { if (info.IsRowSplit()) { - collective::Allreduce(&value, 1); + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax); + SafeColl(rc); } return value; } @@ -136,19 +147,14 @@ std::enable_if_t, T> GlobalMax(Context co * @param size Number of values to sum. */ template -[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info, +[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, linalg::TensorView values) { if (info.IsRowSplit()) { - collective::Allreduce(values.Values().data(), values.Size()); + return collective::Allreduce(ctx, values, collective::Op::kSum); } return Success(); } -template -[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) { - return GlobalSum(ctx, info, values->data(), values->size()); -} - /** * @brief Find the global ratio of the given two values across all workers. * diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 5d1ec664e..c2c7a500f 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size return comm.Block(); }; if (!rc.OK()) { - return rc; + return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc)); } } @@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span si auto as_bytes = sizes[r]; auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r); if (!rc.OK()) { - return rc; + return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r), + std::move(rc)); } offset += as_bytes; } @@ -102,7 +103,7 @@ namespace detail { return prev_ch->Block(); }; if (!rc.OK()) { - return rc; + return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc)); } } return comm.Block(); diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index 55c5c8854..3b201c99d 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span data, Func auto rc = RingAllgather(comm, typed); if (!rc.OK()) { - return rc; + return Fail("Ring allreduce small failed.", std::move(rc)); } auto first = s_buffer.subspan(0, data.size_bytes()); CHECK_EQ(first.size(), data.size()); @@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto next_ch = comm.Chan(dst_rank); auto prev_ch = comm.Chan(src_rank); - std::vector buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0); + std::vector buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1); auto s_buf = common::Span{buffer.data(), buffer.size()}; for (std::int32_t r = 0; r < world - 1; ++r) { @@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, } << [&] { return comm.Block(); }; + if (!rc.OK()) { + return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r), + std::move(rc)); + } // accumulate to recv_seg CHECK_EQ(seg.size(), recv_seg.size()); @@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons auto n_bytes_in_seg = (n / world) * sizeof(T); auto rc = RingScatterReduceTyped(comm, data, n_bytes_in_seg, op); if (!rc.OK()) { - return rc; + return Fail("Ring Allreduce failed.", std::move(rc)); } auto prev = BootstrapPrev(comm.Rank(), comm.World()); diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 50a14aaaf..543ece639 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st } auto rank = comm.Rank(); - auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank())); - if (n_bytes != sizeof(comm.Rank())) { - return Fail("Failed to send rank."); + std::size_t n_bytes{0}; + auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes); + if (!rc.OK()) { + return rc; + } else if (n_bytes != sizeof(comm.Rank())) { + return Fail("Failed to send rank.", std::move(rc)); } workers[r] = std::move(worker); } @@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st return rc; } std::int32_t rank{-1}; - auto n_bytes = peer->RecvAll(&rank, sizeof(rank)); - if (n_bytes != sizeof(comm.Rank())) { + std::size_t n_bytes{0}; + auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes); + if (!rc.OK()) { + return rc; + } else if (n_bytes != sizeof(comm.Rank())) { return Fail("Failed to recv rank."); } workers[rank] = std::move(peer); diff --git a/src/collective/comm.h b/src/collective/comm.h index a41f47be9..0a0f24aad 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; } void Submit(Loop::Op op) const { CHECK(loop_); - loop_->Submit(op); + loop_->Submit(std::move(op)); } [[nodiscard]] virtual Result Block() const { return loop_->Block(); } diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index 18a5ba8a7..a9b58ecb5 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -76,7 +76,7 @@ CommGroup::CommGroup() // Common args auto retry = get_param("dmlc_retry", static_cast(DefaultRetry()), Integer{}); auto timeout = - get_param("dmlc_timeout_sec", static_cast(DefaultTimeoutSec()), Integer{}); + get_param("dmlc_timeout", static_cast(DefaultTimeoutSec()), Integer{}); auto task_id = get_param("dmlc_task_id", std::string{}, String{}); if (type == "rabit") { @@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() { sptr.reset(); SafeColl(rc); } + +void Init(Json const& config) { GlobalCommGroupInit(config); } + +void Finalize() { GlobalCommGroupFinalize(); } + +std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); } + +std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); } + +bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); } + +[[nodiscard]] bool IsFederated() { + return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated(); +} + +void Print(std::string const& message) { + auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message); + SafeColl(rc); +} + +std::string GetProcessorName() { + std::string out; + auto rc = GlobalCommGroup()->ProcessorName(&out); + SafeColl(rc); + return out; +} } // namespace xgboost::collective diff --git a/src/collective/communicator-inl.cc b/src/collective/communicator-inl.cc deleted file mode 100644 index 4164855f1..000000000 --- a/src/collective/communicator-inl.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2024, XGBoost contributors - */ -#include "communicator-inl.h" - -namespace xgboost::collective { -[[nodiscard]] std::vector> VectorAllgatherV( - std::vector> const &input) { - auto n_inputs = input.size(); - std::vector sizes(n_inputs); - std::transform(input.cbegin(), input.cend(), sizes.begin(), - [](auto const &vec) { return vec.size(); }); - - std::vector global_sizes = AllgatherV(sizes); - std::vector offset(global_sizes.size() + 1); - offset[0] = 0; - for (std::size_t i = 1; i < offset.size(); i++) { - offset[i] = offset[i - 1] + global_sizes[i - 1]; - } - - std::vector collected; - for (auto const &vec : input) { - collected.insert(collected.end(), vec.cbegin(), vec.cend()); - } - auto out = AllgatherV(collected); - - std::vector> result; - for (std::size_t i = 1; i < offset.size(); ++i) { - std::vector local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]); - result.emplace_back(std::move(local)); - } - return result; -} -} // namespace xgboost::collective diff --git a/src/collective/communicator-inl.cuh b/src/collective/communicator-inl.cuh deleted file mode 100644 index 200a9ff4a..000000000 --- a/src/collective/communicator-inl.cuh +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2023 by XGBoost contributors - */ -#pragma once -#include -#include - -#include "communicator.h" -#include "device_communicator.cuh" - -namespace xgboost { -namespace collective { - -/** - * @brief Reduce values from all processes and distribute the result back to all processes. - * @param device ID of the device. - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - */ -template -inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op); -} - -template -inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op); -} - -template -inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op); -} - -template -inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op); -} - -template -inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op); -} - -template -inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); -} - -template -inline void AllReduce(int device, float *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op); -} - -template -inline void AllReduce(int device, double *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op); -} - -/** - * @brief Gather values from all all processes. - * - * This assumes all ranks have the same size. - * - * @param send_buffer Buffer storing the data to be sent. - * @param receive_buffer Buffer storing the gathered data. - * @param send_size Size of the sent data in bytes. - */ -inline void AllGather(int device, void const *send_buffer, void *receive_buffer, - std::size_t send_size) { - Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size); -} - -/** - * @brief Gather variable-length values from all processes. - * @param device ID of the device. - * @param send_buffer Buffer storing the input data. - * @param length_bytes Length in bytes of the input data. - * @param segments Size of each segment. - * @param receive_buffer Buffer storing the output data. - */ -inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) { - Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer); -} - -/** - * @brief Synchronize device operations. - * @param device ID of the device. - */ -inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); } - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 991e19f2c..263200700 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -3,308 +3,63 @@ */ #pragma once #include -#include -#include "communicator.h" +#include "xgboost/json.h" // for Json -namespace xgboost { -namespace collective { +namespace xgboost::collective { +/** + * @brief Initialize the collective communicator. + */ +void Init(Json const& config); /** - * \brief Initialize the collective communicator. - * - * Currently the communicator API is experimental, function signatures may change in the future - * without notice. - * - * Call this once before using anything. - * - * The additional configuration is not required. Usually the communicator will detect settings - * from environment variables. - * - * \param json_config JSON encoded configuration. Accepted JSON keys are: - * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. - * * rabit: Use Rabit. This is the default if the type is unspecified. - * * mpi: Use MPI. - * * federated: Use the gRPC interface for Federated Learning. - * Only applicable to the Rabit communicator (these are case-sensitive): - * - rabit_tracker_uri: Hostname of the tracker. - * - rabit_tracker_port: Port number of the tracker. - * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. - * - rabit_world_size: Total number of workers. - * - rabit_hadoop_mode: Enable Hadoop support. - * - rabit_tree_reduce_minsize: Minimal size for tree reduce. - * - rabit_reduce_ring_mincount: Minimal count to perform ring reduce. - * - rabit_reduce_buffer: Size of the reduce buffer. - * - rabit_bootstrap_cache: Size of the bootstrap cache. - * - rabit_debug: Enable debugging. - * - rabit_timeout: Enable timeout. - * - rabit_timeout_sec: Timeout in seconds. - * - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. - * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as - * environment variables): - * - DMLC_TRACKER_URI: Hostname of the tracker. - * - DMLC_TRACKER_PORT: Port number of the tracker. - * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. - * - DMLC_ROLE: Role of the current task, "worker" or "server". - * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. - * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - * Only applicable to the Federated communicator (use upper case for environment variables, use - * lower case for runtime configuration): - * - federated_server_address: Address of the federated server. - * - federated_world_size: Number of federated workers. - * - federated_rank: Rank of the current worker. - * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. - * - federated_client_key: Client key file path. Only needed for the SSL mode. - * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. - */ -inline void Init(Json const &config) { Communicator::Init(config); } - -/*! - * \brief Finalize the collective communicator. + * @brief Finalize the collective communicator. * * Call this function after you finished all jobs. */ -inline void Finalize() { Communicator::Finalize(); } +void Finalize(); -/*! - * \brief Get rank of current process. +/** + * @brief Get rank of current process. * - * \return Rank of the worker. + * @return Rank of the worker. */ -inline int GetRank() { return Communicator::Get()->GetRank(); } +[[nodiscard]] std::int32_t GetRank() noexcept; -/*! - * \brief Get total number of processes. +/** + * @brief Get total number of processes. * - * \return Total world size. + * @return Total world size. */ -inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); } +[[nodiscard]] std::int32_t GetWorldSize() noexcept; -/*! - * \brief Get if the communicator is distributed. +/** + * @brief Get if the communicator is distributed. * - * \return True if the communicator is distributed. + * @return True if the communicator is distributed. */ -inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); } +[[nodiscard]] bool IsDistributed() noexcept; -/*! - * \brief Get if the communicator is federated. +/** + * @brief Get if the communicator is federated. * - * \return True if the communicator is federated. + * @return True if the communicator is federated. */ -inline bool IsFederated() { return Communicator::Get()->IsFederated(); } +[[nodiscard]] bool IsFederated(); -/*! - * \brief Print the message to the communicator. +/** + * @brief Print the message to the communicator. * * This function can be used to communicate the information of the progress to the user who monitors * the communicator. * - * \param message The message to be printed. + * @param message The message to be printed. */ -inline void Print(char const *message) { Communicator::Get()->Print(message); } - -inline void Print(std::string const &message) { Communicator::Get()->Print(message); } - -/*! - * \brief Get the name of the processor. - * - * \return Name of the processor. - */ -inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); } - -/*! - * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. - * - * Example: - * int a = 1; - * Broadcast(&a, sizeof(a), root); - * - * \param send_receive_buffer Pointer to the send or receive buffer. - * \param size Size of the data. - * \param root The process rank to broadcast from. - */ -inline void Broadcast(void *send_receive_buffer, size_t size, int root) { - Communicator::Get()->Broadcast(send_receive_buffer, size, root); -} - -inline void Broadcast(std::string *sendrecv_data, int root) { - size_t size = sendrecv_data->length(); - Broadcast(&size, sizeof(size), root); - if (sendrecv_data->length() != size) { - sendrecv_data->resize(size); - } - if (size != 0) { - Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); - } -} - +void Print(std::string const& message); /** - * @brief Gathers a single value all processes and distributes the result to all processes. + * @brief Get the name of the processor. * - * @param input The single value. + * @return Name of the processor. */ -template -inline std::vector Allgather(T const &input) { - std::string_view str_input{reinterpret_cast(&input), sizeof(T)}; - auto const output = Communicator::Get()->AllGather(str_input); - CHECK_EQ(output.size() % sizeof(T), 0); - std::vector result(output.size() / sizeof(T)); - std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); - return result; -} - -/** - * @brief Gathers data from all processes and distributes it to all processes. - * - * This assumes all ranks have the same size. - * - * @param input Buffer storing the data. - */ -template -inline std::vector Allgather(std::vector const &input) { - if (input.empty()) { - return input; - } - std::string_view str_input{reinterpret_cast(input.data()), - input.size() * sizeof(T)}; - auto const output = Communicator::Get()->AllGather(str_input); - CHECK_EQ(output.size() % sizeof(T), 0); - std::vector result(output.size() / sizeof(T)); - std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); - return result; -} - -/** - * @brief Gathers variable-length data from all processes and distributes it to all processes. - * @param input Buffer storing the data. - */ -template -inline std::vector AllgatherV(std::vector const &input) { - std::string_view str_input{reinterpret_cast(input.data()), - input.size() * sizeof(T)}; - auto const output = Communicator::Get()->AllGatherV(str_input); - CHECK_EQ(output.size() % sizeof(T), 0); - std::vector result(output.size() / sizeof(T)); - if (!output.empty()) { - std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); - } - return result; -} - -/** - * @brief Gathers variable-length data from all processes and distributes it to all processes. - * - * @param inputs All the inputs from the local worker. The number of inputs can vary - * across different workers. Along with which, the size of each vector in - * the input can also vary. - * - * @return The AllgatherV result, containing vectors from all workers. - */ -[[nodiscard]] std::vector> VectorAllgatherV( - std::vector> const &input); - -/** - * @brief Gathers variable-length strings from all processes and distributes them to all processes. - * @param input Variable-length list of variable-length strings. - */ -inline std::vector AllgatherStrings(std::vector const &input) { - std::size_t total_size{0}; - for (auto const &s : input) { - total_size += s.length() + 1; // +1 for null-terminators - } - std::string flat_string; - flat_string.reserve(total_size); - for (auto const &s : input) { - flat_string.append(s); - flat_string.push_back('\0'); // Append a null-terminator after each string - } - - auto const output = Communicator::Get()->AllGatherV(flat_string); - - std::vector result; - std::size_t start_index = 0; - // Iterate through the output, find each null-terminated substring. - for (std::size_t i = 0; i < output.size(); i++) { - if (output[i] == '\0') { - // Construct a std::string from the char* substring - result.emplace_back(&output[start_index]); - // Move to the next substring - start_index = i + 1; - } - } - return result; -} - -/*! - * \brief Perform in-place allreduce. This function is NOT thread-safe. - * - * Example Usage: the following code gives sum of the result - * vector data(10); - * ... - * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); - * ... - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param count Number of elements to be reduced. - * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. - * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. - */ -inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) { - Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast(data_type), - static_cast(op)); -} - -inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) { - Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op); -} - -template -inline void Allreduce(int8_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op); -} - -template -inline void Allreduce(uint8_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op); -} - -template -inline void Allreduce(int32_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op); -} - -template -inline void Allreduce(uint32_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op); -} - -template -inline void Allreduce(int64_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op); -} - -template -inline void Allreduce(uint64_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); -} - -// Specialization for size_t, which is implementation defined, so it might or might not -// be one of uint64_t/uint32_t/unsigned long long/unsigned long. -template {} && !std::is_same{}> > -inline void Allreduce(T *send_receive_buffer, size_t count) { - static_assert(sizeof(T) == sizeof(uint64_t)); - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); -} - -template -inline void Allreduce(float *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op); -} - -template -inline void Allreduce(double *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op); -} -} // namespace collective -} // namespace xgboost +std::string GetProcessorName(); +} // namespace xgboost::collective diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc deleted file mode 100644 index 7fabe50b4..000000000 --- a/src/collective/communicator.cc +++ /dev/null @@ -1,63 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "communicator.h" - -#include "comm.h" -#include "in_memory_communicator.h" -#include "noop_communicator.h" -#include "rabit_communicator.h" - -#if defined(XGBOOST_USE_FEDERATED) -#include "../../plugin/federated/federated_communicator.h" -#endif - -namespace xgboost::collective { -thread_local std::unique_ptr Communicator::communicator_{new NoOpCommunicator()}; -thread_local CommunicatorType Communicator::type_{}; -thread_local std::string Communicator::nccl_path_{}; - -void Communicator::Init(Json const& config) { - auto nccl = OptionalArg(config, "dmlc_nccl_path", std::string{DefaultNcclName()}); - nccl_path_ = nccl; - - auto type = GetTypeFromEnv(); - auto const arg = GetTypeFromConfig(config); - if (arg != CommunicatorType::kUnknown) { - type = arg; - } - if (type == CommunicatorType::kUnknown) { - // Default to Rabit if unspecified. - type = CommunicatorType::kRabit; - } - type_ = type; - switch (type) { - case CommunicatorType::kRabit: { - communicator_.reset(RabitCommunicator::Create(config)); - break; - } - case CommunicatorType::kFederated: { -#if defined(XGBOOST_USE_FEDERATED) - communicator_.reset(FederatedCommunicator::Create(config)); -#else - LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; -#endif - break; - } - case CommunicatorType::kInMemory: - case CommunicatorType::kInMemoryNccl: { - communicator_.reset(InMemoryCommunicator::Create(config)); - break; - } - case CommunicatorType::kUnknown: - LOG(FATAL) << "Unknown communicator type."; - } -} - -#ifndef XGBOOST_USE_CUDA -void Communicator::Finalize() { - communicator_->Shutdown(); - communicator_.reset(new NoOpCommunicator()); -} -#endif -} // namespace xgboost::collective diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu deleted file mode 100644 index a7552d356..000000000 --- a/src/collective/communicator.cu +++ /dev/null @@ -1,54 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "communicator.h" -#include "device_communicator.cuh" -#include "device_communicator_adapter.cuh" -#include "noop_communicator.h" -#ifdef XGBOOST_USE_NCCL -#include "nccl_device_communicator.cuh" -#endif - -namespace xgboost { -namespace collective { - -thread_local std::unique_ptr Communicator::device_communicator_{}; - -void Communicator::Finalize() { - communicator_->Shutdown(); - communicator_.reset(new NoOpCommunicator()); - device_communicator_.reset(nullptr); -} - -DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { - thread_local auto old_device_ordinal = -1; - // If the number of GPUs changes, we need to re-initialize NCCL. - thread_local auto old_world_size = -1; - if (!device_communicator_ || device_ordinal != old_device_ordinal || - communicator_->GetWorldSize() != old_world_size) { - old_device_ordinal = device_ordinal; - old_world_size = communicator_->GetWorldSize(); -#ifdef XGBOOST_USE_NCCL - switch (type_) { - case CommunicatorType::kRabit: - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_)); - break; - case CommunicatorType::kFederated: - case CommunicatorType::kInMemory: - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); - break; - case CommunicatorType::kInMemoryNccl: - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_)); - break; - default: - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_)); - } -#else - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); -#endif - } - return device_communicator_.get(); -} - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/communicator.h b/src/collective/communicator.h deleted file mode 100644 index b6910b80f..000000000 --- a/src/collective/communicator.h +++ /dev/null @@ -1,247 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include -#include - -#include -#include - -namespace xgboost { -namespace collective { - -/** @brief Defines the integral and floating data types. */ -enum class DataType { - kInt8 = 0, - kUInt8 = 1, - kInt32 = 2, - kUInt32 = 3, - kInt64 = 4, - kUInt64 = 5, - kFloat = 6, - kDouble = 7 -}; - -/** @brief Get the size of the data type. */ -inline std::size_t GetTypeSize(DataType data_type) { - std::size_t size{0}; - switch (data_type) { - case DataType::kInt8: - size = sizeof(std::int8_t); - break; - case DataType::kUInt8: - size = sizeof(std::uint8_t); - break; - case DataType::kInt32: - size = sizeof(std::int32_t); - break; - case DataType::kUInt32: - size = sizeof(std::uint32_t); - break; - case DataType::kInt64: - size = sizeof(std::int64_t); - break; - case DataType::kUInt64: - size = sizeof(std::uint64_t); - break; - case DataType::kFloat: - size = sizeof(float); - break; - case DataType::kDouble: - size = sizeof(double); - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return size; -} - -/** @brief Defines the reduction operation. */ -enum class Operation { - kMax = 0, - kMin = 1, - kSum = 2, - kBitwiseAND = 3, - kBitwiseOR = 4, - kBitwiseXOR = 5 -}; - -class DeviceCommunicator; - -enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl }; - -/** \brief Case-insensitive string comparison. */ -inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { -#ifdef _MSC_VER - return _stricmp(s1, s2); -#else // _MSC_VER - return strcasecmp(s1, s2); -#endif // _MSC_VER -} - -/** - * @brief A communicator class that handles collective communication. - */ -class Communicator { - public: - /** - * @brief Initialize the communicator. This can only be done once. - * - * @param config JSON configuration for the communicator. - */ - static void Init(Json const &config); - - /** @brief Finalize the communicator. */ - static void Finalize(); - - /** @brief Get the communicator instance. */ - static Communicator *Get() { return communicator_.get(); } - -#if defined(XGBOOST_USE_CUDA) - /** - * @brief Get the device communicator. - * - * @param device_ordinal ID of the device. - * @return An instance of device communicator. - */ - static DeviceCommunicator *GetDevice(int device_ordinal); -#endif - - virtual ~Communicator() = default; - - /** @brief Get the total number of processes. */ - int GetWorldSize() const { return world_size_; } - - /** @brief Get the rank of the current processes. */ - int GetRank() const { return rank_; } - - /** @brief Whether the communicator is running in distributed mode. */ - virtual bool IsDistributed() const = 0; - - /** @brief Whether the communicator is running in federated mode. */ - virtual bool IsFederated() const = 0; - - /** - * @brief Gathers data from all processes and distributes it to all processes. - * - * This assumes all ranks have the same size. - * - * @param input Buffer storing the data. - */ - virtual std::string AllGather(std::string_view input) = 0; - - /** - * @brief Gathers variable-length data from all processes and distributes it to all processes. - * @param input Buffer storing the data. - */ - virtual std::string AllGatherV(std::string_view input) = 0; - - /** - * @brief Combines values from all processes and distributes the result back to all processes. - * - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - * @param data_type Data type stored in the buffer. - * @param op The operation to perform. - */ - virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) = 0; - - /** - * @brief Broadcasts a message from the process with rank `root` to all other processes of the - * group. - * - * @param send_receive_buffer Buffer storing the data. - * @param size Size of the data in bytes. - * @param root Rank of broadcast root. - */ - virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0; - - /** - * @brief Gets the name of the processor. - */ - virtual std::string GetProcessorName() = 0; - - /** - * @brief Prints the message. - */ - virtual void Print(std::string const &message) = 0; - - /** @brief Get the communicator type from environment variables. Visible for testing. */ - static CommunicatorType GetTypeFromEnv() { - auto *env = std::getenv("XGBOOST_COMMUNICATOR"); - if (env != nullptr) { - return StringToType(env); - } else { - return CommunicatorType::kUnknown; - } - } - - /** @brief Get the communicator type from runtime configuration. Visible for testing. */ - static CommunicatorType GetTypeFromConfig(Json const &config) { - auto const &j_upper = config["XGBOOST_COMMUNICATOR"]; - if (IsA(j_upper)) { - return StringToType(get(j_upper).c_str()); - } - auto const &j_lower = config["xgboost_communicator"]; - if (IsA(j_lower)) { - return StringToType(get(j_lower).c_str()); - } - return CommunicatorType::kUnknown; - } - - protected: - /** - * @brief Construct a new communicator. - * - * @param world_size Total number of processes. - * @param rank Rank of the current process. - */ - Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) { - if (world_size < 1) { - LOG(FATAL) << "World size " << world_size << " is less than 1."; - } - if (rank < 0) { - LOG(FATAL) << "Rank " << rank << " is less than 0."; - } - if (rank >= world_size) { - LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << "."; - } - } - - /** - * @brief Shuts down the communicator. - */ - virtual void Shutdown() = 0; - - private: - static CommunicatorType StringToType(char const *str) { - CommunicatorType result = CommunicatorType::kUnknown; - if (!CompareStringsCaseInsensitive("rabit", str)) { - result = CommunicatorType::kRabit; - } else if (!CompareStringsCaseInsensitive("federated", str)) { - result = CommunicatorType::kFederated; - } else if (!CompareStringsCaseInsensitive("in-memory", str)) { - result = CommunicatorType::kInMemory; - } else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) { - result = CommunicatorType::kInMemoryNccl; - } else { - LOG(FATAL) << "Unknown communicator type " << str; - } - return result; - } - - static thread_local std::unique_ptr communicator_; - static thread_local CommunicatorType type_; - static thread_local std::string nccl_path_; -#if defined(XGBOOST_USE_CUDA) - static thread_local std::unique_ptr device_communicator_; -#endif - - int const world_size_; - int const rank_; -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh deleted file mode 100644 index 69094b382..000000000 --- a/src/collective/device_communicator.cuh +++ /dev/null @@ -1,57 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include "../common/device_helpers.cuh" - -namespace xgboost { -namespace collective { - -/** - * @brief Collective communicator for device buffers. - */ -class DeviceCommunicator { - public: - virtual ~DeviceCommunicator() = default; - - /** - * @brief Combines values from all processes and distributes the result back to all processes. - * - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - * @param data_type Data type stored in the buffer. - * @param op The operation to perform. - */ - virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) = 0; - - /** - * @brief Gather values from all all processes. - * - * This assumes all ranks have the same size. - * - * @param send_buffer Buffer storing the data to be sent. - * @param receive_buffer Buffer storing the gathered data. - * @param send_size Size of the sent data in bytes. - */ - virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0; - - /** - * @brief Gather variable-length values from all processes. - * @param send_buffer Buffer storing the input data. - * @param length_bytes Length in bytes of the input data. - * @param segments Size of each segment. - * @param receive_buffer Buffer storing the output data. - */ - virtual void AllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) = 0; - - /** @brief Synchronize device operations. */ - virtual void Synchronize() = 0; -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh deleted file mode 100644 index 647c74b4e..000000000 --- a/src/collective/device_communicator_adapter.cuh +++ /dev/null @@ -1,94 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once - -#include // for accumulate - -#include "communicator.h" -#include "device_communicator.cuh" - -namespace xgboost { -namespace collective { - -class DeviceCommunicatorAdapter : public DeviceCommunicator { - public: - explicit DeviceCommunicatorAdapter(int device_ordinal) - : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} { - if (device_ordinal_ < 0) { - LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; - } - } - - ~DeviceCommunicatorAdapter() override = default; - - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto size = count * GetTypeSize(data_type); - host_buffer_.resize(size); - dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); - Allreduce(host_buffer_.data(), count, data_type, op); - dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); - } - - void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - host_buffer_.resize(send_size); - dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault)); - auto const output = Allgather(host_buffer_); - dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault)); - } - - void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *receive_buffer) override { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - - segments->clear(); - segments->resize(world_size_, 0); - segments->at(rank_) = length_bytes; - Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); - receive_buffer->resize(total_bytes); - - host_buffer_.resize(total_bytes); - size_t offset = 0; - for (int32_t i = 0; i < world_size_; ++i) { - size_t as_bytes = segments->at(i); - if (i == rank_) { - dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_), - cudaMemcpyDefault)); - } - Broadcast(host_buffer_.data() + offset, as_bytes, i); - offset += as_bytes; - } - dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, - cudaMemcpyDefault)); - } - - void Synchronize() override { - // Noop. - } - - private: - int const device_ordinal_; - int const world_size_; - int const rank_; - /// Host buffer used to call communicator functions. - std::vector host_buffer_{}; -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/in_memory_communicator.cc b/src/collective/in_memory_communicator.cc deleted file mode 100644 index 535a15bc9..000000000 --- a/src/collective/in_memory_communicator.cc +++ /dev/null @@ -1,12 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "in_memory_communicator.h" - -namespace xgboost { -namespace collective { - -InMemoryHandler InMemoryCommunicator::handler_{}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/in_memory_communicator.h b/src/collective/in_memory_communicator.h index c712d32a8..bd89be6e3 100644 --- a/src/collective/in_memory_communicator.h +++ b/src/collective/in_memory_communicator.h @@ -15,14 +15,14 @@ namespace collective { /** * An in-memory communicator, useful for testing. */ -class InMemoryCommunicator : public Communicator { +class InMemoryCommunicator { public: /** * @brief Create a new communicator based on JSON configuration. * @param config JSON configuration. * @return Communicator as specified by the JSON configuration. */ - static Communicator* Create(Json const& config) { + static InMemoryCommunicator* Create(Json const& config) { int world_size{0}; int rank{-1}; @@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator { return new InMemoryCommunicator(world_size, rank); } - InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) { + InMemoryCommunicator(int world_size, int rank) { handler_.Init(world_size, rank); } diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc index 944e5077b..468f09c53 100644 --- a/src/collective/in_memory_handler.cc +++ b/src/collective/in_memory_handler.cc @@ -1,14 +1,13 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #include "in_memory_handler.h" #include #include +#include "comm.h" -namespace xgboost { -namespace collective { - +namespace xgboost::collective { /** * @brief Functor for allgather. */ @@ -16,7 +15,7 @@ class AllgatherFunctor { public: std::string const name{"Allgather"}; - AllgatherFunctor(std::size_t world_size, std::size_t rank) + AllgatherFunctor(std::int32_t world_size, std::int32_t rank) : world_size_{world_size}, rank_{rank} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { @@ -30,8 +29,8 @@ class AllgatherFunctor { } private: - std::size_t world_size_; - std::size_t rank_; + std::int32_t world_size_; + std::int32_t rank_; }; /** @@ -41,13 +40,13 @@ class AllgatherVFunctor { public: std::string const name{"AllgatherV"}; - AllgatherVFunctor(std::size_t world_size, std::size_t rank, + AllgatherVFunctor(std::int32_t world_size, std::int32_t rank, std::map* data) : world_size_{world_size}, rank_{rank}, data_{data} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { data_->emplace(rank_, std::string_view{input, bytes}); - if (data_->size() == world_size_) { + if (data_->size() == static_cast(world_size_)) { for (auto const& kv : *data_) { buffer->append(kv.second); } @@ -56,8 +55,8 @@ class AllgatherVFunctor { } private: - std::size_t world_size_; - std::size_t rank_; + std::int32_t world_size_; + std::int32_t rank_; std::map* data_; }; @@ -68,7 +67,7 @@ class AllreduceFunctor { public: std::string const name{"Allreduce"}; - AllreduceFunctor(DataType dataType, Operation operation) + AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation) : data_type_{dataType}, operation_{operation} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { @@ -76,23 +75,23 @@ class AllreduceFunctor { // Copy the input if this is the first request. buffer->assign(input, bytes); } else { + auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); }); // Apply the reduce_operation to the input and the buffer. - Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front()); + Accumulate(input, bytes / n_bytes_type, &buffer->front()); } } private: template ::value>* = nullptr> - void AccumulateBitwise(T* buffer, T const* input, std::size_t size, - Operation reduce_operation) const { + void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const { switch (reduce_operation) { - case Operation::kBitwiseAND: + case Op::kBitwiseAND: std::transform(buffer, buffer + size, input, buffer, std::bit_and()); break; - case Operation::kBitwiseOR: + case Op::kBitwiseOR: std::transform(buffer, buffer + size, input, buffer, std::bit_or()); break; - case Operation::kBitwiseXOR: + case Op::kBitwiseXOR: std::transform(buffer, buffer + size, input, buffer, std::bit_xor()); break; default: @@ -101,27 +100,27 @@ class AllreduceFunctor { } template ::value>* = nullptr> - void AccumulateBitwise(T*, T const*, std::size_t, Operation) const { + void AccumulateBitwise(T*, T const*, std::size_t, Op) const { LOG(FATAL) << "Floating point types do not support bitwise operations."; } template - void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const { + void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const { switch (reduce_operation) { - case Operation::kMax: + case Op::kMax: std::transform(buffer, buffer + size, input, buffer, [](T a, T b) { return std::max(a, b); }); break; - case Operation::kMin: + case Op::kMin: std::transform(buffer, buffer + size, input, buffer, [](T a, T b) { return std::min(a, b); }); break; - case Operation::kSum: + case Op::kSum: std::transform(buffer, buffer + size, input, buffer, std::plus()); break; - case Operation::kBitwiseAND: - case Operation::kBitwiseOR: - case Operation::kBitwiseXOR: + case Op::kBitwiseAND: + case Op::kBitwiseOR: + case Op::kBitwiseXOR: AccumulateBitwise(buffer, input, size, reduce_operation); break; default: @@ -130,36 +129,37 @@ class AllreduceFunctor { } void Accumulate(char const* input, std::size_t size, char* buffer) const { + using Type = ArrayInterfaceHandler::Type; switch (data_type_) { - case DataType::kInt8: + case Type::kI1: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kUInt8: + case Type::kU1: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kInt32: + case Type::kI4: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kUInt32: + case Type::kU4: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kInt64: + case Type::kI8: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kUInt64: + case Type::kU8: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kFloat: + case Type::kF4: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kDouble: + case Type::kF8: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; @@ -169,8 +169,8 @@ class AllreduceFunctor { } private: - DataType data_type_; - Operation operation_; + ArrayInterfaceHandler::Type data_type_; + Op operation_; }; /** @@ -180,7 +180,7 @@ class BroadcastFunctor { public: std::string const name{"Broadcast"}; - BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {} + BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { if (rank_ == root_) { @@ -190,11 +190,11 @@ class BroadcastFunctor { } private: - std::size_t rank_; - std::size_t root_; + std::int32_t rank_; + std::int32_t root_; }; -void InMemoryHandler::Init(std::size_t world_size, std::size_t) { +void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) { CHECK(world_size_ < world_size) << "In memory handler already initialized."; std::unique_lock lock(mutex_); @@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) { cv_.notify_all(); } -void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) { +void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) { CHECK(world_size_ > 0) << "In memory handler already shutdown."; std::unique_lock lock(mutex_); @@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) { } void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank) { + std::size_t sequence_number, std::int32_t rank) { Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank}); } void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank) { + std::size_t sequence_number, std::int32_t rank) { Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_}); } void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, DataType data_type, - Operation op) { + std::size_t sequence_number, std::int32_t rank, + ArrayInterfaceHandler::Type data_type, Op op) { Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op}); } void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, std::size_t root) { + std::size_t sequence_number, std::int32_t rank, std::int32_t root) { Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root}); } template void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, + std::size_t sequence_number, std::int32_t rank, HandlerFunctor const& functor) { // Pass through if there is only 1 client. if (world_size_ == 1) { @@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* cv_.notify_all(); } } -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index e9c69f537..7c3465d08 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -1,16 +1,15 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #pragma once #include #include #include -#include "communicator.h" - -namespace xgboost { -namespace collective { +#include "../data/array_interface.h" +#include "comm.h" +namespace xgboost::collective { /** * @brief Handles collective communication primitives in memory. * @@ -28,12 +27,11 @@ class InMemoryHandler { /** * @brief Construct a handler with the given world size. - * @param world_size Number of workers. + * @param world Number of workers. * * This is used when the handler only needs to be initialized once with a known world size. */ - explicit InMemoryHandler(std::int32_t worldSize) - : world_size_{static_cast(worldSize)} {} + explicit InMemoryHandler(std::int32_t world) : world_size_{world} {} /** * @brief Initialize the handler with the world size and rank. @@ -43,7 +41,7 @@ class InMemoryHandler { * This is used when multiple objects/threads are accessing the same handler and need to * initialize it collectively. */ - void Init(std::size_t world_size, std::size_t rank); + void Init(std::int32_t world_size, std::int32_t rank); /** * @brief Shut down the handler. @@ -53,7 +51,7 @@ class InMemoryHandler { * This is used when multiple objects/threads are accessing the same handler and need to * shut it down collectively. */ - void Shutdown(uint64_t sequence_number, std::size_t rank); + void Shutdown(uint64_t sequence_number, std::int32_t rank); /** * @brief Perform allgather. @@ -64,7 +62,7 @@ class InMemoryHandler { * @param rank Index of the worker. */ void Allgather(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank); + std::size_t sequence_number, std::int32_t rank); /** * @brief Perform variable-length allgather. @@ -75,7 +73,7 @@ class InMemoryHandler { * @param rank Index of the worker. */ void AllgatherV(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank); + std::size_t sequence_number, std::int32_t rank); /** * @brief Perform allreduce. @@ -88,7 +86,8 @@ class InMemoryHandler { * @param op The reduce operation. */ void Allreduce(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op); + std::size_t sequence_number, std::int32_t rank, + ArrayInterfaceHandler::Type data_type, Op op); /** * @brief Perform broadcast. @@ -100,7 +99,7 @@ class InMemoryHandler { * @param root Index of the worker to broadcast from. */ void Broadcast(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, std::size_t root); + std::size_t sequence_number, std::int32_t rank, std::int32_t root); private: /** @@ -115,17 +114,15 @@ class InMemoryHandler { */ template void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number, - std::size_t rank, HandlerFunctor const& functor); + std::int32_t rank, HandlerFunctor const& functor); - std::size_t world_size_{}; /// Number of workers. - std::size_t received_{}; /// Number of calls received with the current sequence. - std::size_t sent_{}; /// Number of calls completed with the current sequence. + std::int32_t world_size_{}; /// Number of workers. + std::int64_t received_{}; /// Number of calls received with the current sequence. + std::int64_t sent_{}; /// Number of calls completed with the current sequence. std::string buffer_{}; /// A shared common buffer. std::map aux_{}; /// A shared auxiliary map. uint64_t sequence_number_{}; /// Call sequence number. mutable std::mutex mutex_; /// Lock. mutable std::condition_variable cv_; /// Conditional variable to wait on. }; - -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/src/collective/loop.cc b/src/collective/loop.cc index 0cd41426d..1c384bb28 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -6,6 +6,8 @@ #include // for size_t #include // for int32_t #include // for exception, current_exception, rethrow_exception +#include // for promise +#include // for make_shared #include // for lock_guard, unique_lock #include // for queue #include // for string @@ -18,9 +20,10 @@ #include "xgboost/logging.h" // for CHECK namespace xgboost::collective { -Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { +Result Loop::ProcessQueue(std::queue* p_queue) const { timer_.Start(__func__); - auto error = [this] { + auto error = [this](Op op) { + op.pr->set_value(); timer_.Stop(__func__); }; @@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { // Iterate through all the ops for poll for (std::size_t i = 0; i < n_ops; ++i) { - auto op = qcopy.front(); + auto op = std::move(qcopy.front()); qcopy.pop(); switch (op.code) { @@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { break; } default: { - error(); + error(op); return Fail("Invalid socket operation."); } } - qcopy.push(op); + qcopy.push(std::move(op)); } // poll, work on fds that are ready. @@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { if (!poll.fds.empty()) { auto rc = poll.Poll(timeout_); if (!rc.OK()) { - error(); + timer_.Stop(__func__); return rc; } } timer_.Stop("poll"); - // we wonldn't be here if the queue is empty. + // We wonldn't be here if the queue is empty. CHECK(!qcopy.empty()); // Iterate through all the ops for performing the operations for (std::size_t i = 0; i < n_ops; ++i) { - auto op = qcopy.front(); + auto op = std::move(qcopy.front()); qcopy.pop(); std::int32_t n_bytes_done{0}; @@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { if (poll.CheckRead(*op.sock)) { n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); if (n_bytes_done == 0) { - error(); - return Fail("Encountered EOF. The other end is likely closed."); + error(op); + return Fail("Encountered EOF. The other end is likely closed.", + op.sock->GetSockError()); } } break; @@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { break; } default: { - error(); + error(op); return Fail("Invalid socket operation."); } } if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { auto rc = system::FailWithCode("Invalid socket output."); - error(); + error(op); return rc; } @@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { CHECK_LE(op.off, op.n); if (op.off != op.n) { - // not yet finished, push back to queue for next round. + // not yet finished, push back to queue for the next round. qcopy.push(op); + } else { + op.pr->set_value(); } } - - if (!blocking) { - break; - } } timer_.Stop(__func__); @@ -148,8 +150,7 @@ void Loop::Process() { }; // This loop cannot exit unless `stop_` is set to true. There must always be a thread to - // answer the blocking call even if there are errors, otherwise the blocking will wait - // forever. + // answer the call even if there are errors. while (true) { try { std::unique_lock lock{mu_}; @@ -170,44 +171,15 @@ void Loop::Process() { // Move the global queue into a local variable to unblock it. std::queue qcopy; - bool is_blocking = false; while (!queue_.empty()) { - auto op = queue_.front(); + auto op = std::move(queue_.front()); queue_.pop(); - if (op.code == Op::kBlock) { - is_blocking = true; - } else { - qcopy.push(op); - } + qcopy.push(op); } - lock.unlock(); - // Clear the local queue, if `is_blocking` is true, this is blocking the current - // worker thread (but not the client thread), wait until all operations are - // finished. - auto rc = this->ProcessQueue(&qcopy, is_blocking); - if (is_blocking && rc.OK()) { - CHECK(qcopy.empty()); - } - // Push back the remaining operations. - if (rc.OK()) { - std::unique_lock lock{mu_}; - while (!qcopy.empty()) { - queue_.push(qcopy.front()); - qcopy.pop(); - } - } - - // Notify the client thread who called block after all error conditions are set. - auto notify_if_block = [&] { - if (is_blocking) { - std::unique_lock lock{mu_}; - block_done_ = true; - lock.unlock(); - block_cv_.notify_one(); - } - }; + // Clear the local queue. + auto rc = this->ProcessQueue(&qcopy); // Handle error if (!rc.OK()) { @@ -215,8 +187,6 @@ void Loop::Process() { } else { CHECK(qcopy.empty()); } - - notify_if_block(); } catch (std::exception const& e) { curr_exce_ = std::current_exception(); set_rc(Fail("Exception inside the event loop:" + std::string{e.what()})); @@ -256,20 +226,28 @@ Result Loop::Stop() { stop_ = true; } } - if (!this->worker_.joinable()) { std::lock_guard guard{rc_lock_}; return Fail("Worker has stopped.", std::move(rc_)); } - this->Submit(Op{Op::kBlock}); { - // Wait for the block call to finish. std::unique_lock lock{mu_}; - block_cv_.wait(lock, [this] { return block_done_ || stop_; }); - block_done_ = false; + cv_.notify_one(); } + for (auto& fut : futures_) { + if (fut.valid()) { + try { + fut.get(); + } catch (std::future_error const&) { + // Do nothing. If something went wrong in the worker, we have a std::future_error + // due to broken promise. This function will transfer the rc back to the caller. + } + } + } + futures_.clear(); + { // Transfer the rc. std::lock_guard lock{rc_lock_}; @@ -278,13 +256,13 @@ Result Loop::Stop() { } void Loop::Submit(Op op) { + auto p = std::make_shared>(); + op.pr = std::move(p); + futures_.emplace_back(op.pr->get_future()); + CHECK_NE(op.n, 0); + std::unique_lock lock{mu_}; - if (op.code != Op::kBlock) { - CHECK_NE(op.n, 0); - } queue_.push(op); - lock.unlock(); - cv_.notify_one(); } Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { diff --git a/src/collective/loop.h b/src/collective/loop.h index a4de2a81b..0a830eb96 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -7,9 +7,12 @@ #include // for size_t #include // for int8_t, int32_t #include // for exception_ptr -#include // for unique_lock, mutex +#include // for future +#include // for shared_ptr +#include // for mutex #include // for queue #include // for thread +#include // for vector #include "../common/timer.h" // for Monitor #include "xgboost/collective/result.h" // for Result @@ -20,14 +23,15 @@ class Loop { public: struct Op { // kSleep is only for testing - enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code; + enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code; std::int32_t rank{-1}; std::int8_t* ptr{nullptr}; std::size_t n{0}; TCPSocket* sock{nullptr}; std::size_t off{0}; + std::shared_ptr> pr; - explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); } + explicit Op(Code c) : code{c} { CHECK(c == kSleep); } Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} Op(Op const&) = default; @@ -45,12 +49,11 @@ class Loop { private: std::thread worker_; // thread worker to execute the tasks - std::condition_variable cv_; // CV used to notify a new submit call - std::condition_variable block_cv_; // CV used to notify the blocking call - bool block_done_{false}; // Flag to indicate whether the blocking call has finished. + std::condition_variable cv_; // CV used to notify a new submit call std::queue queue_; // event queue - std::mutex mu_; // mutex to protect the queue, cv, and block_done + std::vector> futures_; + std::mutex mu_; // mutex to protect the queue, cv, and block_done std::chrono::seconds timeout_; @@ -61,7 +64,7 @@ class Loop { std::exception_ptr curr_exce_{nullptr}; common::Monitor mutable timer_; - Result ProcessQueue(std::queue* p_queue, bool blocking) const; + Result ProcessQueue(std::queue* p_queue) const; // The cunsumer function that runs inside a worker thread. void Process(); diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu deleted file mode 100644 index b896e7d06..000000000 --- a/src/collective/nccl_device_communicator.cu +++ /dev/null @@ -1,243 +0,0 @@ -/*! - * Copyright 2023 XGBoost contributors - */ -#if defined(XGBOOST_USE_NCCL) -#include // for accumulate - -#include "comm.cuh" -#include "nccl_device_communicator.cuh" - -namespace xgboost { -namespace collective { - -NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync, - StringView nccl_path) - : device_ordinal_{device_ordinal}, - needs_sync_{needs_sync}, - world_size_{GetWorldSize()}, - rank_{GetRank()} { - if (device_ordinal_ < 0) { - LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; - } - if (world_size_ == 1) { - return; - } - stub_ = std::make_shared(std::move(nccl_path)); - - std::vector uuids(world_size_ * kUuidLength, 0); - auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; - auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength); - GetCudaUUID(s_this_uuid); - - // TODO(rongou): replace this with allgather. - Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); - - std::vector> converted(world_size_); - size_t j = 0; - for (size_t i = 0; i < uuids.size(); i += kUuidLength) { - converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; - j++; - } - - auto iter = std::unique(converted.begin(), converted.end()); - auto n_uniques = std::distance(converted.begin(), iter); - - CHECK_EQ(n_uniques, world_size_) - << "Multiple processes within communication group running on same CUDA " - << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; - - nccl_unique_id_ = GetUniqueId(); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_); - CHECK(rc.OK()) << rc.Report(); -} - -NcclDeviceCommunicator::~NcclDeviceCommunicator() { - if (world_size_ == 1) { - return; - } - if (nccl_comm_) { - auto rc = stub_->CommDestroy(nccl_comm_); - CHECK(rc.OK()) << rc.Report(); - } - if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { - LOG(CONSOLE) << "======== NCCL Statistics========"; - LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; - LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576; - } -} - -namespace { -ncclDataType_t GetNcclDataType(DataType const &data_type) { - ncclDataType_t result{ncclInt8}; - switch (data_type) { - case DataType::kInt8: - result = ncclInt8; - break; - case DataType::kUInt8: - result = ncclUint8; - break; - case DataType::kInt32: - result = ncclInt32; - break; - case DataType::kUInt32: - result = ncclUint32; - break; - case DataType::kInt64: - result = ncclInt64; - break; - case DataType::kUInt64: - result = ncclUint64; - break; - case DataType::kFloat: - result = ncclFloat; - break; - case DataType::kDouble: - result = ncclDouble; - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return result; -} - -bool IsBitwiseOp(Operation const &op) { - return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR || - op == Operation::kBitwiseXOR; -} - -ncclRedOp_t GetNcclRedOp(Operation const &op) { - ncclRedOp_t result{ncclMax}; - switch (op) { - case Operation::kMax: - result = ncclMax; - break; - case Operation::kMin: - result = ncclMin; - break; - case Operation::kSum: - result = ncclSum; - break; - default: - LOG(FATAL) << "Unsupported reduce operation."; - } - return result; -} - -template -void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size, - std::size_t size) { - dh::LaunchN(size, [=] __device__(std::size_t idx) { - auto result = device_buffer[idx]; - for (auto rank = 1; rank < world_size; rank++) { - result = func(result, device_buffer[rank * size + idx]); - } - out_buffer[idx] = result; - }); -} -} // anonymous namespace - -void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count, - DataType data_type, Operation op) { - auto const size = count * GetTypeSize(data_type); - dh::caching_device_vector buffer(size * world_size_); - auto *device_buffer = buffer.data().get(); - - // First gather data from all the workers. - auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), - nccl_comm_, dh::DefaultStream()); - CHECK(rc.OK()) << rc.Report(); - if (needs_sync_) { - dh::DefaultStream().Sync(); - } - - // Then reduce locally. - auto *out_buffer = static_cast(send_receive_buffer); - switch (op) { - case Operation::kBitwiseAND: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size_, size); - break; - case Operation::kBitwiseOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size_, size); - break; - case Operation::kBitwiseXOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size_, size); - break; - default: - LOG(FATAL) << "Not a bitwise reduce operation."; - } -} - -void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count, - DataType data_type, Operation op) { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - if (IsBitwiseOp(op)) { - BitwiseAllReduce(send_receive_buffer, count, data_type, op); - } else { - auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count, - GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, - dh::DefaultStream()); - CHECK(rc.OK()) << rc.Report(); - } - allreduce_bytes_ += count * GetTypeSize(data_type); - allreduce_calls_ += 1; -} - -void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer, - std::size_t send_size) { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_, - dh::DefaultStream()); - CHECK(rc.OK()) << rc.Report(); -} - -void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - - segments->clear(); - segments->resize(world_size_, 0); - segments->at(rank_) = length_bytes; - Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); - receive_buffer->resize(total_bytes); - - size_t offset = 0; - auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] { - for (int32_t i = 0; i < world_size_; ++i) { - size_t as_bytes = segments->at(i); - auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, - ncclChar, i, nccl_comm_, dh::DefaultStream()); - if (!rc.OK()) { - return rc; - } - offset += as_bytes; - } - return Success(); - } << [&] { return stub_->GroupEnd(); }; -} - -void NcclDeviceCommunicator::Synchronize() { - if (world_size_ == 1) { - return; - } - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::DefaultStream().Sync(); -} - -} // namespace collective -} // namespace xgboost -#endif diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh deleted file mode 100644 index ef431b571..000000000 --- a/src/collective/nccl_device_communicator.cuh +++ /dev/null @@ -1,91 +0,0 @@ -/*! - * Copyright 2022-2023 XGBoost contributors - */ -#pragma once - -#include "../common/device_helpers.cuh" -#include "comm.cuh" -#include "communicator.h" -#include "device_communicator.cuh" -#include "nccl_stub.h" - -namespace xgboost { -namespace collective { - -class NcclDeviceCommunicator : public DeviceCommunicator { - public: - /** - * @brief Construct a new NCCL communicator. - * @param device_ordinal The GPU device id. - * @param needs_sync Whether extra CUDA stream synchronization is needed. - * - * In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes - * a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization - * makes sure that the NCCL kernels are caught up, thus avoiding the deadlock. - * - * The Rabit communicator runs with one process per GPU, so the additional synchronization is not - * needed. The in-memory communicator is used in tests with multiple threads, each thread - * representing a rank/worker, so the additional synchronization is needed to avoid deadlocks. - */ - explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path); - ~NcclDeviceCommunicator() override; - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override; - void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override; - void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *receive_buffer) override; - void Synchronize() override; - - private: - static constexpr std::size_t kUuidLength = - sizeof(std::declval().uuid) / sizeof(uint64_t); - - void GetCudaUUID(xgboost::common::Span const &uuid) const { - cudaDeviceProp prob{}; - dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_)); - std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); - } - - static std::string PrintUUID(xgboost::common::Span const &uuid) { - std::stringstream ss; - for (auto v : uuid) { - ss << std::hex << v; - } - return ss.str(); - } - - /** - * \fn ncclUniqueId GetUniqueId() - * - * \brief Gets the Unique ID from NCCL to be used in setting up interprocess - * communication - * - * \return the Unique ID - */ - ncclUniqueId GetUniqueId() { - static const int kRootRank = 0; - ncclUniqueId id; - if (rank_ == kRootRank) { - auto rc = stub_->GetUniqueId(&id); - CHECK(rc.OK()) << rc.Report(); - } - Broadcast(static_cast(&id), sizeof(ncclUniqueId), static_cast(kRootRank)); - return id; - } - - void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op); - - int const device_ordinal_; - bool const needs_sync_; - int const world_size_; - int const rank_; - ncclComm_t nccl_comm_{}; - std::shared_ptr stub_; - ncclUniqueId nccl_unique_id_{}; - size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. - size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h deleted file mode 100644 index 2d88fd802..000000000 --- a/src/collective/noop_communicator.h +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include "communicator.h" - -namespace xgboost { -namespace collective { - -/** - * A no-op communicator, used for non-distributed training. - */ -class NoOpCommunicator : public Communicator { - public: - NoOpCommunicator() : Communicator(1, 0) {} - bool IsDistributed() const override { return false; } - bool IsFederated() const override { return false; } - std::string AllGather(std::string_view) override { return {}; } - std::string AllGatherV(std::string_view) override { return {}; } - void AllReduce(void *, std::size_t, DataType, Operation) override {} - void Broadcast(void *, std::size_t, int) override {} - std::string GetProcessorName() override { return {}; } - void Print(const std::string &message) override { LOG(CONSOLE) << message; } - - protected: - void Shutdown() override {} -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/protocol.h b/src/collective/protocol.h index 29e6c9619..222259403 100644 --- a/src/collective/protocol.h +++ b/src/collective/protocol.h @@ -41,20 +41,26 @@ struct Magic { [[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) { std::int32_t magic{kMagic}; - auto n_bytes = p_sock->SendAll(&magic, sizeof(magic)); - if (n_bytes != sizeof(magic)) { - return Fail("Failed to verify."); - } - - magic = 0; - n_bytes = p_sock->RecvAll(&magic, sizeof(magic)); - if (n_bytes != sizeof(magic)) { - return Fail("Failed to verify."); - } - if (magic != kMagic) { - return xgboost::collective::Fail("Invalid verification number."); - } - return Success(); + std::size_t n_sent{0}; + return Success() << [&] { + return p_sock->SendAll(&magic, sizeof(magic), &n_sent); + } << [&] { + if (n_sent != sizeof(magic)) { + return Fail("Failed to verify."); + } + return Success(); + } << [&] { + magic = 0; + return p_sock->RecvAll(&magic, sizeof(magic), &n_sent); + } << [&] { + if (n_sent != sizeof(magic)) { + return Fail("Failed to verify."); + } + if (magic != kMagic) { + return xgboost::collective::Fail("Invalid verification number."); + } + return Success(); + }; } }; @@ -227,31 +233,43 @@ struct Error { [[nodiscard]] Result SignalError(TCPSocket* worker) const { std::int32_t err{ErrorSignal()}; - auto n_sent = worker->SendAll(&err, sizeof(err)); - if (n_sent == sizeof(err)) { - return Success(); - } - return Fail("Failed to send error signal"); + std::size_t n_sent{0}; + return Success() << [&] { + return worker->SendAll(&err, sizeof(err), &n_sent); + } << [&] { + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send error signal"); + }; } // self is localhost, we are sending the signal to the error handling thread for it to // close. [[nodiscard]] Result SignalShutdown(TCPSocket* self) const { std::int32_t err{ShutdownSignal()}; - auto n_sent = self->SendAll(&err, sizeof(err)); - if (n_sent == sizeof(err)) { - return Success(); - } - return Fail("Failed to send shutdown signal"); + std::size_t n_sent{0}; + return Success() << [&] { + return self->SendAll(&err, sizeof(err), &n_sent); + } << [&] { + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send shutdown signal"); + }; } // get signal, either for error or for shutdown. [[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const { std::int32_t err{ShutdownSignal()}; - auto n_recv = peer->RecvAll(&err, sizeof(err)); - if (n_recv == sizeof(err)) { - *p_is_error = err == 1; - return Success(); - } - return Fail("Failed to receive error signal."); + std::size_t n_recv{0}; + return Success() << [&] { + return peer->RecvAll(&err, sizeof(err), &n_recv); + } << [&] { + if (n_recv == sizeof(err)) { + *p_is_error = err == 1; + return Success(); + } + return Fail("Failed to receive error signal."); + }; } }; } // namespace xgboost::collective::proto diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h deleted file mode 100644 index 452e9ad9c..000000000 --- a/src/collective/rabit_communicator.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2022-2023 by XGBoost contributors - */ -#pragma once -#include - -#include -#include - -#include "communicator-inl.h" -#include "communicator.h" -#include "xgboost/json.h" - -namespace xgboost { -namespace collective { - -class RabitCommunicator : public Communicator { - public: - static Communicator *Create(Json const &config) { - std::vector args_str; - for (auto &items : get(config)) { - switch (items.second.GetValue().Type()) { - case xgboost::Value::ValueKind::kString: { - args_str.push_back(items.first + "=" + get(items.second)); - break; - } - case xgboost::Value::ValueKind::kInteger: { - args_str.push_back(items.first + "=" + std::to_string(get(items.second))); - break; - } - case xgboost::Value::ValueKind::kBoolean: { - if (get(items.second)) { - args_str.push_back(items.first + "=1"); - } else { - args_str.push_back(items.first + "=0"); - } - break; - } - default: - break; - } - } - std::vector args; - for (auto &key_value : args_str) { - args.push_back(&key_value[0]); - } - if (!rabit::Init(static_cast(args.size()), &args[0])) { - LOG(FATAL) << "Failed to initialize Rabit"; - } - return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); - } - - RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} - - bool IsDistributed() const override { return rabit::IsDistributed(); } - - bool IsFederated() const override { return false; } - - std::string AllGather(std::string_view input) override { - auto const per_rank = input.size(); - auto const total_size = per_rank * GetWorldSize(); - auto const index = per_rank * GetRank(); - std::string result(total_size, '\0'); - result.replace(index, per_rank, input); - rabit::Allgather(result.data(), total_size, index, per_rank, per_rank); - return result; - } - - std::string AllGatherV(std::string_view input) override { - auto const size_node_slice = input.size(); - auto const all_sizes = collective::Allgather(size_node_slice); - auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul); - auto const begin_index = - std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul); - auto const size_prev_slice = - GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1]; - - std::string result(total_size, '\0'); - result.replace(begin_index, size_node_slice, input); - rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice); - return result; - } - - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - switch (data_type) { - case DataType::kInt8: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kUInt8: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kInt32: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kUInt32: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kInt64: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kUInt64: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kFloat: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kDouble: - DoAllReduce(send_receive_buffer, count, op); - break; - default: - LOG(FATAL) << "Unknown data type"; - } - } - - void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { - rabit::Broadcast(send_receive_buffer, size, root); - } - - std::string GetProcessorName() override { return rabit::GetProcessorName(); } - - void Print(const std::string &message) override { rabit::TrackerPrint(message); } - - protected: - void Shutdown() override { rabit::Finalize(); } - - private: - template ::value> * = nullptr> - void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { - switch (op) { - case Operation::kBitwiseAND: - rabit::Allreduce(static_cast(send_receive_buffer), - count); - break; - case Operation::kBitwiseOR: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kBitwiseXOR: - rabit::Allreduce(static_cast(send_receive_buffer), - count); - break; - default: - LOG(FATAL) << "Unknown allreduce operation"; - } - } - - template ::value> * = nullptr> - void DoBitwiseAllReduce(void *, std::size_t, Operation) { - LOG(FATAL) << "Floating point types do not support bitwise operations."; - } - - template - void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { - switch (op) { - case Operation::kMax: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kMin: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kSum: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kBitwiseAND: - case Operation::kBitwiseOR: - case Operation::kBitwiseXOR: - DoBitwiseAllReduce(send_receive_buffer, count, op); - break; - default: - LOG(FATAL) << "Unknown allreduce operation"; - } - } -}; -} // namespace collective -} // namespace xgboost diff --git a/src/collective/result.cc b/src/collective/result.cc index b11710572..140efa6d8 100644 --- a/src/collective/result.cc +++ b/src/collective/result.cc @@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr rhs) { ptr->prev = std::move(rhs); } -#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) -std::string MakeMsg(std::string&& msg, char const*, std::int32_t) { - return std::forward(msg); -} -#else std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) { - auto name = std::filesystem::path{file}.filename(); + dmlc::DateLogger logger; if (file && line != -1) { - return "[" + name.string() + ":" + std::to_string(line) + // NOLINT + auto name = std::filesystem::path{ file }.filename(); + return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() + "]: " + std::forward(msg); } - return std::forward(msg); + return std::string{"["} + logger.HumanDate() + "]" + std::forward(msg); // NOLINT } -#endif } // namespace detail void SafeColl(Result const& rc) { diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 737ce584e..99b02f665 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) { CHECK(!this->IsClosed()); CHECK_LT(str.size(), std::numeric_limits::max()); std::int32_t len = static_cast(str.size()); - CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length."; - auto bytes = this->SendAll(str.c_str(), str.size()); - CHECK_EQ(bytes, str.size()) << "Failed to send string."; - return bytes; + std::size_t n_bytes{0}; + auto rc = Success() << [&] { + return this->SendAll(&len, sizeof(len), &n_bytes); + } << [&] { + if (n_bytes != sizeof(len)) { + return Fail("Failed to send string length."); + } + return Success(); + } << [&] { + return this->SendAll(str.c_str(), str.size(), &n_bytes); + } << [&] { + if (n_bytes != str.size()) { + return Fail("Failed to send string."); + } + return Success(); + }; + SafeColl(rc); + return n_bytes; } [[nodiscard]] Result TCPSocket::Recv(std::string *p_str) { CHECK(!this->IsClosed()); std::int32_t len; - if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) { - return Fail("Failed to recv string length."); - } - p_str->resize(len); - auto bytes = this->RecvAll(&(*p_str)[0], len); - if (static_cast(bytes) != len) { - return Fail("Failed to recv string."); - } - return Success(); + std::size_t n_bytes{0}; + return Success() << [&] { + return this->RecvAll(&len, sizeof(len), &n_bytes); + } << [&] { + if (n_bytes != sizeof(len)) { + return Fail("Failed to recv string length."); + } + return Success(); + } << [&] { + p_str->resize(len); + return this->RecvAll(&(*p_str)[0], len, &n_bytes); + } << [&] { + if (static_cast>(n_bytes) != len) { + return Fail("Failed to recv string."); + } + return Success(); + }; } [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 142483ccf..756f6bb7f 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -31,14 +31,20 @@ #include "xgboost/json.h" // for Json namespace xgboost::collective { + Tracker::Tracker(Json const& config) : sortby_{static_cast( OptionalArg(config, "sortby", static_cast(SortBy::kHost)))}, n_workers_{ static_cast(RequiredArg(config, "n_workers", __func__))}, port_{static_cast(OptionalArg(config, "port", Integer::Int{0}))}, - timeout_{std::chrono::seconds{OptionalArg( - config, "timeout", static_cast(collective::DefaultTimeoutSec()))}} {} + timeout_{std::chrono::seconds{ + OptionalArg(config, "timeout", static_cast(0))}} { + using std::chrono_literals::operator""s; + // Some old configurations in JVM for the scala implementation (removed) use 0 to + // indicate blocking. We continue that convention here. + timeout_ = (timeout_ == 0s) ? -1s : timeout_; +} Result Tracker::WaitUntilReady() const { using namespace std::chrono_literals; // NOLINT @@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const { timer.Start(); while (!this->Ready()) { auto ela = timer.Duration().count(); - if (ela > this->Timeout().count()) { + if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) { return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) + " seconds."); } @@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { std::lock_guard lock{listener_mu_}; return listener_.NonBlocking(true); } << [&] { - std::lock_guard lock{listener_mu_}; - poll.WatchRead(listener_); + { + std::lock_guard lock{listener_mu_}; + poll.WatchRead(listener_); + } if (state.running) { // Don't timeout if the communicator group is up and running. return poll.Poll(std::chrono::seconds{-1}); diff --git a/src/collective/tracker.h b/src/collective/tracker.h index af30e0be7..b81cf6559 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -15,6 +15,7 @@ #include "xgboost/json.h" // for Json namespace xgboost::collective { +inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; } /** * * @brief Implementation of RABIT tracker. @@ -52,7 +53,7 @@ class Tracker { protected: std::int32_t n_workers_{0}; std::int32_t port_{-1}; - std::chrono::seconds timeout_{0}; + std::chrono::seconds timeout_{-1}; std::atomic ready_{false}; public: diff --git a/src/common/io.h b/src/common/io.h index 5e9d27582..d2fcc9f92 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file io.h * \brief general stream interface for serialization, I/O * \author Tianqi Chen @@ -8,7 +8,6 @@ #define XGBOOST_COMMON_IO_H_ #include -#include // for MemoryFixSizeBuffer, MemoryBufferStream #include // for min, fill_n, copy_n #include // for array @@ -23,12 +22,99 @@ #include // for move #include // for vector -#include "common.h" +#include "common.h" // for DivRoundUp #include "xgboost/string_view.h" // for StringView namespace xgboost::common { -using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer; -using MemoryBufferStream = rabit::utils::MemoryBufferStream; +struct MemoryFixSizeBuffer : public dmlc::SeekStream { + public: + // similar to SEEK_END in libc + static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); + + public: + /** + * @brief Ctor + * + * @param p_buffer Pointer to the source buffer with size `buffer_size`. + * @param buffer_size Size of the source buffer + */ + MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) {} + ~MemoryFixSizeBuffer() override = default; + + std::size_t Read(void *ptr, std::size_t size) override { + std::size_t nread = std::min(buffer_size_ - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + void Write(const void *ptr, std::size_t size) override { + if (size == 0) return; + CHECK_LE(curr_ptr_ + size, buffer_size_); + std::memcpy(p_buffer_ + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + void Seek(std::size_t pos) override { + if (pos == kSeekEnd) { + curr_ptr_ = buffer_size_; + } else { + curr_ptr_ = static_cast(pos); + } + } + /** + * @brief Current position in the buffer (stream). + */ + std::size_t Tell() override { return curr_ptr_; } + [[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } + + protected: + /*! \brief in memory buffer */ + char *p_buffer_{nullptr}; + /*! \brief current pointer */ + std::size_t buffer_size_{0}; + /*! \brief current pointer */ + std::size_t curr_ptr_{0}; +}; + +/*! \brief a in memory buffer that can be read and write as stream interface */ +struct MemoryBufferStream : public dmlc::SeekStream { + public: + explicit MemoryBufferStream(std::string *p_buffer) + : p_buffer_(p_buffer) { + curr_ptr_ = 0; + } + ~MemoryBufferStream() override = default; + size_t Read(void *ptr, size_t size) override { + CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length"; + size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + void Write(const void *ptr, size_t size) override { + if (size == 0) return; + if (curr_ptr_ + size > p_buffer_->length()) { + p_buffer_->resize(curr_ptr_+size); + } + std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + void Seek(size_t pos) override { + curr_ptr_ = static_cast(pos); + } + size_t Tell() override { + return curr_ptr_; + } + virtual bool AtEnd() const { + return curr_ptr_ == p_buffer_->length(); + } + + private: + /*! \brief in memory buffer */ + std::string *p_buffer_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryBufferStream /*! * \brief Input stream that support additional PeekRead operation, diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 8c743d940..05e2f762c 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -116,19 +116,19 @@ INSTANTIATE(ColumnarAdapterBatch) namespace { /** - * \brief A view over gathered sketch values. + * @brief A view over gathered sketch values. */ template struct QuantileAllreduce { common::Span global_values; common::Span worker_indptr; common::Span feature_indptr; - size_t n_features{0}; + bst_feature_t n_features{0}; /** - * \brief Get sketch values of the a feature from a worker. + * @brief Get sketch values of the a feature from a worker. * - * \param rank rank of target worker - * \param fidx feature idx + * @param rank rank of target worker + * @param fidx feature idx */ [[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const { // get span for worker @@ -154,7 +154,7 @@ void SketchContainerImpl::GatherSketchInfo( worker_segments.resize(1, 0); auto world = collective::GetWorldSize(); auto rank = collective::GetRank(); - auto n_columns = sketches_.size(); + bst_feature_t n_columns = sketches_.size(); // get the size of each feature. std::vector sketch_size; @@ -165,7 +165,7 @@ void SketchContainerImpl::GatherSketchInfo( sketch_size.push_back(reduced[i].size); } } - // turn the size into CSC indptr + // Turn the size into CSC indptr std::vector &sketches_scan = *p_sketches_scan; sketches_scan.resize((n_columns + 1) * world, 0); size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker. @@ -174,7 +174,10 @@ void SketchContainerImpl::GatherSketchInfo( // Gather all column pointers auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size())); - collective::SafeColl(rc); + if (!rc.OK()) { + collective::SafeColl(collective::Fail("Failed to get sketch scan.", std::move(rc))); + } + for (int32_t i = 0; i < world; ++i) { size_t back = (i + 1) * (n_columns + 1) - 1; auto n_entries = sketches_scan.at(back); @@ -206,7 +209,9 @@ void SketchContainerImpl::GatherSketchInfo( ctx, info, linalg::MakeVec(reinterpret_cast(global_sketches.data()), global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float))); - collective::SafeColl(rc); + if (!rc.OK()) { + collective::SafeColl(collective::Fail("Failed to get sketch.", std::move(rc))); + } } template @@ -260,7 +265,7 @@ void SketchContainerImpl::AllreduceCategories(Context const* ctx, Meta rc = collective::GlobalSum(ctx, info, linalg::MakeVec(global_categories.data(), global_categories.size())); QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, - categories_.size()}; + static_cast(categories_.size())}; ParallelFor(categories_.size(), n_threads_, [&](auto fidx) { if (!IsCat(feature_types_, fidx)) { return; @@ -285,8 +290,9 @@ void SketchContainerImpl::AllReduce( std::vector *p_reduced, std::vector *p_num_cuts) { monitor_.Start(__func__); - size_t n_columns = sketches_.size(); - collective::Allreduce(&n_columns, 1); + bst_feature_t n_columns = sketches_.size(); + auto rc = collective::Allreduce(ctx, &n_columns, collective::Op::kMax); + collective::SafeColl(rc); CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers"; AllreduceCategories(ctx, info); @@ -300,8 +306,8 @@ void SketchContainerImpl::AllReduce( // Prune the intermediate num cuts for synchronization. std::vector global_column_size(columns_size_); - auto rc = collective::GlobalSum( - ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size())); + rc = collective::GlobalSum(ctx, info, + linalg::MakeVec(global_column_size.data(), global_column_size.size())); collective::SafeColl(rc); ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { diff --git a/src/common/quantile.cu b/src/common/quantile.cu index e7f09fc4d..d0356ae42 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -12,7 +12,8 @@ #include // for partial_sum #include -#include "../collective/communicator-inl.cuh" +#include "../collective/allgather.h" +#include "../collective/allreduce.h" #include "categorical.h" #include "common.h" #include "device_helpers.cuh" @@ -499,7 +500,7 @@ void SketchContainer::FixError() { }); } -void SketchContainer::AllReduce(Context const*, bool is_column_split) { +void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { dh::safe_cuda(cudaSetDevice(device_.ordinal)); auto world = collective::GetWorldSize(); if (world == 1 || is_column_split) { @@ -508,16 +509,18 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) { timer_.Start(__func__); // Reduce the overhead on syncing. - size_t global_sum_rows = num_rows_; - collective::Allreduce(&global_sum_rows, 1); - size_t intermediate_num_cuts = + bst_idx_t global_sum_rows = num_rows_; + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&global_sum_rows, 1), collective::Op::kSum); + SafeColl(rc); + bst_idx_t intermediate_num_cuts = std::min(global_sum_rows, static_cast(num_bins_ * kFactor)); this->Prune(intermediate_num_cuts); auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1); size_t n = d_columns_ptr.size(); - collective::Allreduce(&n, 1); + rc = collective::Allreduce(ctx, linalg::MakeVec(&n, 1), collective::Op::kMax); + SafeColl(rc); CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers"; // Get the columns ptr from all workers @@ -527,18 +530,25 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) { auto offset = rank * d_columns_ptr.size(); thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), gathered_ptrs.begin() + offset); - collective::AllReduce(device_.ordinal, gathered_ptrs.data().get(), - gathered_ptrs.size()); + rc = collective::Allreduce( + ctx, linalg::MakeVec(gathered_ptrs.data().get(), gathered_ptrs.size(), ctx->Device()), + collective::Op::kSum); + SafeColl(rc); // Get the data from all workers. - std::vector recv_lengths; - dh::caching_device_vector recvbuf; - collective::AllGatherV(device_.ordinal, this->Current().data().get(), - dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf); - collective::Synchronize(device_.ordinal); + std::vector recv_lengths; + HostDeviceVector recvbuf; + rc = collective::AllgatherV( + ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_), + &recv_lengths, &recvbuf); + collective::SafeColl(rc); + for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) { + recv_lengths[i] = recv_lengths[i + 1] - recv_lengths[i]; + } + recv_lengths.resize(recv_lengths.size() - 1); // Segment the received data. - auto s_recvbuf = dh::ToSpan(recvbuf); + auto s_recvbuf = recvbuf.DeviceSpan(); std::vector> allworkers; offset = 0; for (int32_t i = 0; i < world; ++i) { diff --git a/src/common/random.h b/src/common/random.h index ece6fa46f..6d7a1bb49 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -1,5 +1,5 @@ /** - * Copyright 2015-2020, XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file random.h * \brief Utility related to random. * \author Tianqi Chen @@ -19,11 +19,13 @@ #include #include +#include "../collective/broadcast.h" // for Broadcast #include "../collective/communicator-inl.h" #include "algorithm.h" // ArgSort #include "common.h" #include "xgboost/context.h" // Context #include "xgboost/host_device_vector.h" +#include "xgboost/linalg.h" namespace xgboost::common { /*! @@ -227,9 +229,10 @@ class ColumnSampler { } }; -inline auto MakeColumnSampler(Context const*) { +inline auto MakeColumnSampler(Context const* ctx) { std::uint32_t seed = common::GlobalRandomEngine()(); - collective::Broadcast(&seed, sizeof(seed), 0); + auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0); + collective::SafeColl(rc); auto cs = std::make_shared(seed); return cs; } diff --git a/src/data/array_interface.h b/src/data/array_interface.h index d645c9e75..fafe0b6ac 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -615,7 +615,12 @@ auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) { case ArrayInterfaceHandler::kF16: { using T = long double; CHECK(sizeof(T) == 16) << error::NoF128(); - return dispatch(T{}); + // Avoid invalid type. + if constexpr (sizeof(T) == 16) { + return dispatch(T{}); + } else { + return dispatch(double{}); + } } case ArrayInterfaceHandler::kI1: { return dispatch(std::int8_t{}); diff --git a/src/data/data.cc b/src/data/data.cc index 22854def8..f37a10fa3 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -18,7 +18,8 @@ #include // for remove_pointer_t, remove_reference #include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated -#include "../collective/communicator.h" // for Operation +#include "../collective/allgather.h" +#include "../collective/allreduce.h" #include "../common/algorithm.h" // for StableSort #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry #include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData @@ -601,41 +602,42 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, } void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) { - if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) { + bool is_col_split = this->IsColumnSplit(); + + if (size != 0 && this->num_col_ != 0 && !is_col_split) { CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns."; CHECK(info); } - if (!std::strcmp(key, "feature_type")) { - feature_type_names.clear(); - for (size_t i = 0; i < size; ++i) { - auto elem = info[i]; - feature_type_names.emplace_back(elem); - } - if (IsColumnSplit()) { - feature_type_names = collective::AllgatherStrings(feature_type_names); - CHECK_EQ(feature_type_names.size(), num_col_) + // Gather column info when data is split by columns + auto gather_columns = [is_col_split, key, n_columns = this->num_col_](auto const& inputs) { + if (is_col_split) { + std::remove_const_t> result; + auto rc = collective::AllgatherStrings(inputs, &result); + collective::SafeColl(rc); + CHECK_EQ(result.size(), n_columns) << "Length of " << key << " must be equal to number of columns."; + return result; } + return inputs; + }; + + if (StringView{key} == "feature_type") { // NOLINT + this->feature_type_names.clear(); + std::copy(info, info + size, std::back_inserter(feature_type_names)); + feature_type_names = gather_columns(feature_type_names); auto& h_feature_types = feature_types.HostVector(); this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types); - } else if (!std::strcmp(key, "feature_name")) { - if (IsColumnSplit()) { - std::vector local_feature_names{}; + } else if (StringView{key} == "feature_name") { // NOLINT + feature_names.clear(); + if (is_col_split) { auto const rank = collective::GetRank(); - for (std::size_t i = 0; i < size; ++i) { - auto elem = std::to_string(rank) + "." + info[i]; - local_feature_names.emplace_back(elem); - } - feature_names = collective::AllgatherStrings(local_feature_names); - CHECK_EQ(feature_names.size(), num_col_) - << "Length of " << key << " must be equal to number of columns."; + std::transform(info, info + size, std::back_inserter(feature_names), + [rank](char const* elem) { return std::to_string(rank) + "." + elem; }); } else { - feature_names.clear(); - for (size_t i = 0; i < size; ++i) { - feature_names.emplace_back(info[i]); - } + std::copy(info, info + size, std::back_inserter(feature_names)); } + feature_names = gather_columns(feature_names); } else { LOG(FATAL) << "Unknown feature info name: " << key; } @@ -728,12 +730,10 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col } } -void MetaInfo::SynchronizeNumberOfColumns(Context const*) { - if (IsColumnSplit()) { - collective::Allreduce(&num_col_, 1); - } else { - collective::Allreduce(&num_col_, 1); - } +void MetaInfo::SynchronizeNumberOfColumns(Context const* ctx) { + auto op = IsColumnSplit() ? collective::Op::kSum : collective::Op::kMax; + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&num_col_, 1), op); + collective::SafeColl(rc); } namespace { diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 0d75d0651..e581e90ca 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -9,11 +9,12 @@ #include // for underlying_type_t #include // for vector -#include "../collective/communicator-inl.h" -#include "../common/categorical.h" // common::IsCat +#include "../collective/allreduce.h" // for Allreduce +#include "../collective/communicator-inl.h" // for IsDistributed +#include "../common/categorical.h" // common::IsCat #include "../common/column_matrix.h" -#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. -#include "batch_utils.h" // for RegenGHist +#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. +#include "batch_utils.h" // for RegenGHist #include "gradient_index.h" #include "proxy_dmatrix.h" #include "simple_batch_iterator.h" @@ -95,13 +96,13 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr ref, bst_featur namespace { // Synchronize feature type in case of empty DMatrix -void SyncFeatureType(Context const*, std::vector* p_h_ft) { +void SyncFeatureType(Context const* ctx, std::vector* p_h_ft) { if (!collective::IsDistributed()) { return; } auto& h_ft = *p_h_ft; - auto n_ft = h_ft.size(); - collective::Allreduce(&n_ft, 1); + bst_idx_t n_ft = h_ft.size(); + collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax)); if (!h_ft.empty()) { // Check correct size if this is not an empty DMatrix. CHECK_EQ(h_ft.size(), n_ft); @@ -109,7 +110,8 @@ void SyncFeatureType(Context const*, std::vector* p_h_ft) { if (n_ft > 0) { h_ft.resize(n_ft); auto ptr = reinterpret_cast*>(h_ft.data()); - collective::Allreduce(ptr, h_ft.size()); + collective::SafeColl( + collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax)); } } } // anonymous namespace @@ -175,7 +177,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, // We use do while here as the first batch is fetched in ctor if (n_features == 0) { n_features = num_cols(); - collective::Allreduce(&n_features, 1); + collective::SafeColl(collective::Allreduce(ctx, &n_features, collective::Op::kMax)); column_sizes.clear(); column_sizes.resize(n_features, 0); info_.num_col_ = n_features; diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 09a3976d7..69a7b1aa2 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -1,20 +1,18 @@ /** - * Copyright 2020-2023, XGBoost contributors + * Copyright 2020-2024, XGBoost contributors */ #include #include -#include +#include "../collective/allreduce.h" #include "../common/hist_util.cuh" #include "batch_utils.h" // for RegenGHist #include "device_adapter.cuh" #include "ellpack_page.cuh" -#include "gradient_index.h" #include "iterative_dmatrix.h" #include "proxy_dmatrix.cuh" #include "proxy_dmatrix.h" #include "simple_batch_iterator.h" -#include "sparse_page_source.h" namespace xgboost::data { void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, @@ -63,7 +61,8 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, dh::safe_cuda(cudaSetDevice(get_device().ordinal)); if (cols == 0) { cols = num_cols(); - collective::Allreduce(&cols, 1); + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax); + SafeColl(rc); this->info_.num_col_ = cols; } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 7efff7af4..a29fde842 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -171,12 +171,13 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_ } else { LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); } - if constexpr (get_value) { - return std::invoke_result_t< - Fn, decltype(std::declval>()->Value())>(); - } else { - return std::invoke_result_t>())>(); - } + } + + if constexpr (get_value) { + return std::invoke_result_t>()->Value())>(); + } else { + return std::invoke_result_t>())>(); } } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 4df1d5e53..f54d1c43e 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014~2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file simple_dmatrix.cc * \brief the input data structure for gradient boosting * \author Tianqi Chen @@ -13,6 +13,7 @@ #include #include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather +#include "../collective/allgather.h" #include "../common/error_msg.h" // for InconsistentMaxBin #include "./simple_batch_iterator.h" #include "adapter.h" @@ -76,8 +77,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { void SimpleDMatrix::ReindexFeatures(Context const* ctx) { if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) { - auto const cols = collective::Allgather(info_.num_col_); - auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul); + std::vector buffer(collective::GetWorldSize()); + buffer[collective::GetRank()] = info_.num_col_; + auto rc = collective::Allgather(ctx, linalg::MakeVec(buffer.data(), buffer.size())); + SafeColl(rc); + auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0); if (offset == 0) { return; } diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index ebb5fdf24..00aeeb542 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -11,6 +11,7 @@ #include // for async #include // for unique_ptr #include // for mutex +#include // for partial_sum #include // for string #include // for pair, move #include // for vector diff --git a/src/learner.cc b/src/learner.cc index ca6704944..93db7f801 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -35,7 +35,6 @@ #include "collective/aggregator.h" // for ApplyWithLabels #include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed -#include "collective/communicator.h" // for Operation #include "common/api_entry.h" // for XGBAPIThreadLocalEntry #include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_... #include "common/common.h" // for ToString, Split @@ -208,7 +207,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter return dmlc::Parameter::UpdateAllowUnknown(kwargs); } // sanity check - void Validate(Context const*) { + void Validate(Context const* ctx) { if (!collective::IsDistributed()) { return; } @@ -229,7 +228,8 @@ struct LearnerModelParamLegacy : public dmlc::Parameter std::array sync; std::copy(data.cbegin(), data.cend(), sync.begin()); - collective::Broadcast(sync.data(), sync.size(), 0); + auto rc = collective::Broadcast(ctx, linalg::MakeVec(sync.data(), sync.size()), 0); + collective::SafeColl(rc); CHECK(std::equal(data.cbegin(), data.cend(), sync.cbegin())) << "Different model parameter across workers."; } @@ -754,7 +754,9 @@ class LearnerConfiguration : public Learner { num_feature = std::max(num_feature, static_cast(num_col)); } - collective::Allreduce(&num_feature, 1); + auto rc = + collective::Allreduce(&ctx_, linalg::MakeVec(&num_feature, 1), collective::Op::kMax); + collective::SafeColl(rc); if (num_feature > mparam_.num_feature) { mparam_.num_feature = num_feature; } diff --git a/src/logging.cc b/src/logging.cc index d24c6633d..4cf74207d 100644 --- a/src/logging.cc +++ b/src/logging.cc @@ -1,14 +1,13 @@ -/*! - * Copyright 2015-2018 by Contributors +/** + * Copyright 2015-2024, XGBoost Contributors * \file logging.cc * \brief Implementation of loggers. * \author Tianqi Chen */ -#include - -#include "xgboost/parameter.h" #include "xgboost/logging.h" +#include // for string + #include "collective/communicator-inl.h" #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 189c2b8e7..6de0d1f12 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -264,9 +264,14 @@ class EvalAUC : public MetricNoCache { info.weights_.SetDevice(ctx_->Device()); } // We use the global size to handle empty dataset. - std::array meta{info.labels.Size(), preds.Size()}; + std::array meta{info.labels.Size(), preds.Size()}; if (!info.IsVerticalFederated()) { - collective::Allreduce(meta.data(), meta.size()); + auto rc = collective::Allreduce( + ctx_, + linalg::MakeTensorView(DeviceOrd::CPU(), common::Span{meta.data(), meta.size()}, + meta.size()), + collective::Op::kMax); + collective::SafeColl(rc); } if (meta[0] == 0) { // Empty across all workers, which is not supported. diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 4ce10d094..59199b092 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -1,9 +1,9 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ +#include // for copy #include -#include #include #include // NOLINT #include @@ -11,7 +11,7 @@ #include #include -#include "../collective/communicator-inl.cuh" +#include "../collective/allreduce.h" #include "../common/algorithm.cuh" // SegmentedArgSort #include "../common/optional_weight.h" // OptionalWeights #include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads @@ -201,13 +201,16 @@ void Transpose(common::Span in, common::Span out, size_t m, }); } -double ScaleClasses(Context const *ctx, common::Span results, +double ScaleClasses(Context const *ctx, bool is_column_split, common::Span results, common::Span local_area, common::Span tp, common::Span auc, size_t n_classes) { - if (collective::IsDistributed()) { - int32_t device = dh::CurrentDevice(); + // With vertical federated learning, only the root has label, other parties are not + // evaluation metrics. + if (collective::IsDistributed() && !(is_column_split && collective::IsFederated())) { + std::int32_t device = dh::CurrentDevice(); CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device); - collective::AllReduce(device, results.data(), results.size()); + auto rc = collective::Allreduce( + ctx, linalg::MakeVec(results.data(), results.size(), ctx->Device()), collective::Op::kSum); } auto reduce_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { @@ -334,7 +337,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info, auto local_area = d_results.subspan(0, n_classes); auto tp = d_results.subspan(2 * n_classes, n_classes); auto auc = d_results.subspan(3 * n_classes, n_classes); - return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes); + return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes); } /** @@ -438,7 +441,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info, tp[c] = 1.0f; } }); - return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes); + return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes); } void MultiClassSortedIdx(Context const *ctx, common::Span predts, @@ -835,7 +838,7 @@ std::pair GPURankingPRAUC(Context const *ctx, InitCacheOnce(predts, p_cache); dh::device_vector group_ptr(info.group_ptr_.size()); - thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); + thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); // NOLINT auto d_group_ptr = dh::ToSpan(group_ptr); CHECK_GE(info.group_ptr_.size(), 1) << "Must have at least 1 query group for LTR."; size_t n_groups = info.group_ptr_.size() - 1; diff --git a/src/metric/auc.h b/src/metric/auc.h index 4fe2ecec4..f27a1dda6 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -1,18 +1,14 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_METRIC_AUC_H_ #define XGBOOST_METRIC_AUC_H_ -#include #include -#include #include #include #include #include "../collective/communicator-inl.h" -#include "../common/common.h" -#include "../common/threading_utils.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/metric.h" diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index cbe69e79a..1a7aef051 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -9,8 +9,6 @@ #include // std::vector #include "../collective/aggregator.h" -#include "../collective/communicator-inl.h" -#include "../common/common.h" #include "xgboost/base.h" // bst_node_t #include "xgboost/context.h" // Context #include "xgboost/data.h" // MetaInfo @@ -42,7 +40,7 @@ inline void UpdateLeafValues(Context const* ctx, std::vector* p_quantiles auto& quantiles = *p_quantiles; auto const& h_node_idx = nidx; - size_t n_leaf = collective::GlobalMax(ctx, info, h_node_idx.size()); + bst_idx_t n_leaf = collective::GlobalMax(ctx, info, static_cast(h_node_idx.size())); CHECK(quantiles.empty() || quantiles.size() == n_leaf); if (quantiles.empty()) { quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index f253493fc..32992e313 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2017-2024, XGBoost Contributors */ #include // for max, fill, min #include // for any, any_cast @@ -12,7 +12,7 @@ #include // for vector #include "../collective/communicator-inl.h" // for Allreduce, IsDistributed -#include "../collective/communicator.h" // for Operation +#include "../collective/allreduce.h" #include "../common/bitfield.h" // for RBitField8 #include "../common/categorical.h" // for IsCat, Decision #include "../common/common.h" // for DivRoundUp @@ -461,11 +461,17 @@ class ColumnSplitHelper { return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id; } - void AllreduceBitVectors(Context const*) { - collective::Allreduce(decision_storage_.data(), - decision_storage_.size()); - collective::Allreduce(missing_storage_.data(), - missing_storage_.size()); + void AllreduceBitVectors(Context const *ctx) { + auto rc = collective::Success() << [&] { + return collective::Allreduce( + ctx, linalg::MakeVec(decision_storage_.data(), decision_storage_.size()), + collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce( + ctx, linalg::MakeVec(missing_storage_.data(), missing_storage_.size()), + collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); } void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index aea1aa95d..29fb6bb6a 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2017-2024, XGBoost Contributors */ #include #include @@ -11,7 +11,7 @@ #include // for any, any_cast #include -#include "../collective/communicator-inl.cuh" +#include "../collective/allreduce.h" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/common.h" @@ -817,10 +817,18 @@ class ColumnSplitHelper { void AllReduceBitVectors(dh::caching_device_vector* decision_storage, dh::caching_device_vector* missing_storage) const { - collective::AllReduce( - ctx_->Ordinal(), decision_storage->data().get(), decision_storage->size()); - collective::AllReduce( - ctx_->Ordinal(), missing_storage->data().get(), missing_storage->size()); + auto rc = collective::Success() << [&] { + return collective::Allreduce( + ctx_, + linalg::MakeVec(decision_storage->data().get(), decision_storage->size(), ctx_->Device()), + collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce( + ctx_, + linalg::MakeVec(missing_storage->data().get(), missing_storage->size(), ctx_->Device()), + collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); } void ResizeBitVectors(dh::caching_device_vector* decision_storage, diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index 293e7d1d4..c3065ad5f 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -1,24 +1,28 @@ /** - * Copyright 2021-2023 XGBoost contributors + * Copyright 2021-2023, XGBoost contributors * \file common_row_partitioner.h * \brief Common partitioner logic for hist and approx methods. */ #ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ -#include // std::all_of -#include // std::uint32_t -#include // std::numeric_limits -#include +#include // for all_of, fill +#include // for uint32_t +#include // for numeric_limits +#include // for vector -#include "../collective/communicator-inl.h" -#include "../common/linalg_op.h" // cbegin -#include "../common/numeric.h" // Iota -#include "../common/partition_builder.h" -#include "hist/expand_entry.h" // CPUExpandEntry -#include "xgboost/base.h" -#include "xgboost/context.h" // Context -#include "xgboost/linalg.h" // TensorView +#include "../collective/allreduce.h" // for Allreduce +#include "../common/bitfield.h" // for RBitField8 +#include "../common/linalg_op.h" // for cbegin +#include "../common/numeric.h" // for Iota +#include "../common/partition_builder.h" // for PartitionBuilder +#include "../common/row_set.h" // for RowSetCollection +#include "../common/threading_utils.h" // for ParallelFor2d +#include "xgboost/base.h" // for bst_row_t +#include "xgboost/collective/result.h" // for Success, SafeColl +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for TensorView +#include "xgboost/span.h" // for Span namespace xgboost::tree { @@ -39,7 +43,7 @@ class ColumnSplitHelper { } template - void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads, + void Partition(Context const* ctx, common::BlockedSpace2d const& space, std::int32_t n_threads, GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix, std::vector const& nodes, std::vector const& split_conditions, RegTree const* p_tree) { @@ -56,10 +60,12 @@ class ColumnSplitHelper { }); // Then aggregate the bit vectors across all the workers. - collective::Allreduce(decision_storage_.data(), - decision_storage_.size()); - collective::Allreduce(missing_storage_.data(), - missing_storage_.size()); + auto rc = collective::Success() << [&] { + return collective::Allreduce(ctx, &decision_storage_, collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce(ctx, &missing_storage_, collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); // Finally use the bit vectors to partition the rows. common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) { @@ -220,7 +226,7 @@ class CommonRowPartitioner { // Store results in intermediate buffers from partition_builder_ if (is_col_split_) { column_split_helper_.Partition( - space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree); + ctx, space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree); } else { common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { size_t begin = r.begin(); diff --git a/src/tree/fit_stump.cu b/src/tree/fit_stump.cu index 832d40754..dd71465df 100644 --- a/src/tree/fit_stump.cu +++ b/src/tree/fit_stump.cu @@ -47,8 +47,10 @@ void FitStump(Context const* ctx, MetaInfo const& info, thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it, thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); - collective::GlobalSum(info, ctx->Device(), reinterpret_cast(d_sum.Values().data()), - d_sum.Size() * 2); + auto rc = collective::GlobalSum(ctx, info, + linalg::MakeVec(reinterpret_cast(d_sum.Values().data()), + d_sum.Size() * 2, ctx->Device())); + SafeColl(rc); thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, [=] XGBOOST_DEVICE(std::size_t i) mutable { diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ceb322c28..5e225a13f 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,11 +1,11 @@ /** - * Copyright 2020-2023, XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include // std::max #include #include -#include "../../collective/communicator-inl.cuh" +#include "../../collective/allgather.h" #include "../../common/categorical.h" #include "../../data/ellpack_page.cuh" #include "evaluate_splits.cuh" @@ -413,8 +413,14 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector all_candidate_storage(out_splits.size() * world_size); auto all_candidates = dh::ToSpan(all_candidate_storage); - collective::AllGather(device_.ordinal, out_splits.data(), all_candidates.data(), - out_splits.size() * sizeof(DeviceSplitCandidate)); + auto current_rank = + all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size()); + dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(), + out_splits.size() * sizeof(DeviceSplitCandidate), + cudaMemcpyDeviceToDevice)); + auto rc = collective::Allgather( + ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device())); + collective::SafeColl(rc); // Reduce to get the best candidate from all workers. dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(), diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index d25a41cb0..ba673d85f 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -12,6 +12,7 @@ #include // for move #include // for vector +#include "../../collective/allgather.h" #include "../../common/categorical.h" // for CatBitField #include "../../common/hist_util.h" // for GHistRow, HistogramCuts #include "../../common/linalg_op.h" // for cbegin, cend, begin @@ -35,7 +36,7 @@ template std::enable_if_t || std::is_same_v, std::vector> -AllgatherColumnSplit(std::vector const &entries) { +AllgatherColumnSplit(Context const *ctx, std::vector const &entries) { auto const n_entries = entries.size(); // First, gather all the primitive fields. @@ -52,7 +53,7 @@ AllgatherColumnSplit(std::vector const &entries) { serialized_entries.emplace_back(std::move(out)); } - auto all_serialized = collective::VectorAllgatherV(serialized_entries); + auto all_serialized = collective::VectorAllgatherV(ctx, serialized_entries); CHECK_GE(all_serialized.size(), local_entries.size()); std::vector all_entries(all_serialized.size()); @@ -401,7 +402,7 @@ class HistEvaluator { if (is_col_split_) { // With column-wise data split, we gather the best splits from all the workers and update the // expand entries accordingly. - auto all_entries = AllgatherColumnSplit(entries); + auto all_entries = AllgatherColumnSplit(ctx_, entries); for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) { for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { entries[nidx_in_set].split.Update( @@ -632,7 +633,7 @@ class HistMultiEvaluator { if (is_col_split_) { // With column-wise data split, we gather the best splits from all the workers and update the // expand entries accordingly. - auto all_entries = AllgatherColumnSplit(entries); + auto all_entries = AllgatherColumnSplit(ctx_, entries); for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) { for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { entries[nidx_in_set].split.Update( diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 033d2221e..e28cae165 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_ #define XGBOOST_TREE_HIST_HISTOGRAM_H_ @@ -7,26 +7,24 @@ #include // for max #include // for size_t #include // for int32_t -#include // for function #include // for move #include // for vector -#include "../../collective/communicator-inl.h" // for Allreduce -#include "../../collective/communicator.h" // for Operation -#include "../../common/hist_util.h" // for GHistRow, ParallelGHi... -#include "../../common/row_set.h" // for RowSetCollection -#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d -#include "../../data/gradient_index.h" // for GHistIndexMatrix -#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry -#include "hist_cache.h" // for BoundedHistCollection -#include "param.h" // for HistMakerTrainParam -#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t -#include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for BatchIterator, BatchSet -#include "xgboost/linalg.h" // for MatrixView, All, Vect... -#include "xgboost/logging.h" // for CHECK_GE -#include "xgboost/span.h" // for Span -#include "xgboost/tree_model.h" // for RegTree +#include "../../collective/allreduce.h" // for Allreduce +#include "../../common/hist_util.h" // for GHistRow, ParallelGHi... +#include "../../common/row_set.h" // for RowSetCollection +#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d +#include "../../data/gradient_index.h" // for GHistIndexMatrix +#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry +#include "hist_cache.h" // for BoundedHistCollection +#include "param.h" // for HistMakerTrainParam +#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for BatchIterator, BatchSet +#include "xgboost/linalg.h" // for MatrixView, All, Vect... +#include "xgboost/logging.h" // for CHECK_GE +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { /** @@ -171,7 +169,7 @@ class HistogramBuilder { } } - void SyncHistogram(Context const *, RegTree const *p_tree, + void SyncHistogram(Context const *ctx, RegTree const *p_tree, std::vector const &nodes_to_build, std::vector const &nodes_to_trick) { auto n_total_bins = buffer_.TotalBins(); @@ -186,8 +184,10 @@ class HistogramBuilder { CHECK(!nodes_to_build.empty()); auto first_nidx = nodes_to_build.front(); std::size_t n = n_total_bins * nodes_to_build.size() * 2; - collective::Allreduce( - reinterpret_cast(this->hist_[first_nidx].data()), n); + auto rc = collective::Allreduce( + ctx, linalg::MakeVec(reinterpret_cast(this->hist_[first_nidx].data()), n), + collective::Op::kSum); + SafeColl(rc); } common::BlockedSpace2d const &subspace = diff --git a/src/tree/hist/param.cc b/src/tree/hist/param.cc index bd8d7a85c..10895d511 100644 --- a/src/tree/hist/param.cc +++ b/src/tree/hist/param.cc @@ -1,18 +1,22 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include "param.h" +#include // for binary #include // for string -#include "../../collective/communicator-inl.h" // for GetRank, Broadcast +#include "../../collective/broadcast.h" // for Broadcast +#include "../../collective/communicator-inl.h" // for GetRank #include "xgboost/json.h" // for Object, Json +#include "xgboost/linalg.h" // for MakeVec #include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { DMLC_REGISTER_PARAMETER(HistMakerTrainParam); -void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree) const { +void HistMakerTrainParam::CheckTreesSynchronized(Context const* ctx, + RegTree const* local_tree) const { if (!this->debug_synchronize) { return; } @@ -24,7 +28,15 @@ void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree->SaveModel(&model); } Json::Dump(model, &s_model, std::ios::binary); - collective::Broadcast(&s_model, 0); + + auto nchars{static_cast(s_model.size())}; + auto rc = collective::Success() << [&] { + return collective::Broadcast(ctx, linalg::MakeVec(&nchars, 1), 0); + } << [&] { + s_model.resize(nchars); + return collective::Broadcast(ctx, linalg::MakeVec(s_model.data(), s_model.size()), 0); + }; + collective::SafeColl(rc); RegTree ref_tree{}; // rank 0 tree auto j_ref_tree = Json::Load(StringView{s_model}, std::ios::binary); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index d53a25d17..958fa0331 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -13,7 +13,7 @@ #include #include "../collective/aggregator.h" -#include "../collective/aggregator.cuh" +#include "../collective/broadcast.h" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/cuda_context.cuh" // CUDAContext @@ -410,11 +410,16 @@ struct GPUHistMakerDevice { } }); - collective::AllReduce( - ctx_->Ordinal(), decision_storage.data().get(), decision_storage.size()); - collective::AllReduce( - ctx_->Ordinal(), missing_storage.data().get(), missing_storage.size()); - collective::Synchronize(ctx_->Ordinal()); + auto rc = collective::Success() << [&] { + return collective::Allreduce( + ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(decision_storage), decision_storage.size()), + collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce( + ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(missing_storage), missing_storage.size()), + collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); row_partitioner->UpdatePositionBatch( nidx, left_nidx, right_nidx, split_data, @@ -611,8 +616,11 @@ struct GPUHistMakerDevice { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); using ReduceT = typename std::remove_pointer::type::ValueT; - collective::GlobalSum(info_, ctx_->Device(), reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * 2 * num_histograms); + auto rc = collective::GlobalSum( + ctx_, info_, + linalg::MakeVec(reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device())); + SafeColl(rc); monitor.Stop("AllReduce"); } @@ -860,7 +868,9 @@ class GPUHistMaker : public TreeUpdater { // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); - collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); + auto rc = collective::Broadcast( + ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0); + SafeColl(rc); this->column_sampler_ = std::make_shared(column_sampling_seed); auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; @@ -1001,9 +1011,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { monitor_.Start(__func__); CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal(); - // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); - collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); this->column_sampler_ = std::make_shared(column_sampling_seed); p_last_fmat_ = p_fmat; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 941df7aec..23c8ec9e6 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file updater_refresh.cc * \brief refresh the statistics and leaf value on the tree on the dataset * \author Tianqi Chen @@ -9,8 +9,7 @@ #include #include -#include "../collective/communicator-inl.h" -#include "../common/io.h" +#include "../collective/allreduce.h" #include "../common/threading_utils.h" #include "../predictor/predict_fn.h" #include "./param.h" @@ -39,7 +38,7 @@ class TreeRefresher : public TreeUpdater { } CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented(); const std::vector &gpair_h = gpair->Data()->ConstHostVector(); - // thread temporal space + // Thread local variables. std::vector > stemp; std::vector fvec_temp; // setup temp space for each thread @@ -61,9 +60,8 @@ class TreeRefresher : public TreeUpdater { }); } exc.Rethrow(); - // if it is C++11, use lazy evaluation for Allreduce, - // to gain speedup in recovery - auto lazy_get_stats = [&]() { + + auto get_stats = [&]() { const MetaInfo &info = p_fmat->Info(); // start accumulating statistics for (const auto &batch : p_fmat->GetBatches()) { @@ -93,12 +91,17 @@ class TreeRefresher : public TreeUpdater { } }); }; - lazy_get_stats(); - collective::Allreduce(&dmlc::BeginPtr(stemp[0])->sum_grad, - stemp[0].size() * 2); - int offset = 0; + get_stats(); + // Synchronize the aggregated result. + auto &sum_grad = stemp[0]; + // x2 for gradient and hessian. + auto rc = collective::Allreduce( + ctx_, linalg::MakeVec(&sum_grad.data()->sum_grad, sum_grad.size() * 2), + collective::Op::kMax); + collective::SafeColl(rc); + bst_node_t offset = 0; for (auto tree : trees) { - this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree); + this->Refresh(param, dmlc::BeginPtr(sum_grad) + offset, 0, tree); offset += tree->NumNodes(); } } diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index f64f35483..6526e519c 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -1,14 +1,14 @@ /** - * Copyright 2014-2023 by XBGoost Contributors + * Copyright 2014-2024, XBGoost Contributors * \file updater_sync.cc * \brief synchronize the tree in all distributed nodes */ #include -#include #include #include +#include "../collective/broadcast.h" #include "../collective/communicator-inl.h" #include "../common/io.h" #include "xgboost/json.h" @@ -44,7 +44,8 @@ class TreeSyncher : public TreeUpdater { } } fs.Seek(0); - collective::Broadcast(&s_model, 0); + auto rc = collective::Broadcast(ctx_, linalg::MakeVec(s_model.data(), s_model.size()), 0); + SafeColl(rc); for (auto tree : trees) { tree->Load(&fs); } diff --git a/tests/ci_build/Dockerfile.jvm_cross b/tests/ci_build/Dockerfile.jvm_cross index 5c4bb569b..43988872d 100644 --- a/tests/ci_build/Dockerfile.jvm_cross +++ b/tests/ci_build/Dockerfile.jvm_cross @@ -1,6 +1,6 @@ FROM ubuntu:18.04 ARG JDK_VERSION=8 -ARG SPARK_VERSION=3.0.0 +ARG SPARK_VERSION=3.4.0 # Environment ENV DEBIAN_FRONTEND noninteractive diff --git a/tests/ci_build/build_jvm_packages.sh b/tests/ci_build/build_jvm_packages.sh index 84b41f2b1..97c056403 100755 --- a/tests/ci_build/build_jvm_packages.sh +++ b/tests/ci_build/build_jvm_packages.sh @@ -18,7 +18,6 @@ fi rm -rf build/ cd jvm-packages -export RABIT_MOCK=ON if [ "x$gpu_arch" != "x" ]; then export GPU_ARCH_FLAG=$gpu_arch diff --git a/tests/ci_build/build_mock_cmake.sh b/tests/ci_build/build_mock_cmake.sh deleted file mode 100755 index 8cbabd036..000000000 --- a/tests/ci_build/build_mock_cmake.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -e - -rm -rf build -mkdir build -cd build -cmake -DRABIT_MOCK=ON -DCMAKE_VERBOSE_MAKEFILE=ON .. -make clean -make -j$(nproc) -cd .. diff --git a/tests/ci_build/test_r_package.py b/tests/ci_build/test_r_package.py index ddcf48674..1fe1644ad 100644 --- a/tests/ci_build/test_r_package.py +++ b/tests/ci_build/test_r_package.py @@ -53,7 +53,6 @@ def pack_rpackage() -> Path: # rabit rabit = Path("rabit") os.mkdir(dest / "src" / rabit) - shutil.copytree(rabit / "src", dest / "src" / "rabit" / "src") shutil.copytree(rabit / "include", dest / "src" / "rabit" / "include") # dmlc-core dmlc_core = Path("dmlc-core") diff --git a/tests/cpp/collective/net_test.h b/tests/cpp/collective/net_test.h deleted file mode 100644 index ed15ed256..000000000 --- a/tests/cpp/collective/net_test.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022-2023, XGBoost Contributors - */ -#pragma once - -#include -#include - -#include // ifstream - -#include "../helpers.h" // for FileExists - -namespace xgboost::collective { -class SocketTest : public ::testing::Test { - protected: - std::string skip_msg_{"Skipping IPv6 test"}; - - bool SkipTest() { - std::string path{"/sys/module/ipv6/parameters/disable"}; - if (FileExists(path)) { - std::ifstream fin(path); - if (!fin) { - return true; - } - std::string s_value; - fin >> s_value; - auto value = std::stoi(s_value); - if (value != 0) { - return true; - } - } else { - return true; - } - return false; - } - - protected: - void SetUp() override { system::SocketStartup(); } - void TearDown() override { system::SocketFinalize(); } -}; -} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index b25db54cb..61e34cb57 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -175,4 +175,35 @@ TEST_F(AllgatherTest, VAlgo) { worker.TestVAlgo(); }); } + +TEST(VectorAllgatherV, Basic) { + std::int32_t n_workers{3}; + TestDistributedGlobal(n_workers, []() { + auto n_workers = collective::GetWorldSize(); + ASSERT_EQ(n_workers, 3); + auto rank = collective::GetRank(); + // Construct input that has different length for each worker. + std::vector> inputs; + for (std::int32_t i = 0; i < rank + 1; ++i) { + std::vector in; + for (std::int32_t j = 0; j < rank + 1; ++j) { + in.push_back(static_cast(j)); + } + inputs.emplace_back(std::move(in)); + } + + Context ctx; + auto outputs = VectorAllgatherV(&ctx, inputs); + + ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2); + auto const& res = outputs; + + for (std::int32_t i = 0; i < n_workers; ++i) { + std::int32_t k = 0; + for (auto v : res[i]) { + ASSERT_EQ(v, k++); + } + } + }); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 13a6ca656..1ce2f35fd 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -39,6 +39,22 @@ class AllreduceWorker : public WorkerForTest { } } + void Restricted() { + this->LimitSockBuf(4096); + + std::size_t n = 4096 * 4; + std::vector data(comm_.World() * n, 1); + auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + for (std::size_t i = 0; i < rhs.size(); ++i) { + rhs[i] += lhs[i]; + } + }); + ASSERT_TRUE(rc.OK()); + for (auto v : data) { + ASSERT_EQ(v, comm_.World()); + } + } + void Acc() { std::vector data(314, 1.5); auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { @@ -95,4 +111,45 @@ TEST_F(AllreduceTest, BitOr) { worker.BitOr(); }); } + +TEST_F(AllreduceTest, Restricted) { + std::int32_t n_workers = std::min(3u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + AllreduceWorker worker{host, port, timeout, n_workers, r}; + worker.Restricted(); + }); +} + +TEST(AllreduceGlobal, Basic) { + auto n_workers = 3; + TestDistributedGlobal(n_workers, [&]() { + std::vector values(n_workers * 2, 0); + auto rank = GetRank(); + auto s_values = common::Span{values.data(), values.size()}; + auto self = s_values.subspan(rank * 2, 2); + for (auto& v : self) { + v = 1.0f; + } + Context ctx; + auto rc = + Allreduce(&ctx, linalg::MakeVec(s_values.data(), s_values.size()), collective::Op::kSum); + SafeColl(rc); + for (auto v : s_values) { + ASSERT_EQ(v, 1); + } + }); +} + +TEST(AllreduceGlobal, Small) { + // Test when the data is not large enougth to be divided by the number of workers + auto n_workers = 8; + TestDistributedGlobal(n_workers, [&]() { + std::uint64_t value{1}; + Context ctx; + auto rc = Allreduce(&ctx, linalg::MakeVec(&value, 1), collective::Op::kSum); + SafeColl(rc); + ASSERT_EQ(value, n_workers); + }); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_communicator.cc b/tests/cpp/collective/test_communicator.cc deleted file mode 100644 index a0ca9e747..000000000 --- a/tests/cpp/collective/test_communicator.cc +++ /dev/null @@ -1,63 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include "../../../src/collective/communicator.h" - -namespace xgboost { -namespace collective { - -TEST(CommunicatorFactory, TypeFromEnv) { - EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "foo"); - EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); - EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); - EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "In-Memory"); - EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromEnv()); -} - -TEST(CommunicatorFactory, TypeFromArgs) { - Json config{JsonObject()}; - EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("rabit"); - EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("federated"); - EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("in-memory"); - EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("foo"); - EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); -} - -TEST(CommunicatorFactory, TypeFromArgsUpperCase) { - Json config{JsonObject()}; - EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("rabit"); - EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("federated"); - EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("in-memory"); - EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("foo"); - EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); -} - -} // namespace collective -} // namespace xgboost diff --git a/tests/cpp/collective/test_in_memory_communicator.cc b/tests/cpp/collective/test_in_memory_communicator.cc deleted file mode 100644 index 69c427a4e..000000000 --- a/tests/cpp/collective/test_in_memory_communicator.cc +++ /dev/null @@ -1,237 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include -#include - -#include "../../../src/collective/in_memory_communicator.h" - -namespace xgboost { -namespace collective { - -class InMemoryCommunicatorTest : public ::testing::Test { - public: - static void Verify(void (*function)(int)) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(function, rank); - } - for (auto &thread : threads) { - thread.join(); - } - } - - static void Allgather(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllgather(comm, rank); - } - - static void AllgatherV(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllgatherV(comm, rank); - } - - static void AllreduceMax(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceMax(comm, rank); - } - - static void AllreduceMin(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceMin(comm, rank); - } - - static void AllreduceSum(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceSum(comm); - } - - static void AllreduceBitwiseAND(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceBitwiseAND(comm, rank); - } - - static void AllreduceBitwiseOR(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceBitwiseOR(comm, rank); - } - - static void AllreduceBitwiseXOR(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceBitwiseXOR(comm, rank); - } - - static void Broadcast(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyBroadcast(comm, rank); - } - - static void Mixture(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - for (auto i = 0; i < 5; i++) { - VerifyAllgather(comm, rank); - VerifyAllreduceMax(comm, rank); - VerifyAllreduceMin(comm, rank); - VerifyAllreduceSum(comm); - VerifyAllreduceBitwiseAND(comm, rank); - VerifyAllreduceBitwiseOR(comm, rank); - VerifyAllreduceBitwiseXOR(comm, rank); - VerifyBroadcast(comm, rank); - } - } - - protected: - static void VerifyAllgather(InMemoryCommunicator &comm, int rank) { - std::string input{static_cast('0' + rank)}; - auto output = comm.AllGather(input); - for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(output[i], static_cast('0' + i)); - } - } - - static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) { - std::vector inputs{"a", "bb", "ccc"}; - auto output = comm.AllGatherV(inputs[rank]); - EXPECT_EQ(output, "abbccc"); - } - - static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) { - int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax); - int expected[] = {3, 4, 5, 6, 7}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void VerifyAllreduceMin(InMemoryCommunicator &comm, int rank) { - int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMin); - int expected[] = {1, 2, 3, 4, 5}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void VerifyAllreduceSum(InMemoryCommunicator &comm) { - int buffer[] = {1, 2, 3, 4, 5}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); - int expected[] = {3, 6, 9, 12, 15}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void VerifyAllreduceBitwiseAND(InMemoryCommunicator &comm, int rank) { - std::bitset<2> original(rank); - auto buffer = original.to_ulong(); - comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseAND); - EXPECT_EQ(buffer, 0UL); - } - - static void VerifyAllreduceBitwiseOR(InMemoryCommunicator &comm, int rank) { - std::bitset<2> original(rank); - auto buffer = original.to_ulong(); - comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseOR); - std::bitset<2> actual(buffer); - std::bitset<2> expected{0b11}; - EXPECT_EQ(actual, expected); - } - - static void VerifyAllreduceBitwiseXOR(InMemoryCommunicator &comm, int rank) { - std::bitset<3> original(rank * 2); - auto buffer = original.to_ulong(); - comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseXOR); - std::bitset<3> actual(buffer); - std::bitset<3> expected{0b110}; - EXPECT_EQ(actual, expected); - } - - static void VerifyBroadcast(InMemoryCommunicator &comm, int rank) { - if (rank == 0) { - std::string buffer{"hello"}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } else { - std::string buffer{" "}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } - } - - static int const kWorldSize{3}; -}; - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = []() { InMemoryCommunicator comm{0, 0}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = []() { InMemoryCommunicator comm{1, -1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = []() { InMemoryCommunicator comm{1, 1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { - auto construct = []() { - Json config{JsonObject()}; - config["in_memory_world_size"] = std::string("1"); - config["in_memory_rank"] = Integer(0); - auto *comm = InMemoryCommunicator::Create(config); - delete comm; - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankNotInteger) { - auto construct = []() { - Json config{JsonObject()}; - config["in_memory_world_size"] = 1; - config["in_memory_rank"] = std::string("0"); - auto *comm = InMemoryCommunicator::Create(config); - delete comm; - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, GetWorldSizeAndRank) { - InMemoryCommunicator comm{1, 0}; - EXPECT_EQ(comm.GetWorldSize(), 1); - EXPECT_EQ(comm.GetRank(), 0); -} - -TEST(InMemoryCommunicatorSimpleTest, IsDistributed) { - InMemoryCommunicator comm{1, 0}; - EXPECT_TRUE(comm.IsDistributed()); -} - -TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); } - -TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); } - -TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); } - -TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); } - -TEST_F(InMemoryCommunicatorTest, AllreduceSum) { Verify(&AllreduceSum); } - -TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseAND) { Verify(&AllreduceBitwiseAND); } - -TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseOR) { Verify(&AllreduceBitwiseOR); } - -TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseXOR) { Verify(&AllreduceBitwiseXOR); } - -TEST_F(InMemoryCommunicatorTest, Broadcast) { Verify(&Broadcast); } - -TEST_F(InMemoryCommunicatorTest, Mixture) { Verify(&Mixture); } - -} // namespace collective -} // namespace xgboost diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index 34e0c1de8..622b350aa 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -59,7 +59,7 @@ class LoopTest : public ::testing::Test { TEST_F(LoopTest, Timeout) { std::vector data(1); Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0}; - loop_->Submit(op); + loop_->Submit(std::move(op)); auto rc = loop_->Block(); ASSERT_FALSE(rc.OK()); ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report(); @@ -75,8 +75,8 @@ TEST_F(LoopTest, Op) { Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0}; Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0}; - loop_->Submit(wop); - loop_->Submit(rop); + loop_->Submit(std::move(wop)); + loop_->Submit(std::move(rop)); auto rc = loop_->Block(); SafeColl(rc); @@ -90,7 +90,7 @@ TEST_F(LoopTest, Block) { common::Timer t; t.Start(); - loop_->Submit(op); + loop_->Submit(std::move(op)); t.Stop(); // submit is non-blocking ASSERT_LT(t.ElapsedSeconds(), 1); diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu deleted file mode 100644 index 47e86220d..000000000 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2022-2023, XGBoost contributors - */ -#ifdef XGBOOST_USE_NCCL - -#include - -#include -#include // for string - -#include "../../../src/collective/comm.cuh" -#include "../../../src/collective/communicator-inl.cuh" -#include "../../../src/collective/nccl_device_communicator.cuh" -#include "../helpers.h" - -namespace xgboost { -namespace collective { - -TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { NcclDeviceCommunicator comm{-1, false, DefaultNcclName()}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(NcclDeviceCommunicatorSimpleTest, SystemError) { - auto stub = std::make_shared(DefaultNcclName()); - auto rc = stub->GetNcclResult(ncclSystemError); - auto msg = rc.Report(); - ASSERT_TRUE(msg.find("environment variables") != std::string::npos); -} - -namespace { -void VerifyAllReduceBitwiseAND() { - auto const rank = collective::GetRank(); - std::bitset<64> original{}; - original[rank] = true; - HostDeviceVector buffer({original.to_ullong()}, DeviceOrd::CUDA(rank)); - collective::AllReduce(rank, buffer.DevicePointer(), 1); - collective::Synchronize(rank); - EXPECT_EQ(buffer.HostVector()[0], 0ULL); -} -} // anonymous namespace - -TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseAND) { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseAND test with # GPUs = " << n_gpus; - } - auto constexpr kUseNccl = true; - RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseAND); -} - -namespace { -void VerifyAllReduceBitwiseOR() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::bitset<64> original{}; - original[rank] = true; - HostDeviceVector buffer({original.to_ullong()}, DeviceOrd::CUDA(rank)); - collective::AllReduce(rank, buffer.DevicePointer(), 1); - collective::Synchronize(rank); - EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1); -} -} // anonymous namespace - -TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseOR) { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseOR test with # GPUs = " << n_gpus; - } - auto constexpr kUseNccl = true; - RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseOR); -} - -namespace { -void VerifyAllReduceBitwiseXOR() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::bitset<64> original{~0ULL}; - original[rank] = false; - HostDeviceVector buffer({original.to_ullong()}, DeviceOrd::CUDA(rank)); - collective::AllReduce(rank, buffer.DevicePointer(), 1); - collective::Synchronize(rank); - EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1); -} -} // anonymous namespace - -TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseXOR) { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseXOR test with # GPUs = " << n_gpus; - } - auto constexpr kUseNccl = true; - RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseXOR); -} - -} // namespace collective -} // namespace xgboost - -#endif // XGBOOST_USE_NCCL diff --git a/tests/cpp/collective/test_rabit_communicator.cc b/tests/cpp/collective/test_rabit_communicator.cc deleted file mode 100644 index 9711e1aed..000000000 --- a/tests/cpp/collective/test_rabit_communicator.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2022-2024, XGBoost contributors - */ -#include - -#include "../../../src/collective/rabit_communicator.h" -#include "../helpers.h" - -namespace xgboost::collective { -TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = []() { RabitCommunicator comm{0, 0}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = []() { RabitCommunicator comm{1, -1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = []() { RabitCommunicator comm{1, 1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) { - RabitCommunicator comm{6, 3}; - EXPECT_EQ(comm.GetWorldSize(), 6); - EXPECT_EQ(comm.GetRank(), 3); -} - -TEST(RabitCommunicatorSimpleTest, IsNotDistributed) { - RabitCommunicator comm{2, 1}; - // Rabit is only distributed with a tracker. - EXPECT_FALSE(comm.IsDistributed()); -} - -namespace { -void VerifyVectorAllgatherV() { - auto n_workers = collective::GetWorldSize(); - ASSERT_EQ(n_workers, 3); - auto rank = collective::GetRank(); - // Construct input that has different length for each worker. - std::vector> inputs; - for (std::int32_t i = 0; i < rank + 1; ++i) { - std::vector in; - for (std::int32_t j = 0; j < rank + 1; ++j) { - in.push_back(static_cast(j)); - } - inputs.emplace_back(std::move(in)); - } - - auto outputs = VectorAllgatherV(inputs); - - ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2); - auto const& res = outputs; - - for (std::int32_t i = 0; i < n_workers; ++i) { - std::int32_t k = 0; - for (auto v : res[i]) { - ASSERT_EQ(v, k++); - } - } -} -} // namespace - -TEST(VectorAllgatherV, Basic) { - std::int32_t n_workers{3}; - RunWithInMemoryCommunicator(n_workers, VerifyVectorAllgatherV); -} -} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc index 8d6cbeff2..e31e26628 100644 --- a/tests/cpp/collective/test_tracker.cc +++ b/tests/cpp/collective/test_tracker.cc @@ -29,6 +29,7 @@ class PrintWorker : public WorkerForTest { TEST_F(TrackerTest, Bootstrap) { RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + ASSERT_TRUE(HasTimeout(tracker.Timeout())); ASSERT_FALSE(tracker.Ready()); auto fut = tracker.Run(); @@ -47,6 +48,9 @@ TEST_F(TrackerTest, Bootstrap) { w.join(); } SafeColl(fut.get()); + + ASSERT_FALSE(HasTimeout(std::chrono::seconds{-1})); + ASSERT_FALSE(HasTimeout(std::chrono::seconds{0})); } TEST_F(TrackerTest, Print) { diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index c84df528f..19e5e590a 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -16,6 +16,10 @@ #include "../../../src/collective/tracker.h" // for GetHostAddress #include "../helpers.h" // for FileExists +#if defined(XGBOOST_USE_FEDERATED) +#include "../plugin/federated/test_worker.h" +#endif // defined(XGBOOST_USE_FEDERATED) + namespace xgboost::collective { class WorkerForTest { std::string tracker_host_; @@ -45,6 +49,7 @@ class WorkerForTest { if (i != comm_.Rank()) { ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK()); + ASSERT_TRUE(comm_.Chan(i)->Socket()->SetNoDelay().OK()); } } } @@ -126,15 +131,80 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { ASSERT_TRUE(fut.get().OK()); } + inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Json config{Object{}}; config["dmlc_communicator"] = std::string{"rabit"}; config["dmlc_tracker_uri"] = host; config["dmlc_tracker_port"] = port; - config["dmlc_timeout_sec"] = static_cast(timeout.count()); + config["dmlc_timeout"] = static_cast(timeout.count()); config["dmlc_task_id"] = std::to_string(r); config["dmlc_retry"] = 2; return config; } + +template +void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) { + system::SocketStartup(); + std::chrono::seconds timeout{1}; + + std::string host; + auto rc = GetHostAddress(&host); + SafeColl(rc); + + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { + auto config = MakeDistributedTestConfig(host, port, timeout, i); + Init(config); + worker_fn(); + if (need_finalize) { + Finalize(); + } + }); + } + + for (auto& t : workers) { + t.join(); + } + + ASSERT_TRUE(fut.get().OK()); + system::SocketFinalize(); +} + +class BaseMGPUTest : public ::testing::Test { + public: + /** + * @param emulate_if_single Emulate multi-GPU for federated test if there's only one GPU + * available. + */ + template + auto DoTest(Fn&& fn, bool is_federated, bool emulate_if_single = false) const { + auto n_gpus = common::AllVisibleGPUs(); + if (is_federated) { +#if defined(XGBOOST_USE_FEDERATED) + if (n_gpus == 1 && emulate_if_single) { + // Emulate multiple GPUs. + // We don't use nccl and can have multiple communicators running on the same device. + n_gpus = 3; + } + TestFederatedGlobal(n_gpus, fn); +#else + GTEST_SKIP_("Not compiled with federated learning."); +#endif // defined(XGBOOST_USE_FEDERATED) + } else { +#if defined(XGBOOST_USE_NCCL) + TestDistributedGlobal(n_gpus, fn); +#else + GTEST_SKIP_("Not compiled with NCCL."); +#endif // defined(XGBOOST_USE_NCCL) + } + } +}; } // namespace xgboost::collective diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index face21851..e7f72dc27 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -15,8 +15,8 @@ namespace xgboost::common { TEST(MemoryFixSizeBuffer, Seek) { size_t constexpr kSize { 64 }; std::vector memory( kSize ); - rabit::utils::MemoryFixSizeBuffer buf(memory.data(), memory.size()); - buf.Seek(rabit::utils::MemoryFixSizeBuffer::kSeekEnd); + MemoryFixSizeBuffer buf(memory.data(), memory.size()); + buf.Seek(MemoryFixSizeBuffer::kSeekEnd); size_t end = buf.Tell(); ASSERT_EQ(end, kSize); } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 26937be76..fef7db9dc 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -1,12 +1,16 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include "test_quantile.h" #include +#include // for int64_t + +#include "../../../src/collective/allreduce.h" #include "../../../src/common/hist_util.h" #include "../../../src/data/adapter.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "xgboost/context.h" namespace xgboost::common { @@ -90,6 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { // Generate cuts for single node environment collective::Finalize(); + CHECK_EQ(collective::GetWorldSize(), 1); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); m->Info().num_row_ = world * rows; @@ -145,7 +150,8 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { template void TestDistributedQuantile(size_t const rows, size_t const cols) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, DoTestDistributedQuantile, rows, cols); + collective::TestDistributedGlobal( + kWorkers, [=] { DoTestDistributedQuantile(rows, cols); }, false); } } // anonymous namespace @@ -272,7 +278,8 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { template void TestColSplitQuantile(size_t rows, size_t cols) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantile, rows, cols); + collective::TestDistributedGlobal(kWorkers, + [=] { DoTestColSplitQuantile(rows, cols); }); } } // anonymous namespace @@ -324,43 +331,56 @@ void TestSameOnAllWorkers() { cut_ptrs(cuts.Ptrs().size() * world, 0); std::vector cut_min_values(cuts.MinValues().size() * world, 0); - size_t value_size = cuts.Values().size(); - collective::Allreduce(&value_size, 1); - size_t ptr_size = cuts.Ptrs().size(); - collective::Allreduce(&ptr_size, 1); - CHECK_EQ(ptr_size, kCols + 1); - size_t min_value_size = cuts.MinValues().size(); - collective::Allreduce(&min_value_size, 1); - CHECK_EQ(min_value_size, kCols); + std::int64_t value_size = cuts.Values().size(); + std::int64_t ptr_size = cuts.Ptrs().size(); + std::int64_t min_value_size = cuts.MinValues().size(); - size_t value_offset = value_size * rank; - std::copy(cuts.Values().begin(), cuts.Values().end(), - cut_values.begin() + value_offset); - size_t ptr_offset = ptr_size * rank; - std::copy(cuts.Ptrs().cbegin(), cuts.Ptrs().cend(), - cut_ptrs.begin() + ptr_offset); - size_t min_values_offset = min_value_size * rank; + auto rc = collective::Success() << [&] { + return collective::Allreduce(&ctx, &value_size, collective::Op::kMax); + } << [&] { + return collective::Allreduce(&ctx, &ptr_size, collective::Op::kMax); + } << [&] { + return collective::Allreduce(&ctx, &min_value_size, collective::Op::kMax); + }; + collective::SafeColl(rc); + ASSERT_EQ(ptr_size, kCols + 1); + ASSERT_EQ(min_value_size, kCols); + + std::size_t value_offset = value_size * rank; + std::copy(cuts.Values().begin(), cuts.Values().end(), cut_values.begin() + value_offset); + std::size_t ptr_offset = ptr_size * rank; + std::copy(cuts.Ptrs().cbegin(), cuts.Ptrs().cend(), cut_ptrs.begin() + ptr_offset); + std::size_t min_values_offset = min_value_size * rank; std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(), cut_min_values.begin() + min_values_offset); - collective::Allreduce(cut_values.data(), cut_values.size()); - collective::Allreduce(cut_ptrs.data(), cut_ptrs.size()); - collective::Allreduce(cut_min_values.data(), cut_min_values.size()); + rc = std::move(rc) << [&] { + return collective::Allreduce(&ctx, linalg::MakeVec(cut_values.data(), cut_values.size()), + collective::Op::kSum); + } << [&] { + return collective::Allreduce(&ctx, linalg::MakeVec(cut_ptrs.data(), cut_ptrs.size()), + collective::Op::kSum); + } << [&] { + return collective::Allreduce( + &ctx, linalg::MakeVec(cut_min_values.data(), cut_min_values.size()), + collective::Op::kSum); + }; + collective::SafeColl(rc); - for (int32_t i = 0; i < world; i++) { - for (size_t j = 0; j < value_size; ++j) { + for (std::int32_t i = 0; i < world; i++) { + for (std::int64_t j = 0; j < value_size; ++j) { size_t idx = i * value_size + j; - EXPECT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); + ASSERT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); } - for (size_t j = 0; j < ptr_size; ++j) { + for (std::int64_t j = 0; j < ptr_size; ++j) { size_t idx = i * ptr_size + j; EXPECT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx)); } - for (size_t j = 0; j < min_value_size; ++j) { + for (std::int64_t j = 0; j < min_value_size; ++j) { size_t idx = i * min_value_size + j; - EXPECT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); + ASSERT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); } } }); @@ -369,6 +389,6 @@ void TestSameOnAllWorkers() { TEST(Quantile, SameOnAllWorkers) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestSameOnAllWorkers); + collective::TestDistributedGlobal(kWorkers, [] { TestSameOnAllWorkers(); }); } } // namespace xgboost::common diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 070c705b5..80c9c5c71 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,12 +1,13 @@ /** - * Copyright 2020-2023, XGBoost contributors + * Copyright 2020-2024, XGBoost contributors */ #include -#include "../../../src/collective/communicator-inl.cuh" +#include "../../../src/collective/allreduce.h" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" #include "../../../src/data/device_adapter.cuh" // CupyAdapter +#include "../collective/test_worker.h" // for BaseMGPUTest #include "../helpers.h" #include "test_quantile.h" @@ -18,9 +19,9 @@ struct IsSorted { } }; } -namespace common { -class MGPUQuantileTest : public BaseMGPUTest {}; +namespace common { +class MGPUQuantileTest : public collective::BaseMGPUTest {}; TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; @@ -36,7 +37,8 @@ TEST(GPUQuantile, Basic) { void TestSketchUnique(float sparsity) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](std::int32_t seed, bst_bin_t n_bins, + MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); @@ -121,7 +123,7 @@ void TestQuantileElemRank(DeviceOrd device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); @@ -190,7 +192,7 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU()); HostDeviceVector storage_0; @@ -260,9 +262,9 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { using Tuple = thrust::tuple; auto it = thrust::make_zip_iterator(tuple_it); thrust::transform(thrust::device, it, it + data_1.size(), data_1.data(), - [=] __device__(Tuple const &tuple) { + [=] XGBOOST_DEVICE(Tuple const& tuple) { auto i = thrust::get<0>(tuple); - if (thrust::get<0>(tuple) % 2 == 0) { + if (i % 2 == 0) { return 0.0f; } else { return thrust::get<1>(tuple); @@ -306,7 +308,7 @@ TEST(GPUQuantile, MergeDuplicated) { TEST(GPUQuantile, MultiMerge) { constexpr size_t kRows = 20, kCols = 1; int32_t world = 2; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { // Set up single node version HostDeviceVector ft; SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, FstCU()); @@ -368,16 +370,18 @@ namespace { void TestAllReduceBasic() { auto const world = collective::GetWorldSize(); constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { auto const device = DeviceOrd::CUDA(GPUIDX); auto ctx = MakeCUDACtx(device.ordinal); - // Set up single node version; + /** + * Set up single node version. + */ HostDeviceVector ft({}, device); SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, device); - size_t intermediate_num_cuts = std::min( - kRows * world, static_cast(n_bins * WQSketch::kFactor)); + size_t intermediate_num_cuts = + std::min(kRows * world, static_cast(n_bins * WQSketch::kFactor)); std::vector containers; for (auto rank = 0; rank < world; ++rank) { HostDeviceVector storage({}, device); @@ -388,21 +392,22 @@ void TestAllReduceBasic() { data::CupyAdapter adapter(interface_str); HostDeviceVector ft({}, device); containers.emplace_back(ft, n_bins, kCols, kRows, device); - AdapterDeviceSketch(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &containers.back()); } - for (auto &sketch : containers) { + for (auto& sketch : containers) { sketch.Prune(intermediate_num_cuts); sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); } sketch_on_single_node.Unique(); - TestQuantileElemRank(device, sketch_on_single_node.Data(), - sketch_on_single_node.ColumnsPtr(), true); + TestQuantileElemRank(device, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr(), + true); - // Set up distributed version. We rely on using rank as seed to generate - // the exact same copy of data. + /** + * Set up distributed version. We rely on using rank as seed to generate + * the exact same copy of data. + */ auto rank = collective::GetRank(); SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, device); HostDeviceVector storage({}, device); @@ -411,22 +416,23 @@ void TestAllReduceBasic() { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketch(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_distributed); + if (world == 1) { + auto n_samples_global = kRows * world; + intermediate_num_cuts = + std::min(n_samples_global, static_cast(n_bins * SketchContainer::kFactor)); + sketch_distributed.Prune(intermediate_num_cuts); + } sketch_distributed.AllReduce(&ctx, false); sketch_distributed.Unique(); - ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), - sketch_on_single_node.ColumnsPtr().size()); - ASSERT_EQ(sketch_distributed.Data().size(), - sketch_on_single_node.Data().size()); + ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size()); + ASSERT_EQ(sketch_distributed.Data().size(), sketch_on_single_node.Data().size()); - TestQuantileElemRank(device, sketch_distributed.Data(), - sketch_distributed.ColumnsPtr(), true); + TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true); - std::vector single_node_data( - sketch_on_single_node.Data().size()); + std::vector single_node_data(sketch_on_single_node.Data().size()); dh::CopyDeviceSpanToVector(&single_node_data, sketch_on_single_node.Data()); std::vector distributed_data(sketch_distributed.Data().size()); @@ -444,7 +450,8 @@ void TestAllReduceBasic() { } // anonymous namespace TEST_F(MGPUQuantileTest, AllReduceBasic) { - DoTest(TestAllReduceBasic); + this->DoTest([] { TestAllReduceBasic(); }, true); + this->DoTest([] { TestAllReduceBasic(); }, false); } namespace { @@ -490,7 +497,8 @@ void TestColumnSplit(DMatrix* dmat) { TEST_F(MGPUQuantileTest, ColumnSplitBasic) { std::size_t constexpr kRows = 1000, kCols = 100; auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); - DoTest(TestColumnSplit, dmat.get()); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, true); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, false); } TEST_F(MGPUQuantileTest, ColumnSplitCategorical) { @@ -507,15 +515,15 @@ TEST_F(MGPUQuantileTest, ColumnSplitCategorical) { .Type(ft) .MaxCategory(13) .GenerateDMatrix(); - DoTest(TestColumnSplit, dmat.get()); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, true); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, false); } namespace { void TestSameOnAllWorkers() { auto world = collective::GetWorldSize(); constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, - MetaInfo const &info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { auto const rank = collective::GetRank(); auto const device = DeviceOrd::CUDA(GPUIDX); Context ctx = MakeCUDACtx(device.ordinal); @@ -536,7 +544,8 @@ void TestSameOnAllWorkers() { // Test for all workers having the same sketch. size_t n_data = sketch_distributed.Data().size(); - collective::Allreduce(&n_data, 1); + auto rc = collective::Allreduce(&ctx, linalg::MakeVec(&n_data, 1), collective::Op::kMax); + SafeColl(rc); ASSERT_EQ(n_data, sketch_distributed.Data().size()); size_t size_as_float = sketch_distributed.Data().size_bytes() / sizeof(float); @@ -549,9 +558,10 @@ void TestSameOnAllWorkers() { thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(), all_workers.begin() + local_data.size() * rank); - collective::AllReduce(device.ordinal, all_workers.data().get(), - all_workers.size()); - collective::Synchronize(device.ordinal); + rc = collective::Allreduce( + &ctx, linalg::MakeVec(all_workers.data().get(), all_workers.size(), ctx.Device()), + collective::Op::kSum); + SafeColl(rc); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); std::vector h_base_line(base_line.size()); @@ -573,7 +583,8 @@ void TestSameOnAllWorkers() { } // anonymous namespace TEST_F(MGPUQuantileTest, SameOnAllWorkers) { - DoTest(TestSameOnAllWorkers); + this->DoTest([] { TestSameOnAllWorkers(); }, true); + this->DoTest([] { TestSameOnAllWorkers(); }, false); } TEST(GPUQuantile, Push) { diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index d34c5e0e4..38ace76c4 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -1,21 +1,22 @@ +/** + * Copyright 2020-2024, XGBoost Contributors + */ #ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ #define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ #include -#include #include #include "../helpers.h" -namespace xgboost { -namespace common { +namespace xgboost::common { template void RunWithSeedsAndBins(size_t rows, Fn fn) { std::vector seeds(2); SimpleLCG lcg; SimpleRealUniformDistribution dist(3, 1000); std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); }); - std::vector bins(2); + std::vector bins(2); for (size_t i = 0; i < bins.size() - 1; ++i) { bins[i] = i * 35 + 2; } @@ -36,7 +37,6 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { } } } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index a7d9a0c76..837ca7768 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -3,15 +3,16 @@ */ #include "test_metainfo.h" -#include #include +#include #include #include #include -#include "../filesystem.h" // dmlc::TemporaryDirectory -#include "../helpers.h" // for GMockTHrow +#include "../collective/test_worker.h" // for TestDistributedGlobal +#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../helpers.h" // for GMockTHrow #include "xgboost/base.h" namespace xgboost { @@ -118,8 +119,8 @@ void VerifyGetSetFeatureColumnSplit() { } // anonymous namespace TEST(MetaInfo, GetSetFeatureColumnSplit) { - auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit); + auto constexpr kWorkers{3}; + collective::TestDistributedGlobal(kWorkers, VerifyGetSetFeatureColumnSplit); } TEST(MetaInfo, SaveLoadBinary) { diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 6334d96c6..ea6eedbb2 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -9,6 +9,7 @@ #include "../../../src/data/adapter.h" // ArrayAdapter #include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData #include "xgboost/base.h" @@ -444,5 +445,5 @@ void VerifyColumnSplit() { TEST(SimpleDMatrix, ColumnSplit) { auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit); + collective::TestDistributedGlobal(kWorldSize, VerifyColumnSplit); } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 273cc0f00..97f3db077 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -520,67 +520,6 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); } -template -void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) { - auto run = [&](auto rank) { - Json config{JsonObject()}; - if constexpr (use_nccl) { - config["xgboost_communicator"] = String("in-memory-nccl"); - } else { - config["xgboost_communicator"] = String("in-memory"); - } - config["in_memory_world_size"] = world_size; - config["in_memory_rank"] = rank; - xgboost::collective::Init(config); - - std::forward(function)(std::forward(args)...); - - xgboost::collective::Finalize(); - }; -#if defined(_OPENMP) - common::ParallelFor(world_size, world_size, run); -#else - std::vector threads; - for (auto rank = 0; rank < world_size; rank++) { - threads.emplace_back(run, rank); - } - for (auto& thread : threads) { - thread.join(); - } -#endif -} - -class BaseMGPUTest : public ::testing::Test { - protected: - int world_size_; - bool use_nccl_{false}; - - void SetUp() override { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - // Use a single GPU to simulate distributed environment. - world_size_ = 3; - // NCCL doesn't like sharing a single GPU, so we use the adapter instead. - use_nccl_ = false; - } else { - // Use multiple GPUs for real. - world_size_ = n_gpus; - use_nccl_ = true; - } - } - - template - void DoTest(Function&& function, Args&&... args) { - if (use_nccl_) { - RunWithInMemoryCommunicator(world_size_, function, args...); - } else { - RunWithInMemoryCommunicator(world_size_, function, args...); - } - } -}; - -class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{}; - inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); } inline auto GMockThrow(StringView msg) { diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc deleted file mode 100644 index eea54fc32..000000000 --- a/tests/cpp/metric/test_auc.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include "test_auc.h" - -#include - -namespace xgboost { -namespace metric { - -TEST(Metric, DeclareUnifiedTest(BinaryAUC)) { VerifyBinaryAUC(); } - -TEST(Metric, DeclareUnifiedTest(MultiClassAUC)) { VerifyMultiClassAUC(); } - -TEST(Metric, DeclareUnifiedTest(RankingAUC)) { VerifyRankingAUC(); } - -TEST(Metric, DeclareUnifiedTest(PRAUC)) { VerifyPRAUC(); } - -TEST(Metric, DeclareUnifiedTest(MultiClassPRAUC)) { VerifyMultiClassPRAUC(); } - -TEST(Metric, DeclareUnifiedTest(RankingPRAUC)) { VerifyRankingPRAUC(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), BinaryAUCRowSplit) { - DoTest(VerifyBinaryAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), BinaryAUCColumnSplit) { - DoTest(VerifyBinaryAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassAUCRowSplit) { - DoTest(VerifyMultiClassAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassAUCColumnSplit) { - DoTest(VerifyMultiClassAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingAUCRowSplit) { - DoTest(VerifyRankingAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingAUCColumnSplit) { - DoTest(VerifyRankingAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PRAUCRowSplit) { - DoTest(VerifyPRAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PRAUCColumnSplit) { - DoTest(VerifyPRAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassPRAUCRowSplit) { - DoTest(VerifyMultiClassPRAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassPRAUCColumnSplit) { - DoTest(VerifyMultiClassPRAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingPRAUCRowSplit) { - DoTest(VerifyRankingPRAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingPRAUCColumnSplit) { - DoTest(VerifyRankingPRAUC, DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost diff --git a/tests/cpp/metric/test_auc.cu b/tests/cpp/metric/test_auc.cu deleted file mode 100644 index 430ab1d37..000000000 --- a/tests/cpp/metric/test_auc.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2021 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_auc.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_auc.h b/tests/cpp/metric/test_auc.h index cef6d9757..dc99ab2e9 100644 --- a/tests/cpp/metric/test_auc.h +++ b/tests/cpp/metric/test_auc.h @@ -7,11 +7,9 @@ #include "../helpers.h" -namespace xgboost { -namespace metric { - -inline void VerifyBinaryAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +namespace xgboost::metric { +inline void VerifyBinaryAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr uni_ptr{Metric::Create("auc", &ctx)}; Metric* metric = uni_ptr.get(); ASSERT_STREQ(metric->Name(), "auc"); @@ -53,8 +51,8 @@ inline void VerifyBinaryAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) 0.5, 1e-10); } -inline void VerifyMultiClassAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMultiClassAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr uni_ptr{Metric::Create("auc", &ctx)}; auto metric = uni_ptr.get(); @@ -114,8 +112,8 @@ inline void VerifyMultiClassAUC(DataSplitMode data_split_mode = DataSplitMode::k ASSERT_GT(auc, 0.714); } -inline void VerifyRankingAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRankingAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("auc", &ctx)}; // single group @@ -148,8 +146,8 @@ inline void VerifyRankingAUC(DataSplitMode data_split_mode = DataSplitMode::kRow 0.769841f, 1e-6); } -inline void VerifyPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyPRAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric* metric = xgboost::Metric::Create("aucpr", &ctx); ASSERT_STREQ(metric->Name(), "aucpr"); @@ -185,8 +183,8 @@ inline void VerifyPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { delete metric; } -inline void VerifyMultiClassPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMultiClassPRAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("aucpr", &ctx)}; @@ -209,8 +207,8 @@ inline void VerifyMultiClassPRAUC(DataSplitMode data_split_mode = DataSplitMode: ASSERT_GT(auc, 0.699); } -inline void VerifyRankingPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRankingPRAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("aucpr", &ctx)}; @@ -245,5 +243,4 @@ inline void VerifyRankingPRAUC(DataSplitMode data_split_mode = DataSplitMode::kR data_split_mode), 0.556021f, 0.001f); } -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_distributed_metric.cc b/tests/cpp/metric/test_distributed_metric.cc new file mode 100644 index 000000000..843ea5762 --- /dev/null +++ b/tests/cpp/metric/test_distributed_metric.cc @@ -0,0 +1,192 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include +#include // for DeviceOrd +#include // for DataSplitMode + +#include // for min +#include // for int32_t +#include // for function +#include // for string +#include // for thread + +#include "../collective/test_worker.h" // for TestDistributedGlobal +#include "test_auc.h" +#include "test_elementwise_metric.h" +#include "test_multiclass_metric.h" +#include "test_rank_metric.h" +#include "test_survival_metric.h" + +#if defined(XGBOOST_USE_FEDERATED) + +#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal + +#endif // defined(XGBOOST_USE_FEDERATED) + +namespace xgboost::metric { +namespace { +using Verifier = std::function; +struct Param { + bool is_dist; // is distributed + bool is_fed; // is federated learning + DataSplitMode split; // how to split data + Verifier v; // test function + std::string name; // metric name + DeviceOrd device; // device to run +}; + +class TestDistributedMetric : public ::testing::TestWithParam { + protected: + template + void Run(bool is_dist, bool is_fed, DataSplitMode split_mode, Fn fn, DeviceOrd device) { + if (!is_dist) { + fn(split_mode, device); + return; + } + + std::int32_t n_workers{0}; + if (device.IsCUDA()) { + n_workers = common::AllVisibleGPUs(); + } else { + n_workers = std::min(static_cast(std::thread::hardware_concurrency()), 3); + } + auto fn1 = [&]() { + auto r = collective::GetRank(); + if (device.IsCPU()) { + fn(split_mode, DeviceOrd::CPU()); + } else { + fn(split_mode, DeviceOrd::CUDA(r)); + } + }; + if (is_fed) { +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal(n_workers, fn1); +#endif // defined(XGBOOST_USE_FEDERATED) + } else { + collective::TestDistributedGlobal(n_workers, fn1); + } + } +}; +} // anonymous namespace + +TEST_P(TestDistributedMetric, BinaryAUCRowSplit) { + auto p = GetParam(); + this->Run(p.is_dist, p.is_fed, p.split, p.v, p.device); +} + +constexpr bool UseNCCL() { +#if defined(XGBOOST_USE_NCCL) + return true; +#else + return false; +#endif // defined(XGBOOST_USE_NCCL) +} + +constexpr bool UseCUDA() { +#if defined(XGBOOST_USE_CUDA) + return true; +#else + return false; +#endif // defined(XGBOOST_USE_CUDA) +} + +constexpr bool UseFederated() { +#if defined(XGBOOST_USE_FEDERATED) + return true; +#else + return false; +#endif +} + +auto MakeParamsForTest() { + std::vector cases; + + auto push = [&](std::string name, auto fn) { + for (bool is_federated : {false, true}) { + for (DataSplitMode m : {DataSplitMode::kCol, DataSplitMode::kRow}) { + for (auto d : {DeviceOrd::CPU(), DeviceOrd::CUDA(0)}) { + if (!is_federated && !UseNCCL() && d.IsCUDA()) { + // Federated doesn't use nccl. + continue; + } + if (!UseCUDA() && d.IsCUDA()) { + // skip CUDA tests + continue; + } + if (!UseFederated() && is_federated) { + // skip GRPC tests + continue; + } + + auto p = Param{true, is_federated, m, fn, name, d}; + cases.push_back(p); + if (!is_federated) { + // Add a local test. + p.is_dist = false; + cases.push_back(p); + } + } + } + } + }; + +#define REFLECT_NAME(name) push(#name, Verify##name) + // AUC + REFLECT_NAME(BinaryAUC); + REFLECT_NAME(MultiClassAUC); + REFLECT_NAME(RankingAUC); + REFLECT_NAME(PRAUC); + REFLECT_NAME(MultiClassPRAUC); + REFLECT_NAME(RankingPRAUC); + // Elementwise + REFLECT_NAME(RMSE); + REFLECT_NAME(RMSLE); + REFLECT_NAME(MAE); + REFLECT_NAME(MAPE); + REFLECT_NAME(MPHE); + REFLECT_NAME(LogLoss); + REFLECT_NAME(Error); + REFLECT_NAME(PoissonNegLogLik); + REFLECT_NAME(MultiRMSE); + REFLECT_NAME(Quantile); + // Multi-Class + REFLECT_NAME(MultiClassError); + REFLECT_NAME(MultiClassLogLoss); + // Ranking + REFLECT_NAME(Precision); + REFLECT_NAME(NDCG); + REFLECT_NAME(MAP); + REFLECT_NAME(NDCGExpGain); + // AFT + using namespace xgboost::common; // NOLINT + REFLECT_NAME(AFTNegLogLik); + REFLECT_NAME(IntervalRegressionAccuracy); + +#undef REFLECT_NAME + + return cases; +} + +INSTANTIATE_TEST_SUITE_P( + DistributedMetric, TestDistributedMetric, ::testing::ValuesIn(MakeParamsForTest()), + [](const ::testing::TestParamInfo& info) { + std::string result; + if (info.param.is_dist) { + result += "Dist_"; + } + if (info.param.is_fed) { + result += "Federated_"; + } + if (info.param.split == DataSplitMode::kRow) { + result += "RowSplit"; + } else { + result += "ColSplit"; + } + result += "_"; + result += info.param.device.IsCPU() ? "CPU" : "CUDA"; + result += "_"; + result += info.param.name; + return result; + }); +} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc deleted file mode 100644 index 11854ce88..000000000 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2018-2023 by XGBoost contributors - */ -#include "test_elementwise_metric.h" - -namespace xgboost::metric { -TEST(Metric, DeclareUnifiedTest(RMSE)) { VerifyRMSE(); } - -TEST(Metric, DeclareUnifiedTest(RMSLE)) { VerifyRMSLE(); } - -TEST(Metric, DeclareUnifiedTest(MAE)) { VerifyMAE(); } - -TEST(Metric, DeclareUnifiedTest(MAPE)) { VerifyMAPE(); } - -TEST(Metric, DeclareUnifiedTest(MPHE)) { VerifyMPHE(); } - -TEST(Metric, DeclareUnifiedTest(LogLoss)) { VerifyLogLoss(); } - -TEST(Metric, DeclareUnifiedTest(Error)) { VerifyError(); } - -TEST(Metric, DeclareUnifiedTest(PoissonNegLogLik)) { VerifyPoissonNegLogLik(); } - -TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { VerifyMultiRMSE(); } - -TEST(Metric, DeclareUnifiedTest(Quantile)) { VerifyQuantile(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSERowSplit) { - DoTest(VerifyRMSE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSEColumnSplit) { - DoTest(VerifyRMSE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSLERowSplit) { - DoTest(VerifyRMSLE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSLEColumnSplit) { - DoTest(VerifyRMSLE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAERowSplit) { - DoTest(VerifyMAE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAEColumnSplit) { - DoTest(VerifyMAE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPERowSplit) { - DoTest(VerifyMAPE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPEColumnSplit) { - DoTest(VerifyMAPE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MPHERowSplit) { - DoTest(VerifyMPHE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MPHEColumnSplit) { - DoTest(VerifyMPHE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), LogLossRowSplit) { - DoTest(VerifyLogLoss, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), LogLossColumnSplit) { - DoTest(VerifyLogLoss, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), ErrorRowSplit) { - DoTest(VerifyError, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), ErrorColumnSplit) { - DoTest(VerifyError, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PoissonNegLogLikRowSplit) { - DoTest(VerifyPoissonNegLogLik, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PoissonNegLogLikColumnSplit) { - DoTest(VerifyPoissonNegLogLik, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiRMSERowSplit) { - DoTest(VerifyMultiRMSE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiRMSEColumnSplit) { - DoTest(VerifyMultiRMSE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), QuantileRowSplit) { - DoTest(VerifyQuantile, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), QuantileColumnSplit) { - DoTest(VerifyQuantile, DataSplitMode::kCol); -} -} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_elementwise_metric.cu b/tests/cpp/metric/test_elementwise_metric.cu deleted file mode 100644 index c45db8f7f..000000000 --- a/tests/cpp/metric/test_elementwise_metric.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2018 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_elementwise_metric.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_elementwise_metric.h b/tests/cpp/metric/test_elementwise_metric.h index 4435c0807..70a106798 100644 --- a/tests/cpp/metric/test_elementwise_metric.h +++ b/tests/cpp/metric/test_elementwise_metric.h @@ -42,8 +42,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) } } -inline void VerifyRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRMSE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("rmse", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "rmse"); @@ -68,11 +68,11 @@ inline void VerifyRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.6708f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"rmse"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"rmse"}, device.ordinal); } -inline void VerifyRMSLE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRMSLE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("rmsle", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "rmsle"); @@ -97,11 +97,11 @@ inline void VerifyRMSLE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.2415f, 1e-4); delete metric; - CheckDeterministicMetricElementWise(StringView{"rmsle"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"rmsle"}, device.ordinal); } -inline void VerifyMAE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMAE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("mae", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "mae"); @@ -126,11 +126,11 @@ inline void VerifyMAE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.54f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"mae"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"mae"}, device.ordinal); } -inline void VerifyMAPE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMAPE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("mape", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "mape"); @@ -155,11 +155,11 @@ inline void VerifyMAPE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 1.3250f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"mape"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"mape"}, device.ordinal); } -inline void VerifyMPHE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMPHE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{xgboost::Metric::Create("mphe", &ctx)}; metric->Configure({}); ASSERT_STREQ(metric->Name(), "mphe"); @@ -183,7 +183,7 @@ inline void VerifyMPHE(DataSplitMode data_split_mode = DataSplitMode::kRow) { { 1, 2, 9, 8}, {}, data_split_mode), 0.1922f, 1e-4); - CheckDeterministicMetricElementWise(StringView{"mphe"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"mphe"}, device.ordinal); metric->Configure({{"huber_slope", "0.1"}}); EXPECT_NEAR(GetMetricEval(metric.get(), @@ -193,8 +193,8 @@ inline void VerifyMPHE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.0461686f, 1e-4); } -inline void VerifyLogLoss(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyLogLoss(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("logloss", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "logloss"); @@ -223,11 +223,11 @@ inline void VerifyLogLoss(DataSplitMode data_split_mode = DataSplitMode::kRow) { 1.3138f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"logloss"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"logloss"}, device.ordinal); } -inline void VerifyError(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyError(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("error", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "error"); @@ -285,11 +285,11 @@ inline void VerifyError(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.45f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"error@0.5"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"error@0.5"}, device.ordinal); } -inline void VerifyPoissonNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyPoissonNegLogLik(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "poisson-nloglik"); @@ -318,11 +318,11 @@ inline void VerifyPoissonNegLogLik(DataSplitMode data_split_mode = DataSplitMode 1.5783f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"poisson-nloglik"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"poisson-nloglik"}, device.ordinal); } -inline void VerifyMultiRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMultiRMSE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); size_t n_samples = 32, n_targets = 8; linalg::Tensor y{{n_samples, n_targets}, ctx.Device()}; auto &h_y = y.Data()->HostVector(); @@ -343,8 +343,8 @@ inline void VerifyMultiRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) ASSERT_FLOAT_EQ(ret, loss_w); } -inline void VerifyQuantile(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyQuantile(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("quantile", &ctx)}; HostDeviceVector predts{0.1f, 0.9f, 0.1f, 0.9f}; diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc deleted file mode 100644 index 7fc8bc429..000000000 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright by Contributors -#include "test_multiclass_metric.h" - -#include - -namespace xgboost { -namespace metric { - -TEST(Metric, DeclareUnifiedTest(MultiClassError)) { VerifyMultiClassError(); } - -TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { VerifyMultiClassLogLoss(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassErrorRowSplit) { - DoTest(VerifyMultiClassError, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassErrorColumnSplit) { - DoTest(VerifyMultiClassError, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassLogLossRowSplit) { - DoTest(VerifyMultiClassLogLoss, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassLogLossColumnSplit) { - DoTest(VerifyMultiClassLogLoss, DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost diff --git a/tests/cpp/metric/test_multiclass_metric.cu b/tests/cpp/metric/test_multiclass_metric.cu deleted file mode 100644 index 8a087565b..000000000 --- a/tests/cpp/metric/test_multiclass_metric.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2019 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_multiclass_metric.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_multiclass_metric.h b/tests/cpp/metric/test_multiclass_metric.h index 5fdead596..002e38cb1 100644 --- a/tests/cpp/metric/test_multiclass_metric.h +++ b/tests/cpp/metric/test_multiclass_metric.h @@ -44,8 +44,8 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) } } -inline void TestMultiClassError(int device, DataSplitMode data_split_mode) { - auto ctx = MakeCUDACtx(device); +inline void TestMultiClassError(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("merror", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "merror"); @@ -59,13 +59,13 @@ inline void TestMultiClassError(int device, DataSplitMode data_split_mode) { delete metric; } -inline void VerifyMultiClassError(DataSplitMode data_split_mode = DataSplitMode::kRow) { - TestMultiClassError(GPUIDX, data_split_mode); - CheckDeterministicMetricMultiClass(StringView{"merror"}, GPUIDX); +inline void VerifyMultiClassError(DataSplitMode data_split_mode, DeviceOrd device) { + TestMultiClassError(data_split_mode, device); + CheckDeterministicMetricMultiClass(StringView{"merror"}, device.ordinal); } -inline void TestMultiClassLogLoss(int device, DataSplitMode data_split_mode) { - auto ctx = MakeCUDACtx(device); +inline void TestMultiClassLogLoss(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("mlogloss", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "mlogloss"); @@ -80,9 +80,9 @@ inline void TestMultiClassLogLoss(int device, DataSplitMode data_split_mode) { delete metric; } -inline void VerifyMultiClassLogLoss(DataSplitMode data_split_mode = DataSplitMode::kRow) { - TestMultiClassLogLoss(GPUIDX, data_split_mode); - CheckDeterministicMetricMultiClass(StringView{"mlogloss"}, GPUIDX); +inline void VerifyMultiClassLogLoss(DataSplitMode data_split_mode, DeviceOrd device) { + TestMultiClassLogLoss(data_split_mode, device); + CheckDeterministicMetricMultiClass(StringView{"mlogloss"}, device.ordinal); } } // namespace metric diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index fbf0611b3..4c69847f8 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -1,84 +1,29 @@ /** - * Copyright 2016-2023 by XGBoost Contributors + * Copyright 2016-2023, XGBoost Contributors */ -#include // for Test, EXPECT_NEAR, ASSERT_STREQ -#include // for Context -#include // for MetaInfo, DMatrix -#include // for Matrix -#include // for Metric - -#include // for max -#include // for unique_ptr -#include // for vector - #include "test_rank_metric.h" -#include "../helpers.h" // for GetMetricEval, CreateEmptyGe... -#include "xgboost/base.h" // for bst_float, kRtEps -#include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/json.h" // for Json, String, Object -namespace xgboost { -namespace metric { +#include // for Test, EXPECT_NEAR, ASSERT_STREQ +#include // for Context +#include // for Metric -#if !defined(__CUDACC__) +#include // for unique_ptr + +#include "../helpers.h" // for GetMetricEval, CreateEmptyGe... +#include "xgboost/base.h" // for bst_float, kRtEps + +namespace xgboost::metric { TEST(Metric, AMS) { auto ctx = MakeCUDACtx(GPUIDX); EXPECT_ANY_THROW(Metric::Create("ams", &ctx)); - Metric* metric = Metric::Create("ams@0.5f", &ctx); + std::unique_ptr metric{Metric::Create("ams@0.5f", &ctx)}; ASSERT_STREQ(metric->Name(), "ams@0.5"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.311f, 0.001f); - EXPECT_NEAR(GetMetricEval(metric, - {0.1f, 0.9f, 0.1f, 0.9f}, - { 0, 0, 1, 1}), - 0.29710f, 0.001f); + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}), 0.311f, 0.001f); + EXPECT_NEAR(GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}), 0.29710f, + 0.001f); - delete metric; - metric = Metric::Create("ams@0", &ctx); + metric.reset(Metric::Create("ams@0", &ctx)); ASSERT_STREQ(metric->Name(), "ams@0"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.311f, 0.001f); - - delete metric; + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}), 0.311f, 0.001f); } -#endif - -TEST(Metric, DeclareUnifiedTest(Precision)) { VerifyPrecision(); } - -TEST(Metric, DeclareUnifiedTest(NDCG)) { VerifyNDCG(); } - -TEST(Metric, DeclareUnifiedTest(MAP)) { VerifyMAP(); } - -TEST(Metric, DeclareUnifiedTest(NDCGExpGain)) { VerifyNDCGExpGain(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PrecisionRowSplit) { - DoTest(VerifyPrecision, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PrecisionColumnSplit) { - DoTest(VerifyPrecision, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGRowSplit) { - DoTest(VerifyNDCG, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGColumnSplit) { - DoTest(VerifyNDCG, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPRowSplit) { - DoTest(VerifyMAP, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPColumnSplit) { - DoTest(VerifyMAP, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGExpGainRowSplit) { - DoTest(VerifyNDCGExpGain, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGExpGainColumnSplit) { - DoTest(VerifyNDCGExpGain, DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_rank_metric.cu b/tests/cpp/metric/test_rank_metric.cu deleted file mode 100644 index 38b4c72e1..000000000 --- a/tests/cpp/metric/test_rank_metric.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2019 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_rank_metric.cc" diff --git a/tests/cpp/metric/test_rank_metric.h b/tests/cpp/metric/test_rank_metric.h index 5d5e87072..bb4096288 100644 --- a/tests/cpp/metric/test_rank_metric.h +++ b/tests/cpp/metric/test_rank_metric.h @@ -19,8 +19,8 @@ namespace xgboost::metric { -inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyPrecision(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("pre", &ctx)}; ASSERT_STREQ(metric->Name(), "pre"); EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7); @@ -43,8 +43,8 @@ inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) 0.5f, 1e-7); } -inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyNDCG(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); Metric * metric = xgboost::Metric::Create("ndcg", &ctx); ASSERT_STREQ(metric->Name(), "ndcg"); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}, {}, {}, data_split_mode)); @@ -101,8 +101,8 @@ inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) { delete metric; } -inline void VerifyMAP(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMAP(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); Metric * metric = xgboost::Metric::Create("map", &ctx); ASSERT_STREQ(metric->Name(), "map"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 1, kRtEps); @@ -149,8 +149,8 @@ inline void VerifyMAP(DataSplitMode data_split_mode = DataSplitMode::kRow) { delete metric; } -inline void VerifyNDCGExpGain(DataSplitMode data_split_mode = DataSplitMode::kRow) { - Context ctx = MakeCUDACtx(GPUIDX); +inline void VerifyNDCGExpGain(DataSplitMode data_split_mode, DeviceOrd device) { + Context ctx = MakeCUDACtx(device.ordinal); auto p_fmat = xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix(); MetaInfo& info = p_fmat->Info(); diff --git a/tests/cpp/metric/test_survival_metric.cc b/tests/cpp/metric/test_survival_metric.cc index ded9c4b0e..1b02fd7e7 100644 --- a/tests/cpp/metric/test_survival_metric.cc +++ b/tests/cpp/metric/test_survival_metric.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) by Contributors 2020 +/** + * Copyright 2020-2023, XGBoost Contributors */ #include #include @@ -16,8 +16,7 @@ // CUDA conditional compile trick. #include "test_survival_metric.cu" -namespace xgboost { -namespace common { +namespace xgboost::common { /** Tests for Survival metrics that should run only on CPU **/ @@ -113,6 +112,4 @@ TEST(AFTLoss, IntervalCensored) { { 8.0000, 4.8004, 2.8805, 1.7284, 1.0372, 0.6231, 0.3872, 0.3031, 0.3740, 0.5839, 0.8995, 1.2878, 1.7231, 2.1878, 2.6707, 3.1647, 3.6653, 4.1699, 4.6770, 5.1856 }); } - -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/metric/test_survival_metric.cu b/tests/cpp/metric/test_survival_metric.cu index eec92dc99..ead8d11f2 100644 --- a/tests/cpp/metric/test_survival_metric.cu +++ b/tests/cpp/metric/test_survival_metric.cu @@ -7,28 +7,7 @@ /** Tests for Survival metrics that should run both on CPU and GPU **/ -namespace xgboost { -namespace common { -TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) { - DoTest(VerifyAFTNegLogLik, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikColumnSplit) { - DoTest(VerifyAFTNegLogLik, DataSplitMode::kCol); -} - -TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) { VerifyIntervalRegressionAccuracy(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), IntervalRegressionAccuracyRowSplit) { - DoTest(VerifyIntervalRegressionAccuracy, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), IntervalRegressionAccuracyColumnSplit) { - DoTest(VerifyIntervalRegressionAccuracy, DataSplitMode::kCol); -} - +namespace xgboost::common { // Test configuration of AFT metric TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) { auto ctx = MakeCUDACtx(GPUIDX); @@ -44,5 +23,4 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) { CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX); } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/metric/test_survival_metric.h b/tests/cpp/metric/test_survival_metric.h index 1626d3772..902c9aa6b 100644 --- a/tests/cpp/metric/test_survival_metric.h +++ b/tests/cpp/metric/test_survival_metric.h @@ -47,8 +47,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) } } -inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); /** * Test aggregate output from the AFT metric over a small test data set. @@ -78,8 +78,8 @@ inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kR } } -inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); auto p_fmat = EmptyDMatrix(); MetaInfo& info = p_fmat->Info(); @@ -101,7 +101,7 @@ inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = Dat info.labels_lower_bound_.HostVector()[0] = 70.0f; EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f); - CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, device.ordinal); } } // namespace common } // namespace xgboost diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index 21ffc7caf..efdd03612 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -50,7 +50,7 @@ class TestDefaultObjConfig : public ::testing::TestWithParam { public: void Run(std::string objective) { - auto Xy = MakeFmatForObjTest(objective); + auto Xy = MakeFmatForObjTest(objective, 10, 10); std::unique_ptr learner{Learner::Create({Xy})}; std::unique_ptr objfn{ObjFunction::Create(objective, &ctx_)}; diff --git a/tests/cpp/objective_helpers.cc b/tests/cpp/objective_helpers.cc index ed80f71d5..9ad4b5c39 100644 --- a/tests/cpp/objective_helpers.cc +++ b/tests/cpp/objective_helpers.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include "objective_helpers.h" @@ -7,17 +7,17 @@ #include "helpers.h" // for RandomDataGenerator namespace xgboost { -std::shared_ptr MakeFmatForObjTest(std::string const& obj) { - auto constexpr kRows = 10, kCols = 10; - auto p_fmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); + +void MakeLabelForObjTest(std::shared_ptr p_fmat, std::string const& obj) { auto& h_upper = p_fmat->Info().labels_upper_bound_.HostVector(); auto& h_lower = p_fmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(kRows); - h_upper.resize(kRows); - for (size_t i = 0; i < kRows; ++i) { + h_lower.resize(p_fmat->Info().num_row_); + h_upper.resize(p_fmat->Info().num_row_); + for (size_t i = 0; i < p_fmat->Info().num_row_; ++i) { h_lower[i] = 1; h_upper[i] = 10; } + if (obj.find("rank:") != std::string::npos) { auto h_label = p_fmat->Info().labels.HostView(); std::size_t k = 0; @@ -26,6 +26,12 @@ std::shared_ptr MakeFmatForObjTest(std::string const& obj) { ++k; } } +} + +std::shared_ptr MakeFmatForObjTest(std::string const& obj, bst_idx_t n_samples, + bst_feature_t n_features) { + auto p_fmat = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + MakeLabelForObjTest(p_fmat, obj); return p_fmat; }; } // namespace xgboost diff --git a/tests/cpp/objective_helpers.h b/tests/cpp/objective_helpers.h index 7f394ef8d..972747c36 100644 --- a/tests/cpp/objective_helpers.h +++ b/tests/cpp/objective_helpers.h @@ -32,5 +32,11 @@ inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo MakeFmatForObjTest(std::string const& obj); +/** + * @brief Construct random label for testing. + */ +void MakeLabelForObjTest(std::shared_ptr p_fmat, std::string const& obj); + +std::shared_ptr MakeFmatForObjTest(std::string const& obj, bst_idx_t n_samples, + bst_feature_t n_features); } // namespace xgboost diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index 237bdeb9d..008952a4f 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -108,6 +108,32 @@ TEST_F(FederatedCollTestGPU, Allreduce) { }); } +TEST(FederatedCollGPUGlobal, Allreduce) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederatedGlobal(n_workers, [&] { + auto r = collective::GetRank(); + auto world = collective::GetWorldSize(); + CHECK_EQ(n_workers, world); + + dh::device_vector values(3, r); + auto ctx = MakeCUDACtx(r); + auto rc = collective::Allreduce( + &ctx, linalg::MakeVec(values.data().get(), values.size(), DeviceOrd::CUDA(r)), + Op::kBitwiseOR); + SafeColl(rc); + + std::vector expected(values.size(), 0); + for (std::int32_t rank = 0; rank < world; ++rank) { + for (std::size_t i = 0; i < expected.size(); ++i) { + expected[i] |= rank; + } + } + for (std::size_t i = 0; i < expected.size(); ++i) { + CHECK_EQ(expected[i], values[i]); + } + }); +} + TEST_F(FederatedCollTestGPU, Broadcast) { std::int32_t n_workers = common::AllVisibleGPUs(); TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h index d0edecc15..8ec76237d 100644 --- a/tests/cpp/plugin/federated/test_worker.h +++ b/tests/cpp/plugin/federated/test_worker.h @@ -11,12 +11,24 @@ #include "../../../../plugin/federated/federated_tracker.h" #include "../../../../src/collective/comm_group.h" +#include "../../../../src/collective/communicator-inl.h" #include "federated_comm.h" // for FederatedComm #include "xgboost/json.h" // for Json namespace xgboost::collective { +inline Json FederatedTestConfig(std::int32_t n_workers, std::int32_t port, std::int32_t i) { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"federated"}; + config["dmlc_task_id"] = std::to_string(i); + config["dmlc_retry"] = 2; + config["federated_world_size"] = n_workers; + config["federated_rank"] = i; + config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); + return config; +} + template -void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { +void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) { Json config{Object()}; config["federated_secure"] = Boolean{false}; config["n_workers"] = Integer{n_workers}; @@ -30,16 +42,7 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { std::int32_t port = tracker.Port(); for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { - Json config{Object{}}; - config["federated_world_size"] = n_workers; - config["federated_rank"] = i; - config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); - auto comm = std::make_shared( - DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config); - - fn(comm, i); - }); + workers.emplace_back([=] { fn(port, i); }); } for (auto& t : workers) { @@ -51,39 +54,33 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { ASSERT_TRUE(fut.get().OK()); } +template +void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + auto comm = std::make_shared( + DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config); + + fn(comm, i); + }); +} + template void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) { - Json config{Object()}; - config["federated_secure"] = Boolean{false}; - config["n_workers"] = Integer{n_workers}; - FederatedTracker tracker{config}; - auto fut = tracker.Run(); + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + std::shared_ptr comm_group{CommGroup::Create(config)}; + fn(comm_group, i); + }); +} - std::vector workers; - auto rc = tracker.WaitUntilReady(); - ASSERT_TRUE(rc.OK()) << rc.Report(); - std::int32_t port = tracker.Port(); - - for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { - Json config{Object{}}; - config["dmlc_communicator"] = std::string{"federated"}; - config["dmlc_task_id"] = std::to_string(i); - config["dmlc_retry"] = 2; - config["federated_world_size"] = n_workers; - config["federated_rank"] = i; - config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); - std::shared_ptr comm_group{CommGroup::Create(config)}; - fn(comm_group, i); - }); - } - - for (auto& t : workers) { - t.join(); - } - - rc = tracker.Shutdown(); - ASSERT_TRUE(rc.OK()) << rc.Report(); - ASSERT_TRUE(fut.get().OK()); +template +void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) { + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + collective::Init(config); + fn(); + collective::Finalize(); + }); } } // namespace xgboost::collective diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h deleted file mode 100644 index 85f2e014b..000000000 --- a/tests/cpp/plugin/helpers.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2022-2023, XGBoost contributors - */ -#pragma once - -#include -#include -#include -#include - -#include -#include // for thread, sleep_for - -#include "../../../plugin/federated/federated_server.h" -#include "../../../src/collective/communicator-inl.h" -#include "../../../src/common/threading_utils.h" - -namespace xgboost { - -class ServerForTest { - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; - - public: - explicit ServerForTest(std::size_t world_size) { - server_thread_.reset(new std::thread([this, world_size] { - grpc::ServerBuilder builder; - xgboost::federated::FederatedService service{static_cast(world_size)}; - int selected_port; - builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port); - builder.RegisterService(&service); - server_ = builder.BuildAndStart(); - server_address_ = std::string("localhost:") + std::to_string(selected_port); - server_->Wait(); - })); - } - - ~ServerForTest() { - using namespace std::chrono_literals; - while (!server_) { - std::this_thread::sleep_for(100ms); - } - server_->Shutdown(); - while (!server_thread_) { - std::this_thread::sleep_for(100ms); - } - server_thread_->join(); - } - - auto Address() const { - using namespace std::chrono_literals; - while (server_address_.empty()) { - std::this_thread::sleep_for(100ms); - } - return server_address_; - } -}; - -class BaseFederatedTest : public ::testing::Test { - protected: - void SetUp() override { server_ = std::make_unique(kWorldSize); } - - void TearDown() override { server_.reset(nullptr); } - - static int constexpr kWorldSize{2}; - std::unique_ptr server_; -}; - -template -void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address, - Function&& function, Args&&... args) { - auto run = [&](auto rank) { - Json config{JsonObject()}; - config["xgboost_communicator"] = String("federated"); - config["federated_secure"] = false; - config["federated_server_address"] = String(server_address); - config["federated_world_size"] = world_size; - config["federated_rank"] = rank; - xgboost::collective::Init(config); - - std::forward(function)(std::forward(args)...); - - xgboost::collective::Finalize(); - }; -#if defined(_OPENMP) - common::ParallelFor(world_size, world_size, run); -#else - std::vector threads; - for (auto rank = 0; rank < world_size; rank++) { - threads.emplace_back(run, rank); - } - for (auto& thread : threads) { - thread.join(); - } -#endif -} - -} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu deleted file mode 100644 index b96524878..000000000 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ /dev/null @@ -1,97 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include -#include -#include - -#include "../../../plugin/federated/federated_communicator.h" -#include "../../../src/collective/communicator-inl.cuh" -#include "../../../src/collective/device_communicator_adapter.cuh" -#include "../helpers.h" -#include "./helpers.h" - -namespace xgboost::collective { - -class FederatedAdapterTest : public BaseFederatedTest {}; - -TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -namespace { -void VerifyAllReduceSum() { - auto const world_size = collective::GetWorldSize(); - auto const device = GPUIDX; - int count = 3; - common::SetDevice(device); - thrust::device_vector buffer(count, 0); - thrust::sequence(buffer.begin(), buffer.end()); - collective::AllReduce(device, buffer.data().get(), count); - thrust::host_vector host_buffer = buffer; - EXPECT_EQ(host_buffer.size(), count); - for (auto i = 0; i < count; i++) { - EXPECT_EQ(host_buffer[i], i * world_size); - } -} -} // anonymous namespace - -TEST_F(FederatedAdapterTest, MGPUAllReduceSum) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum); -} - -namespace { -void VerifyAllGather() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - auto const device = GPUIDX; - common::SetDevice(device); - thrust::device_vector send_buffer(1, rank); - thrust::device_vector receive_buffer(world_size, 0); - collective::AllGather(device, send_buffer.data().get(), receive_buffer.data().get(), - sizeof(double)); - thrust::host_vector host_buffer = receive_buffer; - EXPECT_EQ(host_buffer.size(), world_size); - for (auto i = 0; i < world_size; i++) { - EXPECT_EQ(host_buffer[i], i); - } -} -} // anonymous namespace - -TEST_F(FederatedAdapterTest, MGPUAllGather) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGather); -} - -namespace { -void VerifyAllGatherV() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - auto const device = GPUIDX; - int const count = rank + 2; - common::SetDevice(device); - thrust::device_vector buffer(count, 0); - thrust::sequence(buffer.begin(), buffer.end()); - std::vector segments(world_size); - dh::caching_device_vector receive_buffer{}; - - collective::AllGatherV(device, buffer.data().get(), count, &segments, &receive_buffer); - - EXPECT_EQ(segments[0], 2); - EXPECT_EQ(segments[1], 3); - thrust::host_vector host_buffer = receive_buffer; - EXPECT_EQ(host_buffer.size(), 5); - int expected[] = {0, 1, 0, 1, 2}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(host_buffer[i], expected[i]); - } -} -} // anonymous namespace - -TEST_F(FederatedAdapterTest, MGPUAllGatherV) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV); -} -} // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc deleted file mode 100644 index 68b112f1c..000000000 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ /dev/null @@ -1,161 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include -#include - -#include "../../../plugin/federated/federated_communicator.h" -#include "helpers.h" - -namespace xgboost::collective { - -class FederatedCommunicatorTest : public BaseFederatedTest { - public: - static void VerifyAllgather(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckAllgather(comm, rank); - } - - static void VerifyAllgatherV(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckAllgatherV(comm, rank); - } - - static void VerifyAllreduce(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckAllreduce(comm); - } - - static void VerifyBroadcast(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckBroadcast(comm, rank); - } - - protected: - static void CheckAllgather(FederatedCommunicator &comm, int rank) { - std::string input{static_cast('0' + rank)}; - auto output = comm.AllGather(input); - for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(output[i], static_cast('0' + i)); - } - } - - static void CheckAllgatherV(FederatedCommunicator &comm, int rank) { - std::vector inputs{"Federated", " Learning!!!"}; - auto output = comm.AllGatherV(inputs[rank]); - EXPECT_EQ(output, "Federated Learning!!!"); - } - - static void CheckAllreduce(FederatedCommunicator &comm) { - int buffer[] = {1, 2, 3, 4, 5}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); - int expected[] = {2, 4, 6, 8, 10}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void CheckBroadcast(FederatedCommunicator &comm, int rank) { - if (rank == 0) { - std::string buffer{"hello"}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } else { - std::string buffer{" "}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } - } -}; - -TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = [] { FederatedCommunicator comm{0, 0, "localhost:0", "", "", ""}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = [] { FederatedCommunicator comm{1, -1, "localhost:0", "", "", ""}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = [] { FederatedCommunicator comm{1, 1, "localhost:0", "", "", ""}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { - auto construct = [] { - Json config{JsonObject()}; - config["federated_server_address"] = std::string("localhost:0"); - config["federated_world_size"] = std::string("1"); - config["federated_rank"] = Integer(0); - FederatedCommunicator::Create(config); - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { - auto construct = [] { - Json config{JsonObject()}; - config["federated_server_address"] = std::string("localhost:0"); - config["federated_world_size"] = 1; - config["federated_rank"] = std::string("0"); - FederatedCommunicator::Create(config); - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { - FederatedCommunicator comm{6, 3, "localhost:0"}; - EXPECT_EQ(comm.GetWorldSize(), 6); - EXPECT_EQ(comm.GetRank(), 3); -} - -TEST(FederatedCommunicatorSimpleTest, IsDistributed) { - FederatedCommunicator comm{2, 1, "localhost:0"}; - EXPECT_TRUE(comm.IsDistributed()); -} - -TEST_F(FederatedCommunicatorTest, Allgather) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedCommunicatorTest, AllgatherV) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgatherV, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedCommunicatorTest, Allreduce) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedCommunicatorTest, Broadcast) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} -} // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index 6a8233a0f..d0f649152 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -6,16 +6,13 @@ #include -#include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" #include "../filesystem.h" #include "../helpers.h" -#include "helpers.h" +#include "federated/test_worker.h" namespace xgboost { -class FederatedDataTest : public BaseFederatedTest {}; - void VerifyLoadUri() { auto const rank = collective::GetRank(); @@ -47,7 +44,8 @@ void VerifyLoadUri() { } } -TEST_F(FederatedDataTest, LoadUri) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri); +TEST(FederatedDataTest, LoadUri) { + static int constexpr kWorldSize{2}; + collective::TestFederatedGlobal(kWorldSize, [] { VerifyLoadUri(); }); } } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index a9adedc63..948914e0f 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -1,17 +1,19 @@ -/*! - * Copyright 2023 XGBoost contributors +/** + * Copyright 2023-2024, XGBoost contributors + * + * Some other tests for federated learning are in the main test suite (test_learner.cc), + * gaurded by the `XGBOOST_USE_FEDERATED`. */ #include #include #include #include -#include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" -#include "../../../src/common/linalg_op.h" +#include "../../../src/common/linalg_op.h" // for begin, end #include "../helpers.h" #include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator -#include "helpers.h" +#include "federated/test_worker.h" namespace xgboost { namespace { @@ -36,32 +38,16 @@ auto MakeModel(std::string tree_method, std::string device, std::string objectiv return model; } -void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model, - std::string tree_method, std::string device, std::string objective) { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); +void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_score, + Json expected_model, std::string const &tree_method, std::string device, + std::string const &objective) { + auto rank = collective::GetRank(); std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; if (rank == 0) { - auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); - auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(rows); - h_upper.resize(rows); - for (size_t i = 0; i < rows; ++i) { - h_lower[i] = 1; - h_upper[i] = 10; - } - - if (objective.find("rank:") != std::string::npos) { - auto h_label = dmat->Info().labels.HostView(); - std::size_t k = 0; - for (auto &v : h_label) { - v = k % 2 == 0; - ++k; - } - } + MakeLabelForObjTest(dmat, objective); } - std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; + std::shared_ptr sliced{dmat->SliceCol(collective::GetWorldSize(), rank)}; auto model = MakeModel(tree_method, device, objective, sliced); auto base_score = GetBaseScore(model); @@ -71,18 +57,15 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e } // namespace class VerticalFederatedLearnerTest : public ::testing::TestWithParam { - std::unique_ptr server_; static int constexpr kWorldSize{3}; protected: - void SetUp() override { server_ = std::make_unique(kWorldSize); } - void TearDown() override { server_.reset(nullptr); } - void Run(std::string tree_method, std::string device, std::string objective) { static auto constexpr kRows{16}; static auto constexpr kCols{16}; std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; + MakeLabelForObjTest(dmat, objective); auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); @@ -103,9 +86,9 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParamAddress(), &VerifyObjective, kRows, kCols, - score, model, tree_method, device, objective); + collective::TestFederatedGlobal(kWorldSize, [&]() { + VerifyObjective(kRows, kCols, score, model, tree_method, device, objective); + }); } }; diff --git a/tests/cpp/plugin/test_federated_metrics.cc b/tests/cpp/plugin/test_federated_metrics.cc deleted file mode 100644 index 1bdda567f..000000000 --- a/tests/cpp/plugin/test_federated_metrics.cc +++ /dev/null @@ -1,243 +0,0 @@ -/*! - * Copyright 2023 XGBoost contributors - */ -#include - -#include "../metric/test_auc.h" -#include "../metric/test_elementwise_metric.h" -#include "../metric/test_multiclass_metric.h" -#include "../metric/test_rank_metric.h" -#include "../metric/test_survival_metric.h" -#include "helpers.h" - -namespace { -class FederatedMetricTest : public xgboost::BaseFederatedTest {}; -} // anonymous namespace - -namespace xgboost { -namespace metric { -TEST_F(FederatedMetricTest, BinaryAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RankingAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RankingAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, PRAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, PRAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RMSERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RMSEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RMSLERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RMSLEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MAERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MAEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MAPERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MAPEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MPHERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MPHEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, LogLossRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, LogLossColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, ErrorRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, ErrorColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiRMSERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, QuantileRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, QuantileColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, PrecisionRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, PrecisionColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, NDCGRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, NDCGColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MAPRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MAPColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain, - DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost - -namespace xgboost { -namespace common { -TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy, - DataSplitMode::kCol); -} -} // namespace common -} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc deleted file mode 100644 index c40e58fa3..000000000 --- a/tests/cpp/plugin/test_federated_server.cc +++ /dev/null @@ -1,133 +0,0 @@ -/*! - * Copyright 2017-2020 XGBoost contributors - */ -#include - -#include -#include - -#include "federated_client.h" -#include "helpers.h" - -namespace xgboost { - -class FederatedServerTest : public BaseFederatedTest { - public: - static void VerifyAllgather(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllgather(client, rank); - } - - static void VerifyAllgatherV(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllgatherV(client, rank); - } - - static void VerifyAllreduce(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllreduce(client); - } - - static void VerifyBroadcast(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckBroadcast(client, rank); - } - - static void VerifyMixture(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - for (auto i = 0; i < 10; i++) { - CheckAllgather(client, rank); - CheckAllreduce(client); - CheckBroadcast(client, rank); - } - } - - protected: - static void CheckAllgather(federated::FederatedClient& client, int rank) { - int data[] = {rank}; - std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto reply = client.Allgather(send_buffer); - auto const* result = reinterpret_cast(reply.data()); - for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(result[i], i); - } - } - - static void CheckAllgatherV(federated::FederatedClient& client, int rank) { - std::vector inputs{"Hello,", " World!"}; - auto reply = client.AllgatherV(inputs[rank]); - EXPECT_EQ(reply, "Hello, World!"); - } - - static void CheckAllreduce(federated::FederatedClient& client) { - int data[] = {1, 2, 3, 4, 5}; - std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM); - auto const* result = reinterpret_cast(reply.data()); - int expected[] = {2, 4, 6, 8, 10}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(result[i], expected[i]); - } - } - - static void CheckBroadcast(federated::FederatedClient& client, int rank) { - std::string send_buffer{}; - if (rank == 0) { - send_buffer = "hello broadcast"; - } - auto reply = client.Broadcast(send_buffer, 0); - EXPECT_EQ(reply, "hello broadcast") << "rank " << rank; - } -}; - -TEST_F(FederatedServerTest, Allgather) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, AllgatherV) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllgatherV, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, Allreduce) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, Broadcast) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, Mixture) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -} // namespace xgboost diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 46b085916..637b77b25 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -12,6 +12,7 @@ #include "../../../src/data/proxy_dmatrix.h" #include "../../../src/gbm/gbtree.h" #include "../../../src/gbm/gbtree_model.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" #include "test_predictor.h" @@ -43,7 +44,7 @@ void TestColumnSplit() { TEST(CpuPredictor, BasicColumnSplit) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit); + collective::TestDistributedGlobal(kWorldSize, TestColumnSplit); } TEST(CpuPredictor, IterationRange) { @@ -157,7 +158,7 @@ TEST(CPUPredictor, CategoricalPrediction) { TEST(CPUPredictor, CategoricalPredictionColumnSplit) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, false, true); + collective::TestDistributedGlobal(kWorldSize, [] { TestCategoricalPrediction(false, true); }); } TEST(CPUPredictor, CategoricalPredictLeaf) { @@ -168,7 +169,7 @@ TEST(CPUPredictor, CategoricalPredictLeaf) { TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) { auto constexpr kWorldSize = 2; Context ctx; - RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, &ctx, true); + collective::TestDistributedGlobal(kWorldSize, [&] { TestCategoricalPredictLeaf(&ctx, true); }); } TEST(CpuPredictor, UpdatePredictionCache) { @@ -183,7 +184,8 @@ TEST(CpuPredictor, LesserFeatures) { TEST(CpuPredictor, LesserFeaturesColumnSplit) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestPredictionWithLesserFeaturesColumnSplit, false); + collective::TestDistributedGlobal(kWorldSize, + [] { TestPredictionWithLesserFeaturesColumnSplit(false); }); } TEST(CpuPredictor, Sparse) { diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 50e036b90..4895fb63f 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -12,6 +12,7 @@ #include "../../../src/data/device_adapter.cuh" #include "../../../src/data/proxy_dmatrix.h" #include "../../../src/gbm/gbtree_model.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal, BaseMGPUTest #include "../helpers.h" #include "test_predictor.h" @@ -85,7 +86,7 @@ void VerifyBasicColumnSplit(std::array, 32> const& expected_r } } // anonymous namespace -class MGPUPredictorTest : public BaseMGPUTest {}; +class MGPUPredictorTest : public collective::BaseMGPUTest {}; TEST_F(MGPUPredictorTest, BasicColumnSplit) { auto ctx = MakeCUDACtx(0); @@ -111,7 +112,8 @@ TEST_F(MGPUPredictorTest, BasicColumnSplit) { result[i - 1] = out_predictions_h; } - DoTest(VerifyBasicColumnSplit, result); + this->DoTest([&] { VerifyBasicColumnSplit(result); }, true); + this->DoTest([&] { VerifyBasicColumnSplit(result); }, false); } TEST(GPUPredictor, EllpackBasic) { @@ -209,7 +211,8 @@ TEST(GpuPredictor, LesserFeatures) { } TEST_F(MGPUPredictorTest, LesserFeaturesColumnSplit) { - RunWithInMemoryCommunicator(world_size_, TestPredictionWithLesserFeaturesColumnSplit, true); + this->DoTest([] { TestPredictionWithLesserFeaturesColumnSplit(true); }, true); + this->DoTest([] { TestPredictionWithLesserFeaturesColumnSplit(true); }, false); } // Very basic test of empty model @@ -277,7 +280,7 @@ TEST(GPUPredictor, IterationRange) { } TEST_F(MGPUPredictorTest, IterationRangeColumnSplit) { - TestIterationRangeColumnSplit(world_size_, true); + TestIterationRangeColumnSplit(common::AllVisibleGPUs(), true); } TEST(GPUPredictor, CategoricalPrediction) { @@ -285,7 +288,8 @@ TEST(GPUPredictor, CategoricalPrediction) { } TEST_F(MGPUPredictorTest, CategoricalPredictionColumnSplit) { - RunWithInMemoryCommunicator(world_size_, TestCategoricalPrediction, true, true); + this->DoTest([] { TestCategoricalPrediction(true, true); }, true); + this->DoTest([] { TestCategoricalPrediction(true, true); }, false); } TEST(GPUPredictor, CategoricalPredictLeaf) { @@ -294,8 +298,18 @@ TEST(GPUPredictor, CategoricalPredictLeaf) { } TEST_F(MGPUPredictorTest, CategoricalPredictionLeafColumnSplit) { - auto ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : collective::GetRank()); - RunWithInMemoryCommunicator(world_size_, TestCategoricalPredictLeaf, &ctx, true); + this->DoTest( + [&] { + auto ctx = MakeCUDACtx(collective::GetRank()); + TestCategoricalPredictLeaf(&ctx, true); + }, + true); + this->DoTest( + [&] { + auto ctx = MakeCUDACtx(collective::GetRank()); + TestCategoricalPredictLeaf(&ctx, true); + }, + false); } TEST(GPUPredictor, PredictLeafBasic) { @@ -325,7 +339,7 @@ TEST(GPUPredictor, Sparse) { } TEST_F(MGPUPredictorTest, SparseColumnSplit) { - TestSparsePredictionColumnSplit(world_size_, true, 0.2); - TestSparsePredictionColumnSplit(world_size_, true, 0.8); + TestSparsePredictionColumnSplit(common::AllVisibleGPUs(), true, 0.2); + TestSparsePredictionColumnSplit(common::AllVisibleGPUs(), true, 0.8); } } // namespace xgboost::predictor diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 4108d74b8..fde0e480b 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include "test_predictor.h" @@ -10,7 +10,6 @@ #include // for PredictionCacheEntry, Predictor, Predic... #include // for StringView -#include // for max #include // for numeric_limits #include // for shared_ptr #include // for unordered_map @@ -18,6 +17,7 @@ #include "../../../src/common/bitfield.h" // for LBitField32 #include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix #include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator #include "xgboost/json.h" // for Json, Object, get, String #include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector @@ -593,9 +593,23 @@ void TestIterationRangeColumnSplit(int world_size, bool use_gpu) { Json sliced_model{Object{}}; sliced->SaveModel(&sliced_model); - RunWithInMemoryCommunicator(world_size, VerifyIterationRangeColumnSplit, use_gpu, ranged_model, - sliced_model, kRows, kCols, kClasses, margin_ranged, margin_sliced, - leaf_ranged, leaf_sliced); +#if !defined(XGBOOST_USE_NCCL) + if (use_gpu) { + GTEST_SKIP_("Not compiled with NCCL"); + return; + } +#endif // defined(XGBOOST_USE_NCCL) + collective::TestDistributedGlobal(world_size, [&] { + VerifyIterationRangeColumnSplit(use_gpu, ranged_model, sliced_model, kRows, kCols, kClasses, + margin_ranged, margin_sliced, leaf_ranged, leaf_sliced); + }); + +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal(world_size, [&] { + VerifyIterationRangeColumnSplit(use_gpu, ranged_model, sliced_model, kRows, kCols, kClasses, + margin_ranged, margin_sliced, leaf_ranged, leaf_sliced); + }); +#endif // defined(XGBOOST_USE_FEDERATED) } void TestSparsePrediction(Context const *ctx, float sparsity) { @@ -701,8 +715,23 @@ void TestSparsePredictionColumnSplit(int world_size, bool use_gpu, float sparsit learner->SetParam("device", ctx.DeviceName()); learner->Predict(Xy, false, &sparse_predt, 0, 0); - RunWithInMemoryCommunicator(world_size, VerifySparsePredictionColumnSplit, use_gpu, model, - kRows, kCols, sparsity, sparse_predt.HostVector()); +#if !defined(XGBOOST_USE_NCCL) + if (use_gpu) { + GTEST_SKIP_("Not compiled with NCCL."); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + collective::TestDistributedGlobal(world_size, [&] { + VerifySparsePredictionColumnSplit(use_gpu, model, kRows, kCols, sparsity, + sparse_predt.HostVector()); + }); + +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal(world_size, [&] { + VerifySparsePredictionColumnSplit(use_gpu, model, kRows, kCols, sparsity, + sparse_predt.HostVector()); + }); +#endif // defined(XGBOOST_USE_FEDERATED) } void TestVectorLeafPrediction(Context const *ctx) { diff --git a/tests/cpp/rabit/allreduce_base_test.cc b/tests/cpp/rabit/allreduce_base_test.cc deleted file mode 100644 index 55cce5c7d..000000000 --- a/tests/cpp/rabit/allreduce_base_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -#define RABIT_CXXTESTDEFS_H -#if !defined(_WIN32) -#include - -#include -#include -#include "../../../rabit/src/allreduce_base.h" - -TEST(AllreduceBase, InitTask) -{ - rabit::engine::AllreduceBase base; - - std::string rabit_task_id = "rabit_task_id=1"; - char cmd[rabit_task_id.size()+1]; - std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); - cmd[rabit_task_id.size()] = '\0'; - - char* argv[] = {cmd}; - base.Init(1, argv); - EXPECT_EQ(base.task_id, "1"); -} - -TEST(AllreduceBase, InitWithRingReduce) -{ - rabit::engine::AllreduceBase base; - - std::string rabit_task_id = "rabit_task_id=1"; - char cmd[rabit_task_id.size()+1]; - std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); - cmd[rabit_task_id.size()] = '\0'; - - std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1"; - char cmd2[rabit_reduce_ring_mincount.size()+1]; - std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2); - cmd2[rabit_reduce_ring_mincount.size()] = '\0'; - - char* argv[] = {cmd, cmd2}; - base.Init(2, argv); - EXPECT_EQ(base.task_id, "1"); - EXPECT_EQ(base.reduce_ring_mincount, 1ul); -} -#endif // !defined(_WIN32) diff --git a/tests/cpp/rabit/test_utils.cc b/tests/cpp/rabit/test_utils.cc deleted file mode 100644 index 0b8787bdd..000000000 --- a/tests/cpp/rabit/test_utils.cc +++ /dev/null @@ -1,6 +0,0 @@ -#include -#include - -TEST(Utils, Assert) { - EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error); -} diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 541f53008..976ae2147 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -1,15 +1,14 @@ /** * Copyright 2017-2024, XGBoost contributors */ -#include #include -#include // for Learner -#include // for LogCheck_NE, CHECK_NE, LogCheck_EQ -#include // for ObjFunction -#include // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR +#include +#include // for Learner +#include // for LogCheck_NE, CHECK_NE, LogCheck_EQ +#include // for ObjFunction +#include // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR #include // for equal, transform -#include // for int32_t, int64_t, uint32_t #include // for size_t #include // for ofstream #include // for numeric_limits @@ -27,6 +26,7 @@ #include "../../src/common/io.h" // for LoadSequentialFile #include "../../src/common/linalg_op.h" // for ElementWiseTransformHost, begin, end #include "../../src/common/random.h" // for GlobalRandom +#include "./collective/test_worker.h" // for TestDistributedGlobal #include "dmlc/io.h" // for Stream #include "dmlc/omp.h" // for omp_get_max_threads #include "filesystem.h" // for TemporaryDirectory @@ -658,7 +658,7 @@ class TestColumnSplit : public ::testing::TestWithParam { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - auto p_fmat = MakeFmatForObjTest(objective); + auto p_fmat = MakeFmatForObjTest(objective, 10, 10); std::shared_ptr sliced{p_fmat->SliceCol(world_size, rank)}; std::unique_ptr learner{Learner::Create({sliced})}; learner->SetParam("tree_method", "approx"); @@ -682,7 +682,7 @@ class TestColumnSplit : public ::testing::TestWithParam { public: void Run(std::string objective) { - auto p_fmat = MakeFmatForObjTest(objective); + auto p_fmat = MakeFmatForObjTest(objective, 10, 10); std::unique_ptr learner{Learner::Create({p_fmat})}; learner->SetParam("tree_method", "approx"); learner->SetParam("objective", objective); @@ -703,7 +703,9 @@ class TestColumnSplit : public ::testing::TestWithParam { auto constexpr kWorldSize{3}; auto call = [this, &objective](auto&... args) { TestBaseScore(objective, args...); }; auto score = GetBaseScore(config); - RunWithInMemoryCommunicator(kWorldSize, call, score, model); + collective::TestDistributedGlobal(kWorldSize, [&] { + call(score, model); + }); } }; @@ -736,7 +738,7 @@ void VerifyColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Arg Json const& expected_model) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - auto p_fmat = MakeFmatForObjTest(""); + auto p_fmat = MakeFmatForObjTest("", 10, 10); std::shared_ptr sliced{p_fmat->SliceCol(world_size, rank)}; std::string device = "cpu"; if (use_gpu) { @@ -747,82 +749,99 @@ void VerifyColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Arg ASSERT_EQ(model, expected_model); } -void TestColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Args const& args) { - auto p_fmat = MakeFmatForObjTest(""); +void TestColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Args const& args, + bool federated) { + auto p_fmat = MakeFmatForObjTest("", 10, 10); std::string device = use_gpu ? "cuda:0" : "cpu"; auto model = GetModelWithArgs(p_fmat, tree_method, device, args); auto world_size{3}; if (use_gpu) { world_size = common::AllVisibleGPUs(); - // Simulate MPU on a single GPU. - if (world_size == 1) { + // Simulate MPU on a single GPU. Federated doesn't use nccl, can run multiple + // instances on the same GPU. + if (world_size == 1 && federated) { world_size = 3; } } - RunWithInMemoryCommunicator(world_size, VerifyColumnSplitWithArgs, tree_method, use_gpu, args, - model); + if (federated) { +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal( + world_size, [&] { VerifyColumnSplitWithArgs(tree_method, use_gpu, args, model); }); +#else + GTEST_SKIP_("Not compiled with federated learning."); +#endif // defined(XGBOOST_USE_FEDERATED) + } else { +#if !defined(XGBOOST_USE_NCCL) + if (use_gpu) { + GTEST_SKIP_("Not compiled with NCCL."); + return; + } +#endif // defined(XGBOOST_USE_NCCL) + collective::TestDistributedGlobal( + world_size, [&] { VerifyColumnSplitWithArgs(tree_method, use_gpu, args, model); }); + } } -void TestColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu) { - Args args{{"colsample_bytree", "0.5"}, {"colsample_bylevel", "0.6"}, {"colsample_bynode", "0.7"}}; - TestColumnSplitWithArgs(tree_method, use_gpu, args); -} +class ColumnSplitTrainingTest + : public ::testing::TestWithParam> { + public: + static void TestColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu, + bool federated) { + Args args{ + {"colsample_bytree", "0.5"}, {"colsample_bylevel", "0.6"}, {"colsample_bynode", "0.7"}}; + TestColumnSplitWithArgs(tree_method, use_gpu, args, federated); + } + static void TestColumnSplitInteractionConstraints(std::string const& tree_method, bool use_gpu, + bool federated) { + Args args{{"interaction_constraints", "[[0, 5, 7], [2, 8, 9], [1, 3, 6]]"}}; + TestColumnSplitWithArgs(tree_method, use_gpu, args, federated); + } + static void TestColumnSplitMonotoneConstraints(std::string const& tree_method, bool use_gpu, + bool federated) { + Args args{{"monotone_constraints", "(1,-1,0,1,1,-1,-1,0,0,1)"}}; + TestColumnSplitWithArgs(tree_method, use_gpu, args, federated); + } +}; -void TestColumnSplitInteractionConstraints(std::string const& tree_method, bool use_gpu) { - Args args{{"interaction_constraints", "[[0, 5, 7], [2, 8, 9], [1, 3, 6]]"}}; - TestColumnSplitWithArgs(tree_method, use_gpu, args); -} - -void TestColumnSplitMonotoneConstraints(std::string const& tree_method, bool use_gpu) { - Args args{{"monotone_constraints", "(1,-1,0,1,1,-1,-1,0,0,1)"}}; - TestColumnSplitWithArgs(tree_method, use_gpu, args); +auto MakeParamsForTest() { + std::vector> configs; + for (auto tm : {"hist", "approx"}) { +#if defined(XGBOOST_USE_CUDA) + std::array use_gpu{true, false}; +#else + std::array use_gpu{false}; +#endif + for (auto i : use_gpu) { +#if defined(XGBOOST_USE_FEDERATED) + std::array fed{true, false}; +#else + std::array fed{false}; +#endif + for (auto j : fed) { + configs.emplace_back(tm, i, j); + } + } + } + return configs; } } // anonymous namespace -TEST(ColumnSplitColumnSampler, Approx) { TestColumnSplitColumnSampler("approx", false); } - -TEST(ColumnSplitColumnSampler, Hist) { TestColumnSplitColumnSampler("hist", false); } - -#if defined(XGBOOST_USE_CUDA) -TEST(MGPUColumnSplitColumnSampler, GPUApprox) { TestColumnSplitColumnSampler("approx", true); } - -TEST(MGPUColumnSplitColumnSampler, GPUHist) { TestColumnSplitColumnSampler("hist", true); } -#endif // defined(XGBOOST_USE_CUDA) - -TEST(ColumnSplitInteractionConstraints, Approx) { - TestColumnSplitInteractionConstraints("approx", false); +TEST_P(ColumnSplitTrainingTest, ColumnSampler) { + auto param = GetParam(); + std::apply(TestColumnSplitColumnSampler, param); } -TEST(ColumnSplitInteractionConstraints, Hist) { - TestColumnSplitInteractionConstraints("hist", false); +TEST_P(ColumnSplitTrainingTest, InteractionConstraints) { + auto param = GetParam(); + std::apply(TestColumnSplitInteractionConstraints, param); } -#if defined(XGBOOST_USE_CUDA) -TEST(MGPUColumnSplitInteractionConstraints, GPUApprox) { - TestColumnSplitInteractionConstraints("approx", true); +TEST_P(ColumnSplitTrainingTest, MonotoneConstraints) { + auto param = GetParam(); + std::apply(TestColumnSplitMonotoneConstraints, param); } -TEST(MGPUColumnSplitInteractionConstraints, GPUHist) { - TestColumnSplitInteractionConstraints("hist", true); -} -#endif // defined(XGBOOST_USE_CUDA) - -TEST(ColumnSplitMonotoneConstraints, Approx) { - TestColumnSplitMonotoneConstraints("approx", false); -} - -TEST(ColumnSplitMonotoneConstraints, Hist) { - TestColumnSplitMonotoneConstraints("hist", false); -} - -#if defined(XGBOOST_USE_CUDA) -TEST(MGPUColumnSplitMonotoneConstraints, GPUApprox) { - TestColumnSplitMonotoneConstraints("approx", true); -} - -TEST(MGPUColumnSplitMonotoneConstraints, GPUHist) { - TestColumnSplitMonotoneConstraints("hist", true); -} -#endif // defined(XGBOOST_USE_CUDA) +INSTANTIATE_TEST_SUITE_P(ColumnSplit, ColumnSplitTrainingTest, + ::testing::ValuesIn(MakeParamsForTest())); } // namespace xgboost diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc index b93329c2e..37be97f08 100644 --- a/tests/cpp/test_main.cc +++ b/tests/cpp/test_main.cc @@ -1,15 +1,16 @@ -// Copyright by Contributors +/** + * Copyright 2016-2024, XGBoost Contributors + */ #include #include #include + #include -#include -#include #include "helpers.h" -int main(int argc, char ** argv) { - xgboost::Args args {{"verbosity", "2"}}; +int main(int argc, char** argv) { + xgboost::Args args{{"verbosity", "2"}}; xgboost::ConsoleLogger::Configure(args); testing::InitGoogleTest(&argc, argv); diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index f4accfc8a..72a8b5449 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -1,12 +1,12 @@ /** - * Copyright 2020-2023, XGBoost contributors + * Copyright 2020-2024, XGBoost contributors */ #include #include #include "../../../../src/tree/gpu_hist/evaluate_splits.cuh" +#include "../../collective/test_worker.h" // for BaseMGPUTest #include "../../helpers.h" -#include "../../histogram_helpers.h" #include "../test_evaluate_splits.h" // TestPartitionBasedSplit namespace xgboost::tree { @@ -17,13 +17,13 @@ auto ZeroParam() { tparam.UpdateAllowUnknown(args); return tparam; } -} // anonymous namespace -inline GradientQuantiser DummyRoundingFactor(Context const* ctx) { +GradientQuantiser DummyRoundingFactor(Context const* ctx) { thrust::device_vector gpair(1); gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000 return {ctx, dh::ToSpan(gpair), MetaInfo()}; } +} // anonymous namespace thrust::device_vector ConvertToInteger(Context const* ctx, std::vector x) { @@ -546,7 +546,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) { ASSERT_NEAR(split.loss_chg, best_score_, 1e-2); } -class MGPUHistTest : public BaseMGPUTest {}; +class MGPUHistTest : public collective::BaseMGPUTest {}; namespace { void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) { @@ -589,21 +589,29 @@ void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) { evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, ctx.Device()); DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split; - EXPECT_EQ(result.findex, 1) << "rank: " << rank; + EXPECT_EQ(result.findex, 1); if (is_categorical) { ASSERT_TRUE(std::isnan(result.fvalue)); } else { - EXPECT_EQ(result.fvalue, 11.0) << "rank: " << rank; + EXPECT_EQ(result.fvalue, 11.0); } - EXPECT_EQ(result.left_sum + result.right_sum, parent_sum) << "rank: " << rank; + EXPECT_EQ(result.left_sum + result.right_sum, parent_sum); } } // anonymous namespace TEST_F(MGPUHistTest, ColumnSplitEvaluateSingleSplit) { - DoTest(VerifyColumnSplitEvaluateSingleSplit, false); + if (common::AllVisibleGPUs() > 1) { + // We can't emulate multiple GPUs with NCCL. + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(false); }, false, true); + } + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(false); }, true, true); } TEST_F(MGPUHistTest, ColumnSplitEvaluateSingleCategoricalSplit) { - DoTest(VerifyColumnSplitEvaluateSingleSplit, true); + if (common::AllVisibleGPUs() > 1) { + // We can't emulate multiple GPUs with NCCL. + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(true); }, false, true); + } + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(true); }, true, true); } } // namespace xgboost::tree diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 5b48f2793..740175c57 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -33,6 +33,7 @@ #include "../../../../src/tree/hist/histogram.h" // for HistogramBuilder #include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam #include "../../categorical_helpers.h" // for OneHotEncodeFeature +#include "../../collective/test_worker.h" // for TestDistributedGlobal #include "../../helpers.h" // for RandomDataGenerator, GenerateRa... namespace xgboost::tree { @@ -300,8 +301,8 @@ TEST(CPUHistogram, BuildHist) { TEST(CPUHistogram, BuildHistColSplit) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true); - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, true, true); }); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, false, true); }); } namespace { diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 38da629b1..b2949e595 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -1,15 +1,15 @@ /** - * Copyright 2021-2023 by XGBoost contributors. + * Copyright 2021-2024, XGBoost contributors. */ #include #include "../../../src/common/numeric.h" #include "../../../src/tree/common_row_partitioner.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" #include "test_partitioner.h" -namespace xgboost { -namespace tree { +namespace xgboost::tree { namespace { std::vector GenerateHess(size_t n_samples) { auto grad = GenerateRandomGradients(n_samples); @@ -145,8 +145,9 @@ TEST(Approx, PartitionerColSplit) { } auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestColumnSplitPartitioner, n_samples, base_rowid, Xy, - &hess, min_value, mid_value, mid_partitioner); + collective::TestDistributedGlobal(kWorkers, [&] { + TestColumnSplitPartitioner(n_samples, base_rowid, Xy, &hess, min_value, mid_value, + mid_partitioner); + }); } -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h index 6506b54e8..a25e75aef 100644 --- a/tests/cpp/tree/test_evaluate_splits.h +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include // for GradientPairInternal, GradientPairPrecise @@ -14,7 +14,6 @@ #include // for numeric_limits #include // for iota #include // for make_tuple, tie, tuple -#include // for pair #include // for vector #include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow @@ -23,7 +22,6 @@ #include "../../../src/tree/param.h" // for TrainParam, GradStats #include "../../../src/tree/split_evaluator.h" // for TreeEvaluator #include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution -#include "gtest/gtest_pred_impl.h" // for AssertionResult, ASSERT_EQ, ASSERT_TRUE namespace xgboost::tree { /** @@ -96,13 +94,11 @@ class TestPartitionBasedSplit : public ::testing::Test { // enumerate all possible partitions to find the optimal split do { - int32_t thresh; - float score; std::vector sorted_hist(node_hist.size()); for (size_t i = 0; i < sorted_hist.size(); ++i) { sorted_hist[i] = node_hist[sorted_idx_[i]]; } - std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_); + auto [thresh, score] = enumerate({sorted_hist}, total_gpair_); if (score > best_score_) { best_score_ = score; } diff --git a/tests/cpp/tree/test_fit_stump.cc b/tests/cpp/tree/test_fit_stump.cc index d9441fd6f..720d87852 100644 --- a/tests/cpp/tree/test_fit_stump.cc +++ b/tests/cpp/tree/test_fit_stump.cc @@ -1,11 +1,12 @@ /** - * Copyright 2022-2023, XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include #include "../../src/common/linalg_op.h" #include "../../src/tree/fit_stump.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" namespace xgboost::tree { @@ -43,7 +44,7 @@ TEST(InitEstimation, FitStump) { #if defined(XGBOOST_USE_CUDA) TEST(InitEstimation, GPUFitStump) { Context ctx; - ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + ctx.UpdateAllowUnknown(Args{{"device", "cuda"}}); TestFitStump(&ctx); } #endif // defined(XGBOOST_USE_CUDA) @@ -51,6 +52,6 @@ TEST(InitEstimation, GPUFitStump) { TEST(InitEstimation, FitStumpColumnSplit) { Context ctx; auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol); + collective::TestDistributedGlobal(kWorldSize, [&] { TestFitStump(&ctx, DataSplitMode::kCol); }); } } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index aaeba13f1..c3a949008 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -13,14 +13,19 @@ #include "../../../src/common/common.h" #include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl #include "../../../src/data/ellpack_page.h" // for EllpackPage -#include "../../../src/tree/param.h" // for TrainParam +#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/updater_gpu_hist.cu" -#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../collective/test_worker.h" // for BaseMGPUTest +#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" #include "../histogram_helpers.h" #include "xgboost/context.h" #include "xgboost/json.h" +#if defined(XGBOOST_USE_FEDERATED) +#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal +#endif // defined(XGBOOST_USE_FEDERATED) + namespace xgboost::tree { TEST(GpuHist, DeviceHistogram) { // Ensures that node allocates correctly after reaching `kStopGrowingSize`. @@ -458,9 +463,9 @@ void VerifyHistColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& ex } } // anonymous namespace -class MGPUHistTest : public BaseMGPUTest {}; +class MGPUHistTest : public collective::BaseMGPUTest {}; -TEST_F(MGPUHistTest, GPUHistColumnSplit) { +TEST_F(MGPUHistTest, HistColumnSplit) { auto constexpr kRows = 32; auto constexpr kCols = 16; @@ -468,7 +473,8 @@ TEST_F(MGPUHistTest, GPUHistColumnSplit) { auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); RegTree expected_tree = GetHistTree(&ctx, dmat.get()); - DoTest(VerifyHistColumnSplit, kRows, kCols, expected_tree); + this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true); + this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false); } namespace { @@ -508,7 +514,7 @@ void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& } } // anonymous namespace -class MGPUApproxTest : public BaseMGPUTest {}; +class MGPUApproxTest : public collective::BaseMGPUTest {}; TEST_F(MGPUApproxTest, GPUApproxColumnSplit) { auto constexpr kRows = 32; @@ -518,6 +524,7 @@ TEST_F(MGPUApproxTest, GPUApproxColumnSplit) { auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); RegTree expected_tree = GetApproxTree(&ctx, dmat.get()); - DoTest(VerifyApproxColumnSplit, kRows, kCols, expected_tree); + this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true); + this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false); } } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index 963660f59..b8b9e46ca 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -5,7 +5,8 @@ #include #include -#include "../../../src/tree/param.h" // for TrainParam +#include "../../../src/tree/param.h" // for TrainParam +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" namespace xgboost::tree { @@ -118,8 +119,8 @@ void TestColumnSplit(bool categorical) { } auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, categorical, - std::cref(expected_tree)); + collective::TestDistributedGlobal( + kWorldSize, [&] { VerifyColumnSplit(kRows, kCols, categorical, expected_tree); }); } } // anonymous namespace diff --git a/tests/cpp/tree/test_multi_target_tree_model.cc b/tests/cpp/tree/test_multi_target_tree_model.cc index 0b5745a20..39e4cb4b5 100644 --- a/tests/cpp/tree/test_multi_target_tree_model.cc +++ b/tests/cpp/tree/test_multi_target_tree_model.cc @@ -11,26 +11,26 @@ namespace { auto MakeTreeForTest() { bst_target_t n_targets{3}; bst_feature_t n_features{4}; - RegTree tree{n_targets, n_features}; - CHECK(tree.IsMultiTarget()); + std::unique_ptr tree{std::make_unique(n_targets, n_features)}; + CHECK(tree->IsMultiTarget()); linalg::Vector base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()}; - tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), - left_weight.HostView(), right_weight.HostView()); + tree->ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), + left_weight.HostView(), right_weight.HostView()); return tree; } } // namespace TEST(MultiTargetTree, JsonIO) { auto tree = MakeTreeForTest(); - ASSERT_EQ(tree.NumNodes(), 3); - ASSERT_EQ(tree.NumTargets(), 3); - ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); - ASSERT_EQ(tree.Size(), 3); + ASSERT_EQ(tree->NumNodes(), 3); + ASSERT_EQ(tree->NumTargets(), 3); + ASSERT_EQ(tree->GetMultiTargetTree()->Size(), 3); + ASSERT_EQ(tree->Size(), 3); Json jtree{Object{}}; - tree.SaveModel(&jtree); + tree->SaveModel(&jtree); auto check_jtree = [](Json jtree, RegTree const& tree) { ASSERT_EQ(get(jtree["tree_param"]["num_nodes"]), std::to_string(tree.NumNodes())); @@ -40,7 +40,7 @@ TEST(MultiTargetTree, JsonIO) { ASSERT_EQ(get(jtree["left_children"]).size(), tree.NumNodes()); ASSERT_EQ(get(jtree["right_children"]).size(), tree.NumNodes()); }; - check_jtree(jtree, tree); + check_jtree(jtree, *tree); RegTree loaded; loaded.LoadModel(jtree); @@ -49,18 +49,18 @@ TEST(MultiTargetTree, JsonIO) { Json jtree1{Object{}}; loaded.SaveModel(&jtree1); - check_jtree(jtree1, tree); + check_jtree(jtree1, *tree); } TEST(MultiTargetTree, DumpDot) { auto tree = MakeTreeForTest(); - auto n_features = tree.NumFeatures(); + auto n_features = tree->NumFeatures(); FeatureMap fmap; for (bst_feature_t f = 0; f < n_features; ++f) { auto name = "feat_" + std::to_string(f); fmap.PushBack(f, name.c_str(), "q"); } - auto str = tree.DumpModel(fmap, true, "dot"); + auto str = tree->DumpModel(fmap, true, "dot"); ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos); ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos); diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 1c3651005..ce637caa4 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -13,6 +13,7 @@ #include "../../../src/tree/common_row_partitioner.h" #include "../../../src/tree/hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry #include "../../../src/tree/param.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" #include "test_partitioner.h" #include "xgboost/data.h" @@ -190,9 +191,10 @@ void TestColumnSplitPartitioner(bst_target_t n_targets) { } auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, VerifyColumnSplitPartitioner, n_targets, - n_samples, n_features, base_rowid, Xy, min_value, mid_value, - mid_partitioner); + collective::TestDistributedGlobal(kWorkers, [&] { + VerifyColumnSplitPartitioner(n_targets, n_samples, n_features, base_rowid, Xy, + min_value, mid_value, mid_partitioner); + }); } } // anonymous namespace @@ -245,8 +247,9 @@ void TestColumnSplit(bst_target_t n_targets) { } auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, &ctx, kRows, kCols, n_targets, - std::cref(expected_tree)); + collective::TestDistributedGlobal(kWorldSize, [&] { + VerifyColumnSplit(&ctx, kRows, kCols, n_targets, std::cref(expected_tree)); + }); } } // anonymous namespace diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index f7de0400d..a3923e9df 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -1,16 +1,14 @@ import multiprocessing import socket import sys -import time +from threading import Thread import numpy as np import pytest import xgboost as xgb from xgboost import RabitTracker, build_info, federated - -if sys.platform.startswith("win"): - pytest.skip("Skipping collective tests on Windows", allow_module_level=True) +from xgboost import testing as tm def run_rabit_worker(rabit_env, world_size): @@ -18,20 +16,21 @@ def run_rabit_worker(rabit_env, world_size): assert xgb.collective.get_world_size() == world_size assert xgb.collective.is_distributed() assert xgb.collective.get_processor_name() == socket.gethostname() - ret = xgb.collective.broadcast('test1234', 0) - assert str(ret) == 'test1234' + ret = xgb.collective.broadcast("test1234", 0) + assert str(ret) == "test1234" ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) assert np.array_equal(ret, np.asarray([2, 4, 6])) -def test_rabit_communicator(): +def test_rabit_communicator() -> None: world_size = 2 - tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) - tracker.start(world_size) + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) + tracker.start() workers = [] for _ in range(world_size): - worker = multiprocessing.Process(target=run_rabit_worker, - args=(tracker.worker_envs(), world_size)) + worker = multiprocessing.Process( + target=run_rabit_worker, args=(tracker.worker_args(), world_size) + ) workers.append(worker) worker.start() for worker in workers: @@ -39,39 +38,44 @@ def test_rabit_communicator(): assert worker.exitcode == 0 -def run_federated_worker(port, world_size, rank): - with xgb.collective.CommunicatorContext(xgboost_communicator='federated', - federated_server_address=f'localhost:{port}', - federated_world_size=world_size, - federated_rank=rank): +def run_federated_worker(port: int, world_size: int, rank: int) -> None: + with xgb.collective.CommunicatorContext( + dmlc_communicator="federated", + federated_server_address=f"localhost:{port}", + federated_world_size=world_size, + federated_rank=rank, + ): assert xgb.collective.get_world_size() == world_size assert xgb.collective.is_distributed() - assert xgb.collective.get_processor_name() == f'rank{rank}' - ret = xgb.collective.broadcast('test1234', 0) - assert str(ret) == 'test1234' - ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) - assert np.array_equal(ret, np.asarray([2, 4, 6])) + assert xgb.collective.get_processor_name() == f"rank:{rank}" + bret = xgb.collective.broadcast("test1234", 0) + assert str(bret) == "test1234" + aret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) + assert np.array_equal(aret, np.asarray([2, 4, 6])) +@pytest.mark.skipif(**tm.skip_win()) def test_federated_communicator(): if not build_info()["USE_FEDERATED"]: pytest.skip("XGBoost not built with federated learning enabled") port = 9091 world_size = 2 - server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size)) - server.start() - time.sleep(1) - if not server.is_alive(): + tracker = multiprocessing.Process( + target=federated.run_federated_server, + kwargs={"port": port, "n_workers": world_size}, + ) + tracker.start() + if not tracker.is_alive(): raise Exception("Error starting Federated Learning server") workers = [] for rank in range(world_size): - worker = multiprocessing.Process(target=run_federated_worker, - args=(port, world_size, rank)) + worker = multiprocessing.Process( + target=run_federated_worker, args=(port, world_size, rank) + ) workers.append(worker) worker.start() for worker in workers: worker.join() assert worker.exitcode == 0 - server.terminate() diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 1f42711a2..5d508f0d1 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -3,33 +3,33 @@ import sys import numpy as np import pytest +from hypothesis import HealthCheck, given, settings, strategies import xgboost as xgb from xgboost import RabitTracker, collective from xgboost import testing as tm -if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) - def test_rabit_tracker(): tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) - tracker.start(1) - with xgb.collective.CommunicatorContext(**tracker.worker_envs()): + tracker.start() + with xgb.collective.CommunicatorContext(**tracker.worker_args()): ret = xgb.collective.broadcast("test1234", 0) assert str(ret) == "test1234" @pytest.mark.skipif(**tm.not_linux()) def test_socket_error(): - tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) - tracker.start(1) - env = tracker.worker_envs() - env["DMLC_TRACKER_PORT"] = 0 - env["DMLC_WORKER_CONNECT_RETRY"] = 1 - with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"): + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2) + tracker.start() + env = tracker.worker_args() + env["dmlc_tracker_port"] = 0 + env["dmlc_retry"] = 1 + with pytest.raises(ValueError, match="Failed to bootstrap the communication."): with xgb.collective.CommunicatorContext(**env): pass + with pytest.raises(ValueError): + tracker.free() def run_rabit_ops(client, n_workers): @@ -70,6 +70,40 @@ def test_rabit_ops(): run_rabit_ops(client, n_workers) +def run_allreduce(client) -> None: + from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args + + workers = tm.get_client_workers(client) + rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) + n_workers = len(workers) + + def local_test(worker_id: int) -> None: + x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0) + with CommunicatorContext(**rabit_args): + k = np.asarray([1.0]) + for i in range(128): + m = collective.allreduce(k, collective.Op.SUM) + assert m == n_workers + + y = collective.allreduce(x, collective.Op.SUM) + np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers))) + + futures = client.map(local_test, range(len(workers)), workers=workers) + results = client.gather(futures) + + +@pytest.mark.skipif(**tm.no_dask()) +def test_allreduce() -> None: + from distributed import Client, LocalCluster + + n_workers = 4 + for i in range(2): + with LocalCluster(n_workers=n_workers) as cluster: + with Client(cluster) as client: + for i in range(2): + run_allreduce(client) + + def run_broadcast(client): from xgboost.dask import _get_dask_config, _get_rabit_args @@ -109,6 +143,7 @@ def test_rabit_ops_ipv6(): run_rabit_ops(client, n_workers) +@pytest.mark.skipif(**tm.no_dask()) def test_rank_assignment() -> None: from distributed import Client, LocalCluster @@ -133,3 +168,107 @@ def test_rank_assignment() -> None: futures = client.map(local_test, range(len(workers)), workers=workers) client.gather(futures) + + +@pytest.fixture +def local_cluster(): + from distributed import LocalCluster + + n_workers = 8 + with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster: + yield cluster + + +ops_strategy = strategies.lists( + strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"]) +) + + +@pytest.mark.skipif(**tm.no_dask()) +@given(ops=ops_strategy, size=strategies.integers(2**4, 2**16)) +@settings( + deadline=None, + print_blob=True, + max_examples=10, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_ops_restart_comm(local_cluster, ops, size) -> None: + from distributed import Client + + def local_test(w: int, n_workers: int) -> None: + a = np.arange(0, n_workers) + with xgb.dask.CommunicatorContext(**args): + for op in ops: + if op == "broadcast": + b = collective.broadcast(a, root=1) + np.testing.assert_allclose(b, a) + elif op == "allreduce_max": + b = collective.allreduce(a, collective.Op.MAX) + np.testing.assert_allclose(b, a) + elif op == "allreduce_sum": + b = collective.allreduce(a, collective.Op.SUM) + np.testing.assert_allclose(a * n_workers, b) + else: + raise ValueError() + + with Client(local_cluster) as client: + workers = tm.get_client_workers(client) + args = client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + + workers = tm.get_client_workers(client) + n_workers = len(workers) + + futures = client.map( + local_test, range(len(workers)), workers=workers, n_workers=n_workers + ) + client.gather(futures) + + +@pytest.mark.skipif(**tm.no_dask()) +def test_ops_reuse_comm(local_cluster) -> None: + from distributed import Client + + rng = np.random.default_rng(1994) + n_examples = 10 + ops = rng.choice( + ["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples + ).tolist() + + def local_test(w: int, n_workers: int) -> None: + a = np.arange(0, n_workers) + + with xgb.dask.CommunicatorContext(**args): + for op in ops: + if op == "broadcast": + b = collective.broadcast(a, root=1) + assert np.allclose(b, a) + elif op == "allreduce_max": + c = np.full_like(a, collective.get_rank()) + b = collective.allreduce(c, collective.Op.MAX) + assert np.allclose(b, n_workers - 1), b + elif op == "allreduce_sum": + b = collective.allreduce(a, collective.Op.SUM) + assert np.allclose(a * 8, b) + else: + raise ValueError() + + with Client(local_cluster) as client: + workers = tm.get_client_workers(client) + args = client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + + n_workers = len(workers) + + futures = client.map( + local_test, range(len(workers)), workers=workers, n_workers=n_workers + ) + client.gather(futures) diff --git a/tests/python/test_with_arrow.py b/tests/python/test_with_arrow.py index 4d12f32df..145cc0f2b 100644 --- a/tests/python/test_with_arrow.py +++ b/tests/python/test_with_arrow.py @@ -8,19 +8,14 @@ import xgboost as xgb from xgboost import testing as tm from xgboost.core import DataSplitMode -try: - import pandas as pd - import pyarrow as pa - import pyarrow.csv as pc -except ImportError: - pass - pytestmark = pytest.mark.skipif( tm.no_arrow()["condition"] or tm.no_pandas()["condition"], reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"], ) -dpath = "demo/data/" +import pandas as pd +import pyarrow as pa +import pyarrow.csv as pc class TestArrowTable: diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index dc1b0b669..61f33832a 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1098,9 +1098,10 @@ def test_pandas_input(): np.testing.assert_equal(model.feature_names_in_, np.array(feature_names)) columns = list(train.columns) - random.shuffle(columns) + rng.shuffle(columns) df_incorrect = df[columns] - with pytest.raises(ValueError): + + with pytest.raises(ValueError, match="feature_names mismatch"): model.predict(df_incorrect) clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic")