Use dlopen to load NCCL. (#9796)
This PR adds optional support for loading nccl with `dlopen` as an alternative of compile time linking. This is to address the size bloat issue with the PyPI binary release. - Add CMake option to load `nccl` at runtime. - Add an NCCL stub. After this, `nccl` will be fetched from PyPI when using pip to install XGBoost, either by a user or by `pyproject.toml`. Others who want to link the nccl at compile time can continue to do so without any change. At the moment, this is Linux only since we only support MNMG on Linux.
This commit is contained in:
parent
fedd9674c8
commit
0715ab3c10
@ -69,7 +69,10 @@ option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binar
|
||||
option(USE_CUDA "Build with GPU acceleration" OFF)
|
||||
option(USE_PER_THREAD_DEFAULT_STREAM "Build with per-thread default stream" ON)
|
||||
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
|
||||
# This is specifically designed for PyPI binary release and should be disabled for most of the cases.
|
||||
option(USE_DLOPEN_NCCL "Whether to load nccl dynamically." OFF)
|
||||
option(BUILD_WITH_SHARED_NCCL "Build with shared NCCL library." OFF)
|
||||
|
||||
if(USE_CUDA)
|
||||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND NOT DEFINED ENV{CUDAARCHS})
|
||||
set(GPU_COMPUTE_VER "" CACHE STRING
|
||||
@ -80,6 +83,7 @@ if(USE_CUDA)
|
||||
unset(GPU_COMPUTE_VER CACHE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# CUDA device LTO was introduced in CMake v3.25 and requires host LTO to also be enabled but can still
|
||||
# be explicitly disabled allowing for LTO on host only, host and device, or neither, but device-only LTO
|
||||
# is not a supproted configuration
|
||||
@ -115,6 +119,12 @@ endif()
|
||||
if(BUILD_WITH_SHARED_NCCL AND (NOT USE_NCCL))
|
||||
message(SEND_ERROR "Build XGBoost with -DUSE_NCCL=ON to enable BUILD_WITH_SHARED_NCCL.")
|
||||
endif()
|
||||
if(USE_DLOPEN_NCCL AND (NOT USE_NCCL))
|
||||
message(SEND_ERROR "Build XGBoost with -DUSE_NCCL=ON to enable USE_DLOPEN_NCCL.")
|
||||
endif()
|
||||
if(USE_DLOPEN_NCCL AND (NOT (CMAKE_SYSTEM_NAME STREQUAL "Linux")))
|
||||
message(SEND_ERROR "`USE_DLOPEN_NCCL` supports only Linux at the moment.")
|
||||
endif()
|
||||
if(JVM_BINDINGS AND R_LIB)
|
||||
message(SEND_ERROR "`R_LIB' is not compatible with `JVM_BINDINGS' as they both have customized configurations.")
|
||||
endif()
|
||||
|
||||
@ -171,17 +171,24 @@ function(xgboost_set_cuda_flags target)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
macro(xgboost_link_nccl target)
|
||||
function(xgboost_link_nccl target)
|
||||
set(xgboost_nccl_flags -DXGBOOST_USE_NCCL=1)
|
||||
if(USE_DLOPEN_NCCL)
|
||||
list(APPEND xgboost_nccl_flags -DXGBOOST_USE_DLOPEN_NCCL=1)
|
||||
endif()
|
||||
|
||||
if(BUILD_STATIC_LIB)
|
||||
target_include_directories(${target} PUBLIC ${NCCL_INCLUDE_DIR})
|
||||
target_compile_definitions(${target} PUBLIC -DXGBOOST_USE_NCCL=1)
|
||||
target_compile_definitions(${target} PUBLIC ${xgboost_nccl_flags})
|
||||
target_link_libraries(${target} PUBLIC ${NCCL_LIBRARY})
|
||||
else()
|
||||
target_include_directories(${target} PRIVATE ${NCCL_INCLUDE_DIR})
|
||||
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NCCL=1)
|
||||
target_compile_definitions(${target} PRIVATE ${xgboost_nccl_flags})
|
||||
if(NOT USE_DLOPEN_NCCL)
|
||||
target_link_libraries(${target} PRIVATE ${NCCL_LIBRARY})
|
||||
endif()
|
||||
endmacro()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# compile options
|
||||
macro(xgboost_target_properties target)
|
||||
|
||||
@ -54,17 +54,24 @@ find_path(NCCL_INCLUDE_DIR
|
||||
NAMES nccl.h
|
||||
HINTS ${NCCL_ROOT}/include $ENV{NCCL_ROOT}/include)
|
||||
|
||||
find_library(NCCL_LIBRARY
|
||||
if(USE_DLOPEN_NCCL)
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(Nccl DEFAULT_MSG NCCL_INCLUDE_DIR)
|
||||
|
||||
mark_as_advanced(NCCL_INCLUDE_DIR)
|
||||
else()
|
||||
find_library(NCCL_LIBRARY
|
||||
NAMES ${NCCL_LIB_NAME}
|
||||
HINTS ${NCCL_ROOT}/lib $ENV{NCCL_ROOT}/lib/)
|
||||
|
||||
message(STATUS "Using nccl library: ${NCCL_LIBRARY}")
|
||||
message(STATUS "Using nccl library: ${NCCL_LIBRARY}")
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(Nccl DEFAULT_MSG
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(Nccl DEFAULT_MSG
|
||||
NCCL_INCLUDE_DIR NCCL_LIBRARY)
|
||||
|
||||
mark_as_advanced(
|
||||
mark_as_advanced(
|
||||
NCCL_INCLUDE_DIR
|
||||
NCCL_LIBRARY
|
||||
)
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -536,6 +536,37 @@ Troubleshooting
|
||||
- MIG (Multi-Instance GPU) is not yet supported by NCCL. You will receive an error message
|
||||
that includes `Multiple processes within a communication group ...` upon initialization.
|
||||
|
||||
.. _nccl-load:
|
||||
|
||||
- Starting from version 2.1.0, to reduce the size of the binary wheel, the XGBoost package
|
||||
(installed using pip) loads NCCL from the environment instead of bundling it
|
||||
directly. This means that if you encounter an error message like
|
||||
"Failed to load nccl ...", it indicates that NCCL is not installed or properly
|
||||
configured in your environment.
|
||||
|
||||
To resolve this issue, you can install NCCL using pip:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
pip install nvidia-nccl-cu12 # (or with any compatible CUDA version)
|
||||
|
||||
The default conda installation of XGBoost should not encounter this error. If you are
|
||||
using a customized XGBoost, please make sure one of the followings is true:
|
||||
|
||||
+ XGBoost is NOT compiled with the `USE_DLOPEN_NCCL` flag.
|
||||
+ The `dmlc_nccl_path` parameter is set to full NCCL path when initializing the collective.
|
||||
|
||||
Here are some additional tips for troubleshooting NCCL dependency issues:
|
||||
|
||||
+ Check the NCCL installation path and verify that it's installed correctly. We try to
|
||||
find NCCL by using ``from nvidia.nccl import lib`` in Python when XGBoost is installed
|
||||
using pip.
|
||||
+ Ensure that you have the correct CUDA version installed. NCCL requires a compatible
|
||||
CUDA version to function properly.
|
||||
+ If you are not using distributed training with XGBoost and yet see this error, please
|
||||
open an issue on GitHub.
|
||||
+ If you continue to encounter NCCL dependency issues, please open an issue on GitHub.
|
||||
|
||||
************
|
||||
IPv6 Support
|
||||
************
|
||||
|
||||
@ -1613,6 +1613,8 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle);
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
|
||||
* `USE_DLOPEN_NCCL`.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
|
||||
@ -1,23 +1,24 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_STRING_VIEW_H_
|
||||
#define XGBOOST_STRING_VIEW_H_
|
||||
#include <xgboost/logging.h> // CHECK_LT
|
||||
#include <xgboost/span.h> // Span
|
||||
|
||||
#include <algorithm> // std::equal,std::min
|
||||
#include <iterator> // std::reverse_iterator
|
||||
#include <ostream> // std::ostream
|
||||
#include <string> // std::char_traits,std::string
|
||||
#include <algorithm> // for equal, min
|
||||
#include <cstddef> // for size_t
|
||||
#include <iterator> // for reverse_iterator
|
||||
#include <ostream> // for ostream
|
||||
#include <string> // for char_traits, string
|
||||
|
||||
namespace xgboost {
|
||||
struct StringView {
|
||||
private:
|
||||
using CharT = char; // unsigned char
|
||||
using CharT = char;
|
||||
using Traits = std::char_traits<CharT>;
|
||||
CharT const* str_{nullptr};
|
||||
size_t size_{0};
|
||||
std::size_t size_{0};
|
||||
|
||||
public:
|
||||
using value_type = CharT; // NOLINT
|
||||
@ -28,40 +29,41 @@ struct StringView {
|
||||
|
||||
public:
|
||||
constexpr StringView() = default;
|
||||
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||
constexpr StringView(value_type const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||
StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {} // NOLINT
|
||||
constexpr StringView(CharT const* str) // NOLINT
|
||||
constexpr StringView(value_type const* str) // NOLINT
|
||||
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}
|
||||
|
||||
CharT const& operator[](size_t p) const { return str_[p]; }
|
||||
CharT const& at(size_t p) const { // NOLINT
|
||||
[[nodiscard]] value_type const& operator[](std::size_t p) const { return str_[p]; }
|
||||
[[nodiscard]] explicit operator std::string() const { return {this->c_str(), this->size()}; }
|
||||
[[nodiscard]] value_type const& at(std::size_t p) const { // NOLINT
|
||||
CHECK_LT(p, size_);
|
||||
return str_[p];
|
||||
}
|
||||
constexpr std::size_t size() const { return size_; } // NOLINT
|
||||
constexpr bool empty() const { return size() == 0; } // NOLINT
|
||||
StringView substr(size_t beg, size_t n) const { // NOLINT
|
||||
[[nodiscard]] constexpr std::size_t size() const { return size_; } // NOLINT
|
||||
[[nodiscard]] constexpr bool empty() const { return size() == 0; } // NOLINT
|
||||
[[nodiscard]] StringView substr(std::size_t beg, std::size_t n) const { // NOLINT
|
||||
CHECK_LE(beg, size_);
|
||||
size_t len = std::min(n, size_ - beg);
|
||||
std::size_t len = std::min(n, size_ - beg);
|
||||
return {str_ + beg, len};
|
||||
}
|
||||
CharT const* c_str() const { return str_; } // NOLINT
|
||||
[[nodiscard]] value_type const* c_str() const { return str_; } // NOLINT
|
||||
|
||||
constexpr CharT const* cbegin() const { return str_; } // NOLINT
|
||||
constexpr CharT const* cend() const { return str_ + size(); } // NOLINT
|
||||
constexpr CharT const* begin() const { return str_; } // NOLINT
|
||||
constexpr CharT const* end() const { return str_ + size(); } // NOLINT
|
||||
[[nodiscard]] constexpr const_iterator cbegin() const { return str_; } // NOLINT
|
||||
[[nodiscard]] constexpr const_iterator cend() const { return str_ + size(); } // NOLINT
|
||||
[[nodiscard]] constexpr iterator begin() const { return str_; } // NOLINT
|
||||
[[nodiscard]] constexpr iterator end() const { return str_ + size(); } // NOLINT
|
||||
|
||||
const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator rbegin() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->end());
|
||||
}
|
||||
const_reverse_iterator crbegin() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator crbegin() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->end());
|
||||
}
|
||||
const_reverse_iterator rend() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator rend() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->begin());
|
||||
}
|
||||
const_reverse_iterator crend() const noexcept { // NOLINT
|
||||
[[nodiscard]] const_reverse_iterator crend() const noexcept { // NOLINT
|
||||
return const_reverse_iterator(this->begin());
|
||||
}
|
||||
};
|
||||
|
||||
@ -103,6 +103,7 @@ if __name__ == "__main__":
|
||||
if cli_args.use_cuda == 'ON':
|
||||
CONFIG['USE_CUDA'] = 'ON'
|
||||
CONFIG['USE_NCCL'] = 'ON'
|
||||
CONFIG["USE_DLOPEN_NCCL"] = "OFF"
|
||||
|
||||
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
|
||||
|
||||
|
||||
@ -5,9 +5,11 @@
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../src/collective/coll.h" // for Coll
|
||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedComm : public FederatedComm {
|
||||
@ -16,5 +18,9 @@ class CUDAFederatedComm : public FederatedComm {
|
||||
public:
|
||||
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
|
||||
[[nodiscard]] auto Stream() const { return stream_; }
|
||||
Comm* MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const override {
|
||||
LOG(FATAL) << "[Internal Error]: Invalid request for CUDA variant.";
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -10,12 +10,12 @@
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/collective/comm.h" // for HostComm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
class FederatedComm : public HostComm {
|
||||
std::shared_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
@ -64,6 +64,6 @@ class FederatedComm : public Comm {
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -15,6 +15,8 @@ class BuildConfiguration: # pylint: disable=R0902
|
||||
use_cuda: bool = False
|
||||
# Whether to enable NCCL
|
||||
use_nccl: bool = False
|
||||
# Whether to load nccl dynamically
|
||||
use_dlopen_nccl: bool = False
|
||||
# Whether to enable HDFS
|
||||
use_hdfs: bool = False
|
||||
# Whether to enable Azure Storage
|
||||
|
||||
@ -29,7 +29,8 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"numpy",
|
||||
"scipy"
|
||||
"scipy",
|
||||
"nvidia-nccl-cu12 ; platform_system == 'Linux' and platform_machine != 'aarch64'"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@ -2,14 +2,15 @@
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typing import _T
|
||||
from .core import _LIB, _check_call, c_str, from_pystr_to_cstr, py_str
|
||||
from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.collective]")
|
||||
|
||||
@ -250,6 +251,31 @@ class CommunicatorContext:
|
||||
|
||||
def __init__(self, **args: Any) -> None:
|
||||
self.args = args
|
||||
key = "dmlc_nccl_path"
|
||||
if args.get(key, None) is not None:
|
||||
return
|
||||
|
||||
binfo = build_info()
|
||||
if not binfo["USE_DLOPEN_NCCL"]:
|
||||
return
|
||||
|
||||
try:
|
||||
# PyPI package of NCCL.
|
||||
from nvidia.nccl import lib
|
||||
|
||||
# There are two versions of nvidia-nccl, one is from PyPI, another one from
|
||||
# nvidia-pyindex. We support only the first one as the second one is too old
|
||||
# (2.9.8 as of writing).
|
||||
if lib.__file__ is not None:
|
||||
dirname: Optional[str] = os.path.dirname(lib.__file__)
|
||||
else:
|
||||
dirname = None
|
||||
|
||||
if dirname:
|
||||
path = os.path.join(dirname, "libnccl.so.2")
|
||||
self.args[key] = path
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> Dict[str, Any]:
|
||||
init(**self.args)
|
||||
|
||||
@ -184,6 +184,13 @@ def _py_version() -> str:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def _register_log_callback(lib: ctypes.CDLL) -> None:
|
||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||
lib.callback = _get_log_callback_func() # type: ignore
|
||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||
raise XGBoostError(lib.XGBGetLastError())
|
||||
|
||||
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
"""Load xgboost Library."""
|
||||
lib_paths = find_lib_path()
|
||||
@ -228,10 +235,7 @@ Likely causes:
|
||||
Error message(s): {os_error_list}
|
||||
"""
|
||||
)
|
||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||
lib.callback = _get_log_callback_func() # type: ignore
|
||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||
raise XGBoostError(lib.XGBGetLastError())
|
||||
_register_log_callback(lib)
|
||||
|
||||
def parse(ver: str) -> Tuple[int, int, int]:
|
||||
"""Avoid dependency on packaging (PEP 440)."""
|
||||
|
||||
@ -7,8 +7,6 @@
|
||||
#include <cinttypes> // for strtoimax
|
||||
#include <cmath> // for nan
|
||||
#include <cstring> // for strcmp
|
||||
#include <fstream> // for operator<<, basic_ostream, ios, stringstream
|
||||
#include <functional> // for less
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
|
||||
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
|
||||
@ -22,7 +20,6 @@
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/linalg_op.h" // for ElementWiseTransformHost
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
|
||||
#include "../data/ellpack_page.h" // for EllpackPage
|
||||
@ -35,14 +32,12 @@
|
||||
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
|
||||
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||
#include "rabit/c_api.h" // for RabitLinkTag
|
||||
#include "rabit/rabit.h" // for CheckPoint, LoadCheckPoint
|
||||
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
|
||||
#include "xgboost/feature_map.h" // for FeatureMap
|
||||
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/intrusive_ptr.h" // for xgboost
|
||||
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
|
||||
#include "xgboost/learner.h" // for Learner, PredictionType
|
||||
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
|
||||
@ -79,6 +74,7 @@ void XGBBuildInfoDevice(Json *p_info) {
|
||||
info["USE_CUDA"] = Boolean{false};
|
||||
info["USE_NCCL"] = Boolean{false};
|
||||
info["USE_RMM"] = Boolean{false};
|
||||
info["USE_DLOPEN_NCCL"] = Boolean{false};
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
|
||||
@ -33,8 +33,16 @@ void XGBBuildInfoDevice(Json *p_info) {
|
||||
info["USE_NCCL"] = Boolean{true};
|
||||
v = {Json{Integer{NCCL_MAJOR}}, Json{Integer{NCCL_MINOR}}, Json{Integer{NCCL_PATCH}}};
|
||||
info["NCCL_VERSION"] = v;
|
||||
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
info["USE_DLOPEN_NCCL"] = Boolean{true};
|
||||
#else
|
||||
info["USE_DLOPEN_NCCL"] = Boolean{false};
|
||||
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
|
||||
#else
|
||||
info["USE_NCCL"] = Boolean{false};
|
||||
info["USE_DLOPEN_NCCL"] = Boolean{false};
|
||||
#endif
|
||||
|
||||
#if defined(XGBOOST_USE_RMM)
|
||||
|
||||
@ -19,25 +19,6 @@ Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
|
||||
|
||||
NCCLColl::~NCCLColl() = default;
|
||||
namespace {
|
||||
Result GetNCCLResult(ncclResult_t code) {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
|
||||
auto fatal = [] {
|
||||
LOG(FATAL) << "Invalid type for NCCL operation.";
|
||||
@ -94,11 +75,12 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
|
||||
common::Span<std::int8_t> data, Op op) {
|
||||
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
|
||||
auto* device_buffer = buffer.data().get();
|
||||
auto stub = pcomm->Stub();
|
||||
|
||||
// First gather data from all the workers.
|
||||
CHECK(handle);
|
||||
auto rc = GetNCCLResult(
|
||||
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
|
||||
auto rc = GetNCCLResult(stub, stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8,
|
||||
handle, pcomm->Stream()));
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -149,6 +131,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
return Success() << [&] {
|
||||
if (IsBitwiseOp(op)) {
|
||||
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
|
||||
@ -156,9 +140,9 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
return DispatchDType(type, [=](auto t) {
|
||||
using T = decltype(t);
|
||||
auto rdata = common::RestoreType<T>(data);
|
||||
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
auto rc = stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||
return GetNCCLResult(rc);
|
||||
return GetNCCLResult(stub, rc);
|
||||
});
|
||||
}
|
||||
} << [&] { return nccl->Block(); };
|
||||
@ -171,9 +155,11 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||
nccl->Handle(), nccl->Stream()));
|
||||
return GetNCCLResult(stub, stub->Broadcast(data.data(), data.data(), data.size_bytes(),
|
||||
ncclInt8, root, nccl->Handle(), nccl->Stream()));
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
@ -184,10 +170,12 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(
|
||||
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
|
||||
return GetNCCLResult(stub, stub->Allgather(send.data(), data.data(), size, ncclInt8,
|
||||
nccl->Handle(), nccl->Stream()));
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
@ -199,19 +187,20 @@ namespace cuda_impl {
|
||||
*/
|
||||
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
auto stub = comm->Stub();
|
||||
return Success() << [&stub] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
if (rc != ncclSuccess) {
|
||||
return GetNCCLResult(rc);
|
||||
return GetNCCLResult(stub, rc);
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [] { return GetNCCLResult(ncclGroupEnd()); };
|
||||
} << [&] { return GetNCCLResult(stub, stub->GroupEnd()); };
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
@ -224,10 +213,11 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing: {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
return Success() << [&] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
// copy data
|
||||
@ -237,8 +227,8 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
cudaMemcpyDeviceToDevice, nccl->Stream()));
|
||||
}
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
} << [] {
|
||||
return GetNCCLResult(ncclGroupEnd());
|
||||
} << [&] {
|
||||
return GetNCCLResult(stub, stub->GroupEnd());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
case AllgatherVAlgo::kBcast: {
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "nccl_stub.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
@ -7,15 +7,12 @@
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdlib> // for exit
|
||||
#include <memory> // for shared_ptr
|
||||
#include <mutex> // for unique_lock
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#include "../common/json_utils.h" // for OptionalArg
|
||||
#include "allgather.h" // for RingAllgather
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "tracker.h" // for GetHostAddress
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json, Object
|
||||
@ -62,14 +59,6 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
this->Rank(), this->World());
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
common::AssertNCCLSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
|
||||
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
|
||||
proto::PeerInfo ninfo, std::chrono::seconds timeout,
|
||||
std::int32_t retry,
|
||||
@ -194,12 +183,21 @@ Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
}
|
||||
|
||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path)
|
||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
||||
nccl_path_{std::move(nccl_path)} {
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
common::AssertNCCLSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
|
||||
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id) {
|
||||
TCPSocket tracker;
|
||||
|
||||
@ -13,19 +13,21 @@
|
||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.cuh" // for NCCLComm
|
||||
#include "comm.h" // for Comm
|
||||
#include "nccl_stub.h" // for NcclStub
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
|
||||
Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared_ptr<Coll> coll,
|
||||
ncclUniqueId* pid) {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
auto rc = GetNCCLResult(stub, stub->GetUniqueId(&id));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
auto rc = coll->Broadcast(
|
||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||
@ -54,11 +56,12 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
return new NCCLComm{ctx, *this, pimpl};
|
||||
Comm* RabitComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
return new NCCLComm{ctx, *this, pimpl, StringView{this->nccl_path_}};
|
||||
}
|
||||
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
|
||||
StringView nccl_path)
|
||||
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||
root.TaskID()},
|
||||
stream_{ctx->CUDACtx()->Stream()} {
|
||||
@ -70,6 +73,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
stub_ = std::make_shared<NcclStub>(nccl_path);
|
||||
|
||||
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
|
||||
@ -95,19 +99,24 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
|
||||
rc = std::move(rc) << [&] {
|
||||
return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_);
|
||||
} << [&] {
|
||||
return GetNCCLResult(this->stub_, this->stub_->CommInitRank(&nccl_comm_, root.World(),
|
||||
nccl_unique_id_, root.Rank()));
|
||||
};
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||
|
||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||
this->channels_.emplace_back(
|
||||
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
|
||||
std::make_shared<NCCLChannel>(root, r, nccl_comm_, stub_, dh::DefaultStream()));
|
||||
}
|
||||
}
|
||||
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -6,9 +6,13 @@
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "coll.h"
|
||||
#include "comm.h"
|
||||
#include "nccl_stub.h" // for NcclStub
|
||||
#include "xgboost/context.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -21,15 +25,20 @@ inline Result GetCUDAResult(cudaError rc) {
|
||||
return Fail(msg);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
class NCCLComm : public Comm {
|
||||
ncclComm_t nccl_comm_{nullptr};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
dh::CUDAStreamView stream_;
|
||||
std::string nccl_path_;
|
||||
|
||||
public:
|
||||
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
|
||||
auto Stub() const { return stub_; }
|
||||
|
||||
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
|
||||
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
|
||||
StringView nccl_path);
|
||||
[[nodiscard]] Result LogTracker(std::string) const override {
|
||||
LOG(FATAL) << "Device comm is used for logging.";
|
||||
return Fail("Undefined.");
|
||||
@ -43,25 +52,53 @@ class NCCLComm : public Comm {
|
||||
}
|
||||
};
|
||||
|
||||
inline Result GetNCCLResult(std::shared_ptr<NcclStub> stub, ncclResult_t code) {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << stub->GetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
class NCCLChannel : public Channel {
|
||||
std::int32_t rank_{-1};
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
|
||||
dh::CUDAStreamView stream)
|
||||
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
|
||||
std::shared_ptr<NcclStub> stub, dh::CUDAStreamView stream)
|
||||
: rank_{rank},
|
||||
nccl_comm_{nccl_comm},
|
||||
stub_{std::move(stub)},
|
||||
Channel{comm, nullptr},
|
||||
stream_{stream} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
auto rc = GetNCCLResult(stub_, stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
auto rc = GetNCCLResult(stub_, stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
[[nodiscard]] Result Block() override {
|
||||
auto rc = stream_.Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -34,6 +34,8 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
||||
return nrank;
|
||||
}
|
||||
|
||||
inline StringView DefaultNcclName() { return "libnccl.so.2"; }
|
||||
|
||||
class Channel;
|
||||
class Coll;
|
||||
|
||||
@ -86,11 +88,21 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
|
||||
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
/**
|
||||
* @brief Base class for CPU-based communicator.
|
||||
*/
|
||||
class HostComm : public Comm {
|
||||
public:
|
||||
using Comm::Comm;
|
||||
[[nodiscard]] virtual Comm* MakeCUDAVar(Context const* ctx,
|
||||
std::shared_ptr<Coll> pimpl) const = 0;
|
||||
};
|
||||
|
||||
class RabitComm : public HostComm {
|
||||
std::string nccl_path_ = std::string{DefaultNcclName()};
|
||||
|
||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
[[nodiscard]] Result Shutdown();
|
||||
@ -100,13 +112,15 @@ class RabitComm : public Comm {
|
||||
RabitComm() = default;
|
||||
// ctor for testing where environment is known.
|
||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id);
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path);
|
||||
~RabitComm() noexcept(false) override;
|
||||
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||
|
||||
[[nodiscard]] Result SignalError(Result const&) override;
|
||||
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@ -37,7 +37,7 @@ namespace xgboost::collective {
|
||||
[[nodiscard]] Comm const& CommGroup::Ctx(Context const* ctx, DeviceOrd device) const {
|
||||
if (device.IsCUDA()) {
|
||||
CHECK(ctx->IsCUDA());
|
||||
if (!gpu_comm_) {
|
||||
if (!gpu_comm_ || gpu_comm_->World() != comm_->World()) {
|
||||
gpu_comm_.reset(comm_->MakeCUDAVar(ctx, backend_));
|
||||
}
|
||||
return *gpu_comm_;
|
||||
@ -55,7 +55,6 @@ CommGroup::CommGroup()
|
||||
}
|
||||
|
||||
std::string type = OptionalArg<String>(config, "dmlc_communicator", std::string{"rabit"});
|
||||
std::vector<std::string> keys;
|
||||
// Try both lower and upper case for compatibility
|
||||
auto get_param = [&](std::string name, auto dft, auto t) {
|
||||
std::string upper;
|
||||
@ -63,8 +62,6 @@ CommGroup::CommGroup()
|
||||
[](char c) { return std::toupper(c); });
|
||||
std::transform(name.cbegin(), name.cend(), name.begin(),
|
||||
[](char c) { return std::tolower(c); });
|
||||
keys.push_back(upper);
|
||||
keys.push_back(name);
|
||||
|
||||
auto const& obj = get<Object const>(config);
|
||||
auto it = obj.find(upper);
|
||||
@ -75,19 +72,19 @@ CommGroup::CommGroup()
|
||||
}
|
||||
};
|
||||
// Common args
|
||||
auto retry =
|
||||
OptionalArg<Integer>(config, "dmlc_retry", static_cast<Integer::Int>(DefaultRetry()));
|
||||
auto timeout = OptionalArg<Integer>(config, "dmlc_timeout_sec",
|
||||
static_cast<Integer::Int>(DefaultTimeoutSec()));
|
||||
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
|
||||
auto timeout =
|
||||
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||
|
||||
if (type == "rabit") {
|
||||
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
||||
auto ptr =
|
||||
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
|
||||
static_cast<std::int32_t>(retry), task_id}},
|
||||
static_cast<std::int32_t>(retry), task_id, nccl}},
|
||||
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||
return ptr;
|
||||
} else if (type == "federated") {
|
||||
|
||||
@ -17,14 +17,16 @@ namespace xgboost::collective {
|
||||
* collective implementations.
|
||||
*/
|
||||
class CommGroup {
|
||||
std::shared_ptr<Comm> comm_;
|
||||
std::shared_ptr<HostComm> comm_;
|
||||
mutable std::shared_ptr<Comm> gpu_comm_;
|
||||
|
||||
std::shared_ptr<Coll> backend_;
|
||||
mutable std::shared_ptr<Coll> gpu_coll_; // lazy initialization
|
||||
|
||||
CommGroup(std::shared_ptr<Comm> comm, std::shared_ptr<Coll> coll)
|
||||
: comm_{std::move(comm)}, backend_{std::move(coll)} {}
|
||||
: comm_{std::dynamic_pointer_cast<HostComm>(comm)}, backend_{std::move(coll)} {
|
||||
CHECK(comm_);
|
||||
}
|
||||
|
||||
public:
|
||||
CommGroup();
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "comm.h"
|
||||
#include "in_memory_communicator.h"
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
@ -14,8 +15,12 @@
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
thread_local std::string Communicator::nccl_path_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
|
||||
nccl_path_ = nccl;
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
|
||||
@ -31,17 +31,17 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemoryNccl:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
|
||||
@ -234,6 +234,7 @@ class Communicator {
|
||||
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
static thread_local std::string nccl_path_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
@ -2,12 +2,14 @@
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include "comm.cuh"
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
|
||||
StringView nccl_path)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
@ -18,6 +20,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
|
||||
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
@ -43,7 +46,9 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
auto rc =
|
||||
GetNCCLResult(stub_, stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
@ -51,7 +56,8 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
@ -137,8 +143,10 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
auto rc = GetNCCLResult(
|
||||
stub_, stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
@ -170,9 +178,10 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
auto rc = GetNCCLResult(stub_, stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op),
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
@ -185,8 +194,9 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
auto rc = GetNCCLResult(stub_, stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8,
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
@ -206,14 +216,19 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
auto rc = Success() << [&] { return GetNCCLResult(stub_, stub_->GroupStart()); } << [&] {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
auto rc = GetNCCLResult(
|
||||
stub_, stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
return Success();
|
||||
} << [&] { return GetNCCLResult(stub_, stub_->GroupEnd()); };
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
|
||||
@ -4,8 +4,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "nccl_stub.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
@ -25,7 +27,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
@ -64,7 +66,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rank_ == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
auto rc = GetNCCLResult(stub_, stub_->GetUniqueId(&id));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
@ -78,6 +81,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
|
||||
109
src/collective/nccl_stub.cc
Normal file
109
src/collective/nccl_stub.cc
Normal file
@ -0,0 +1,109 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include "nccl_stub.h"
|
||||
|
||||
#include <cuda.h> // for CUDA_VERSION
|
||||
#include <dlfcn.h> // for dlclose, dlsym, dlopen
|
||||
#include <nccl.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <sstream> // for stringstream
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
CHECK(!path_.empty()) << "Empty path for NCCL.";
|
||||
|
||||
auto cu_major = (CUDA_VERSION) / 1000;
|
||||
std::stringstream ss;
|
||||
ss << R"m(
|
||||
|
||||
If XGBoost is installed from PyPI with pip, the error can fixed by:
|
||||
|
||||
- Run `pip install nvidia-nccl-cu)m"
|
||||
<< cu_major << "` (Or with any CUDA version that's compatible with " << cu_major << ").";
|
||||
ss << R"m(
|
||||
|
||||
Otherwise, please refer to:
|
||||
|
||||
https://xgboost.readthedocs.io/en/stable/tutorials/dask.html#troubleshooting
|
||||
|
||||
for more info, or open an issue on GitHub. Starting from XGBoost 2.1.0, the PyPI package
|
||||
no long bundles NCCL in the binary wheel.
|
||||
|
||||
)m";
|
||||
auto help = ss.str();
|
||||
std::string msg{"Failed to load NCCL from path: `" + path_ + "`. Error:\n "};
|
||||
|
||||
auto safe_load = [&](auto t, StringView name) {
|
||||
std::stringstream errs;
|
||||
auto ptr = reinterpret_cast<decltype(t)>(dlsym(handle_, name.c_str()));
|
||||
if (!ptr) {
|
||||
errs << "Failed to load NCCL symbol `" << name << "` from " << path_ << ". Error:\n "
|
||||
<< dlerror() << help;
|
||||
LOG(FATAL) << errs.str();
|
||||
}
|
||||
return ptr;
|
||||
};
|
||||
|
||||
handle_ = dlopen(path_.c_str(), RTLD_LAZY);
|
||||
if (!handle_) {
|
||||
LOG(FATAL) << msg << dlerror() << help;
|
||||
}
|
||||
|
||||
allreduce_ = safe_load(allreduce_, "ncclAllReduce");
|
||||
broadcast_ = safe_load(broadcast_, "ncclBroadcast");
|
||||
allgather_ = safe_load(allgather_, "ncclAllGather");
|
||||
comm_init_rank_ = safe_load(comm_init_rank_, "ncclCommInitRank");
|
||||
comm_destroy_ = safe_load(comm_destroy_, "ncclCommDestroy");
|
||||
get_uniqueid_ = safe_load(get_uniqueid_, "ncclGetUniqueId");
|
||||
send_ = safe_load(send_, "ncclSend");
|
||||
recv_ = safe_load(recv_, "ncclRecv");
|
||||
group_start_ = safe_load(group_start_, "ncclGroupStart");
|
||||
group_end_ = safe_load(group_end_, "ncclGroupEnd");
|
||||
get_error_string_ = safe_load(get_error_string_, "ncclGetErrorString");
|
||||
get_version_ = safe_load(get_version_, "ncclGetVersion");
|
||||
|
||||
std::int32_t v;
|
||||
CHECK_EQ(get_version_(&v), ncclSuccess);
|
||||
auto patch = v % 100;
|
||||
auto minor = (v / 100) % 100;
|
||||
auto major = v / 10000;
|
||||
|
||||
LOG(INFO) << "Loaded shared NCCL " << major << "." << minor << "." << patch << ":`" << path_
|
||||
<< "`" << std::endl;
|
||||
#else
|
||||
allreduce_ = ncclAllReduce;
|
||||
broadcast_ = ncclBroadcast;
|
||||
allgather_ = ncclAllGather;
|
||||
comm_init_rank_ = ncclCommInitRank;
|
||||
comm_destroy_ = ncclCommDestroy;
|
||||
get_uniqueid_ = ncclGetUniqueId;
|
||||
send_ = ncclSend;
|
||||
recv_ = ncclRecv;
|
||||
group_start_ = ncclGroupStart;
|
||||
group_end_ = ncclGroupEnd;
|
||||
get_error_string_ = ncclGetErrorString;
|
||||
get_version_ = ncclGetVersion;
|
||||
#endif
|
||||
};
|
||||
|
||||
NcclStub::~NcclStub() { // NOLINT
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
if (handle_) {
|
||||
auto rc = dlclose(handle_);
|
||||
if (rc != 0) {
|
||||
LOG(WARNING) << "Failed to close NCCL handle:" << dlerror();
|
||||
}
|
||||
}
|
||||
handle_ = nullptr;
|
||||
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
94
src/collective/nccl_stub.h
Normal file
94
src/collective/nccl_stub.h
Normal file
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <nccl.h>
|
||||
|
||||
#include <string> // for string
|
||||
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NcclStub {
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
void* handle_{nullptr};
|
||||
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
std::string path_;
|
||||
|
||||
decltype(ncclAllReduce)* allreduce_{nullptr};
|
||||
decltype(ncclBroadcast)* broadcast_{nullptr};
|
||||
decltype(ncclAllGather)* allgather_{nullptr};
|
||||
decltype(ncclCommInitRank)* comm_init_rank_{nullptr};
|
||||
decltype(ncclCommDestroy)* comm_destroy_{nullptr};
|
||||
decltype(ncclGetUniqueId)* get_uniqueid_{nullptr};
|
||||
decltype(ncclSend)* send_{nullptr};
|
||||
decltype(ncclRecv)* recv_{nullptr};
|
||||
decltype(ncclGroupStart)* group_start_{nullptr};
|
||||
decltype(ncclGroupEnd)* group_end_{nullptr};
|
||||
decltype(ncclGetErrorString)* get_error_string_{nullptr};
|
||||
decltype(ncclGetVersion)* get_version_{nullptr};
|
||||
|
||||
public:
|
||||
explicit NcclStub(StringView path);
|
||||
~NcclStub();
|
||||
|
||||
[[nodiscard]] ncclResult_t Allreduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
CHECK(allreduce_);
|
||||
return this->allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Broadcast(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
CHECK(broadcast_);
|
||||
return this->broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
CHECK(allgather_);
|
||||
return this->allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
|
||||
int rank) const {
|
||||
CHECK(comm_init_rank_);
|
||||
return this->comm_init_rank_(comm, nranks, commId, rank);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t CommDestroy(ncclComm_t comm) const {
|
||||
CHECK(comm_destroy_);
|
||||
return this->comm_destroy_(comm);
|
||||
}
|
||||
|
||||
[[nodiscard]] ncclResult_t GetUniqueId(ncclUniqueId* uniqueId) const {
|
||||
CHECK(get_uniqueid_);
|
||||
return this->get_uniqueid_(uniqueId);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Send(const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
int peer, ncclComm_t comm, cudaStream_t stream) {
|
||||
CHECK(send_);
|
||||
return send_(sendbuff, count, datatype, peer, comm, stream);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) const {
|
||||
CHECK(recv_);
|
||||
return recv_(recvbuff, count, datatype, peer, comm, stream);
|
||||
}
|
||||
[[nodiscard]] ncclResult_t GroupStart() const {
|
||||
CHECK(group_start_);
|
||||
return group_start_();
|
||||
}
|
||||
[[nodiscard]] ncclResult_t GroupEnd() const {
|
||||
CHECK(group_end_);
|
||||
return group_end_();
|
||||
}
|
||||
|
||||
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
|
||||
return get_error_string_(result);
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
@ -115,30 +115,6 @@ XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT
|
||||
}
|
||||
namespace dh {
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__)
|
||||
|
||||
inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file, int line) {
|
||||
if (code != ncclSuccess) {
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
|
||||
ss << " " << file << "(" << line << ")\n";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
LOG(FATAL) << ss.str();
|
||||
}
|
||||
|
||||
return code;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline int32_t CudaGetPointerDevice(void const *ptr) {
|
||||
int32_t device = -1;
|
||||
cudaPointerAttributes attr;
|
||||
|
||||
@ -21,11 +21,18 @@ command_wrapper="tests/ci_build/ci_build.sh gpu_build_centos7 docker --build-arg
|
||||
`"RAPIDS_VERSION_ARG=$RAPIDS_VERSION"
|
||||
|
||||
echo "--- Build libxgboost from the source"
|
||||
$command_wrapper tests/ci_build/prune_libnccl.sh
|
||||
$command_wrapper tests/ci_build/build_via_cmake.sh -DCMAKE_PREFIX_PATH="/opt/grpc;/opt/rmm" \
|
||||
-DUSE_CUDA=ON -DUSE_NCCL=ON -DUSE_OPENMP=ON -DHIDE_CXX_SYMBOLS=ON -DPLUGIN_FEDERATED=ON \
|
||||
-DPLUGIN_RMM=ON -DUSE_NCCL_LIB_PATH=ON -DNCCL_INCLUDE_DIR=/usr/include \
|
||||
-DNCCL_LIBRARY=/workspace/libnccl_static.a ${arch_flag}
|
||||
$command_wrapper tests/ci_build/build_via_cmake.sh \
|
||||
-DCMAKE_PREFIX_PATH="/opt/grpc;/opt/rmm" \
|
||||
-DUSE_CUDA=ON \
|
||||
-DUSE_OPENMP=ON \
|
||||
-DHIDE_CXX_SYMBOLS=ON \
|
||||
-DPLUGIN_FEDERATED=ON \
|
||||
-DPLUGIN_RMM=ON \
|
||||
-DUSE_NCCL=ON \
|
||||
-DUSE_NCCL_LIB_PATH=ON \
|
||||
-DNCCL_INCLUDE_DIR=/usr/include \
|
||||
-DUSE_DLOPEN_NCCL=ON \
|
||||
${arch_flag}
|
||||
echo "--- Build binary wheel"
|
||||
$command_wrapper bash -c \
|
||||
"cd python-package && rm -rf dist/* && pip wheel --no-deps -v . --wheel-dir dist/"
|
||||
|
||||
@ -21,11 +21,17 @@ command_wrapper="tests/ci_build/ci_build.sh gpu_build_centos7 docker --build-arg
|
||||
`"RAPIDS_VERSION_ARG=$RAPIDS_VERSION"
|
||||
|
||||
echo "--- Build libxgboost from the source"
|
||||
$command_wrapper tests/ci_build/prune_libnccl.sh
|
||||
$command_wrapper tests/ci_build/build_via_cmake.sh -DCMAKE_PREFIX_PATH="/opt/grpc" \
|
||||
-DUSE_CUDA=ON -DUSE_NCCL=ON -DUSE_OPENMP=ON -DHIDE_CXX_SYMBOLS=ON -DPLUGIN_FEDERATED=ON \
|
||||
-DUSE_NCCL_LIB_PATH=ON -DNCCL_INCLUDE_DIR=/usr/include \
|
||||
-DNCCL_LIBRARY=/workspace/libnccl_static.a ${arch_flag}
|
||||
$command_wrapper tests/ci_build/build_via_cmake.sh \
|
||||
-DCMAKE_PREFIX_PATH="/opt/grpc" \
|
||||
-DUSE_CUDA=ON \
|
||||
-DUSE_OPENMP=ON \
|
||||
-DHIDE_CXX_SYMBOLS=ON \
|
||||
-DPLUGIN_FEDERATED=ON \
|
||||
-DUSE_NCCL=ON \
|
||||
-DUSE_NCCL_LIB_PATH=ON \
|
||||
-DNCCL_INCLUDE_DIR=/usr/include \
|
||||
-DUSE_DLOPEN_NCCL=ON \
|
||||
${arch_flag}
|
||||
echo "--- Build binary wheel"
|
||||
$command_wrapper bash -c \
|
||||
"cd python-package && rm -rf dist/* && pip wheel --no-deps -v . --wheel-dir dist/"
|
||||
|
||||
@ -10,6 +10,7 @@ chmod +x build/testxgboost
|
||||
tests/ci_build/ci_build.sh gpu nvidia-docker \
|
||||
--build-arg CUDA_VERSION_ARG=$CUDA_VERSION \
|
||||
--build-arg RAPIDS_VERSION_ARG=$RAPIDS_VERSION \
|
||||
--build-arg NCCL_VERSION_ARG=$NCCL_VERSION \
|
||||
build/testxgboost
|
||||
|
||||
echo "--- Run Google Tests with CUDA, using a GPU, RMM enabled"
|
||||
|
||||
@ -13,4 +13,5 @@ chmod +x build/testxgboost
|
||||
tests/ci_build/ci_build.sh gpu nvidia-docker \
|
||||
--build-arg CUDA_VERSION_ARG=$CUDA_VERSION \
|
||||
--build-arg RAPIDS_VERSION_ARG=$RAPIDS_VERSION \
|
||||
--build-arg NCCL_VERSION_ARG=$NCCL_VERSION \
|
||||
build/testxgboost --gtest_filter=*MGPU*
|
||||
|
||||
@ -24,7 +24,8 @@ export CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'
|
||||
|
||||
command_wrapper="tests/ci_build/ci_build.sh gpu nvidia-docker --build-arg "`
|
||||
`"CUDA_VERSION_ARG=$CUDA_VERSION --build-arg "`
|
||||
`"RAPIDS_VERSION_ARG=$RAPIDS_VERSION"
|
||||
`"RAPIDS_VERSION_ARG=$RAPIDS_VERSION --build-arg "`
|
||||
`"NCCL_VERSION_ARG=$NCCL_VERSION"
|
||||
|
||||
# Run specified test suite
|
||||
case "$suite" in
|
||||
|
||||
@ -2,6 +2,7 @@ ARG CUDA_VERSION_ARG
|
||||
FROM nvidia/cuda:$CUDA_VERSION_ARG-runtime-ubuntu22.04
|
||||
ARG CUDA_VERSION_ARG
|
||||
ARG RAPIDS_VERSION_ARG
|
||||
ARG NCCL_VERSION_ARG
|
||||
|
||||
# Environment
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
@ -23,7 +24,9 @@ RUN \
|
||||
conda install -c conda-forge mamba && \
|
||||
mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.10 cudf=$RAPIDS_VERSION_ARG* rmm=$RAPIDS_VERSION_ARG* cudatoolkit=$CUDA_VERSION_ARG \
|
||||
dask dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
|
||||
nccl>=$(cut -d "-" -f 1 << $NCCL_VERSION_ARG) \
|
||||
dask \
|
||||
dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
|
||||
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
|
||||
pyspark>=3.4.0 cloudpickle cuda-python && \
|
||||
mamba clean --all && \
|
||||
|
||||
@ -27,7 +27,7 @@ RUN \
|
||||
wget -nv -nc https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||
yum -y update && \
|
||||
yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \
|
||||
yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} && \
|
||||
rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm;
|
||||
|
||||
ENV PATH=/opt/mambaforge/bin:/usr/local/ninja:$PATH
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
rm -rf tmp_nccl
|
||||
|
||||
mkdir tmp_nccl
|
||||
pushd tmp_nccl
|
||||
|
||||
set -x
|
||||
|
||||
cat << EOF > test.cu
|
||||
int main(void) { return 0; }
|
||||
EOF
|
||||
|
||||
cat << EOF > CMakeLists.txt
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
project(gencode_extractor CXX C)
|
||||
cmake_policy(SET CMP0104 NEW)
|
||||
set(CMAKE_CUDA_HOST_COMPILER \${CMAKE_CXX_COMPILER})
|
||||
enable_language(CUDA)
|
||||
include(../cmake/Utils.cmake)
|
||||
compute_cmake_cuda_archs("")
|
||||
add_library(test OBJECT test.cu)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
EOF
|
||||
|
||||
cmake . -GNinja -DCMAKE_EXPORT_COMPILE_COMMANDS=ON
|
||||
gen_code=$(grep -o -- '--generate-code=\S*' compile_commands.json | paste -sd ' ')
|
||||
|
||||
nvprune ${gen_code} /usr/lib64/libnccl_static.a -o ../libnccl_static.a
|
||||
|
||||
popd
|
||||
rm -rf tmp_nccl
|
||||
|
||||
set +x
|
||||
@ -1,22 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cd(path):
|
||||
path = os.path.normpath(path)
|
||||
cwd = os.getcwd()
|
||||
os.chdir(path)
|
||||
print("cd " + path)
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
from test_utils import DirectoryExcursion
|
||||
|
||||
if len(sys.argv) != 4:
|
||||
print('Usage: {} [wheel to rename] [commit id] [platform tag]'.format(sys.argv[0]))
|
||||
print("Usage: {} [wheel to rename] [commit id] [platform tag]".format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -26,20 +14,26 @@ platform_tag = sys.argv[3]
|
||||
|
||||
dirname, basename = os.path.dirname(whl_path), os.path.basename(whl_path)
|
||||
|
||||
with cd(dirname):
|
||||
tokens = basename.split('-')
|
||||
with DirectoryExcursion(dirname):
|
||||
tokens = basename.split("-")
|
||||
assert len(tokens) == 5
|
||||
version = tokens[1].split('+')[0]
|
||||
keywords = {'pkg_name': tokens[0],
|
||||
'version': version,
|
||||
'commit_id': commit_id,
|
||||
'platform_tag': platform_tag}
|
||||
new_name = '{pkg_name}-{version}+{commit_id}-py3-none-{platform_tag}.whl'.format(**keywords)
|
||||
print('Renaming {} to {}...'.format(basename, new_name))
|
||||
version = tokens[1].split("+")[0]
|
||||
keywords = {
|
||||
"pkg_name": tokens[0],
|
||||
"version": version,
|
||||
"commit_id": commit_id,
|
||||
"platform_tag": platform_tag,
|
||||
}
|
||||
new_name = "{pkg_name}-{version}+{commit_id}-py3-none-{platform_tag}.whl".format(
|
||||
**keywords
|
||||
)
|
||||
print("Renaming {} to {}...".format(basename, new_name))
|
||||
if os.path.isfile(new_name):
|
||||
os.remove(new_name)
|
||||
os.rename(basename, new_name)
|
||||
|
||||
filesize = os.path.getsize(new_name) / 1024 / 1024 # MB
|
||||
print(f"Wheel size: {filesize}")
|
||||
|
||||
msg = f"Limit of wheel size set by PyPI is exceeded. {new_name}: {filesize}"
|
||||
assert filesize <= 300, msg
|
||||
|
||||
@ -90,10 +90,10 @@ class Worker : public NCCLWorkerForTest {
|
||||
}
|
||||
};
|
||||
|
||||
class AllgatherTestGPU : public SocketTest {};
|
||||
class MGPUAllgatherTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||
TEST_F(MGPUAllgatherTest, MGPUTestVRing) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
@ -104,7 +104,7 @@ TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
|
||||
TEST_F(MGPUAllgatherTest, MGPUTestVBcast) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
|
||||
@ -5,17 +5,15 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class AllreduceTestGPU : public SocketTest {};
|
||||
class MGPUAllreduceTest : public SocketTest {};
|
||||
|
||||
class Worker : public NCCLWorkerForTest {
|
||||
public:
|
||||
@ -47,7 +45,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllreduceTestGPU, BitOr) {
|
||||
TEST_F(MGPUAllreduceTest, BitOr) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
@ -57,7 +55,7 @@ TEST_F(AllreduceTestGPU, BitOr) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTestGPU, Sum) {
|
||||
TEST_F(MGPUAllreduceTest, Sum) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <bitset>
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../src/collective/comm.cuh"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/collective/nccl_device_communicator.cuh"
|
||||
#include "../helpers.h"
|
||||
@ -16,17 +17,15 @@ namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||
auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
|
||||
auto construct = []() { NcclDeviceCommunicator comm{-1, false, DefaultNcclName()}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
|
||||
try {
|
||||
dh::safe_nccl(ncclSystemError);
|
||||
} catch (dmlc::Error const& e) {
|
||||
auto str = std::string{e.what()};
|
||||
ASSERT_TRUE(str.find("environment variables") != std::string::npos);
|
||||
}
|
||||
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
|
||||
auto rc = GetNCCLResult(stub, ncclSystemError);
|
||||
auto msg = rc.Report();
|
||||
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@ -33,7 +33,7 @@ class WorkerForTest {
|
||||
tracker_port_{port},
|
||||
world_size_{world},
|
||||
task_id_{"t:" + std::to_string(rank)},
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
|
||||
CHECK_EQ(world_size_, comm_.World());
|
||||
}
|
||||
virtual ~WorkerForTest() = default;
|
||||
|
||||
@ -12,6 +12,7 @@ from hypothesis._settings import duration
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.collective import CommunicatorContext
|
||||
from xgboost.testing.params import hist_parameter_strategy
|
||||
|
||||
pytestmark = [
|
||||
@ -572,6 +573,65 @@ def test_with_asyncio(local_cuda_client: Client) -> None:
|
||||
assert isinstance(output["history"], dict)
|
||||
|
||||
|
||||
def test_invalid_nccl(local_cuda_client: Client) -> None:
|
||||
client = local_cuda_client
|
||||
workers = tm.get_client_workers(client)
|
||||
args = client.sync(
|
||||
dxgb._get_rabit_args, len(workers), dxgb._get_dask_config(), client
|
||||
)
|
||||
|
||||
def run(wid: int) -> None:
|
||||
ctx = CommunicatorContext(dmlc_nccl_path="foo", **args)
|
||||
X, y, w = tm.make_regression(n_samples=10, n_features=10, use_cupy=True)
|
||||
|
||||
with ctx:
|
||||
with pytest.raises(ValueError, match=r"pip install"):
|
||||
xgb.QuantileDMatrix(X, y, weight=w)
|
||||
|
||||
futures = client.map(run, range(len(workers)), workers=workers)
|
||||
client.gather(futures)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_nccl_load(local_cuda_client: Client, tree_method: str) -> None:
|
||||
X, y, w = tm.make_regression(128, 16, use_cupy=True)
|
||||
|
||||
def make_model() -> None:
|
||||
xgb.XGBRegressor(
|
||||
device="cuda",
|
||||
tree_method=tree_method,
|
||||
objective="reg:quantileerror",
|
||||
verbosity=2,
|
||||
quantile_alpha=[0.2, 0.8],
|
||||
).fit(X, y, sample_weight=w)
|
||||
|
||||
# no nccl load when using single-node.
|
||||
with tm.captured_output() as (out, err):
|
||||
make_model()
|
||||
assert out.getvalue().find("NCCL") == -1
|
||||
assert err.getvalue().find("NCCL") == -1
|
||||
|
||||
client = local_cuda_client
|
||||
workers = tm.get_client_workers(client)
|
||||
args = client.sync(
|
||||
dxgb._get_rabit_args, len(workers), dxgb._get_dask_config(), client
|
||||
)
|
||||
|
||||
# nccl is loaded
|
||||
def run(wid: int) -> None:
|
||||
# FIXME(jiamingy): https://github.com/dmlc/xgboost/issues/9147
|
||||
from xgboost.core import _LIB, _register_log_callback
|
||||
_register_log_callback(_LIB)
|
||||
|
||||
with CommunicatorContext(**args):
|
||||
with tm.captured_output() as (out, err):
|
||||
make_model()
|
||||
assert out.getvalue().find("Loaded shared NCCL") != -1, out.getvalue()
|
||||
|
||||
futures = client.map(run, range(len(workers)), workers=workers)
|
||||
client.gather(futures)
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainReturnT:
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
import cupy as cp
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user