Common interface for collective communication (#8057)

* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-09-12 15:21:12 -07:00
committed by GitHub
parent bc818316f2
commit a2686543a9
25 changed files with 1771 additions and 95 deletions

View File

@@ -0,0 +1,243 @@
"""XGBoost collective communication related API."""
import ctypes
import json
import logging
import pickle
from enum import IntEnum, unique
from typing import Any, List
import numpy as np
from ._typing import _T
from .core import _LIB, _check_call, c_str, py_str, from_pystr_to_cstr
LOGGER = logging.getLogger("[xgboost.collective]")
def init(**args: Any) -> None:
"""Initialize the collective library with arguments.
Parameters
----------
args: Dict[str, Any]
Keyword arguments representing the parameters and their values.
Accepted parameters:
- xgboost_communicator: The type of the communicator. Can be set as an environment
variable.
* rabit: Use Rabit. This is the default if the type is unspecified.
* federated: Use the gRPC interface for Federated Learning.
Only applicable to the Rabit communicator (these are case sensitive):
-- rabit_tracker_uri: Hostname of the tracker.
-- rabit_tracker_port: Port number of the tracker.
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
assignment.
-- rabit_world_size: Total number of workers.
-- rabit_hadoop_mode: Enable Hadoop support.
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
-- rabit_reduce_buffer: Size of the reduce buffer.
-- rabit_bootstrap_cache: Size of the bootstrap cache.
-- rabit_debug: Enable debugging.
-- rabit_timeout: Enable timeout.
-- rabit_timeout_sec: Timeout in seconds.
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
environment variables):
-- DMLC_TRACKER_URI: Hostname of the tracker.
-- DMLC_TRACKER_PORT: Port number of the tracker.
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
assignment.
-- DMLC_ROLE: Role of the current task, "worker" or "server".
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
Only applicable to the Federated communicator (use upper case for environment variables, use
lower case for runtime configuration):
-- federated_server_address: Address of the federated server.
-- federated_world_size: Number of federated workers.
-- federated_rank: Rank of the current worker.
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
-- federated_client_key: Client key file path. Only needed for the SSL mode.
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
"""
config = from_pystr_to_cstr(json.dumps(args))
_check_call(_LIB.XGCommunicatorInit(config))
def finalize() -> None:
"""Finalize the communicator."""
_check_call(_LIB.XGCommunicatorFinalize())
def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.XGCommunicatorGetRank()
return ret
def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.XGCommunicatorGetWorldSize()
return ret
def is_distributed() -> int:
"""If the collective communicator is distributed."""
is_dist = _LIB.XGCommunicatorIsDistributed()
return is_dist
def communicator_print(msg: Any) -> None:
"""Print message to the communicator.
This function can be used to communicate the information of
the progress to the communicator.
Parameters
----------
msg : str
The message to be printed to the communicator.
"""
if not isinstance(msg, str):
msg = str(msg)
is_dist = _LIB.XGCommunicatorIsDistributed()
if is_dist != 0:
_check_call(_LIB.XGCommunicatorPrint(c_str(msg)))
else:
print(msg.strip(), flush=True)
def get_processor_name() -> str:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
name_str = ctypes.c_char_p()
_check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str)))
value = name_str.value
assert value
return py_str(value)
def broadcast(data: _T, root: int) -> _T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root))
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root))
data = pickle.loads(dptr.raw)
del dptr
else:
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root))
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8'): 0,
np.dtype('uint8'): 1,
np.dtype('int32'): 2,
np.dtype('uint32'): 3,
np.dtype('int64'): 4,
np.dtype('uint64'): 5,
np.dtype('float32'): 6,
np.dtype('float64'): 7
}
@unique
class Op(IntEnum):
"""Supported operations for allreduce."""
MAX = 0
MIN = 1
SUM = 2
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data :
Input data.
op :
Reduction operator.
Returns
-------
result :
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise TypeError('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported")
_check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
int(op), None, None))
return buf
class CommunicatorContext:
"""A context controlling collective communicator initialization and finalization."""
def __init__(self, **args: Any) -> None:
self.args = args
def __enter__(self) -> None:
init(**self.args)
assert is_distributed()
LOGGER.debug("-------------- communicator say hello ------------------")
def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- communicator say bye ------------------")