diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index f17ac23ba..220093b47 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -4,7 +4,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md """ from . import tracker # noqa -from . import collective, dask, rabit +from . import collective, dask from .core import ( Booster, DataIter, diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py deleted file mode 100644 index 132d72178..000000000 --- a/python-package/xgboost/rabit.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Compatibility shim for xgboost.rabit; to be removed in 2.0""" -import logging -import warnings -from enum import IntEnum, unique -from typing import Any, Callable, List, Optional, TypeVar - -import numpy as np - -from . import collective - -LOGGER = logging.getLogger("[xgboost.rabit]") - - -def _deprecation_warning() -> str: - return ( - "The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed " - "in 2.0. Please use xgboost.collective instead." - ) - - -def init(args: Optional[List[bytes]] = None) -> None: - """Initialize the rabit library with arguments""" - warnings.warn(_deprecation_warning(), FutureWarning) - parsed = {} - if args: - for arg in args: - kv = arg.decode().split("=") - if len(kv) == 2: - parsed[kv[0]] = kv[1] - collective.init(**parsed) - - -def finalize() -> None: - """Finalize the process, notify tracker everything is done.""" - collective.finalize() - - -def get_rank() -> int: - """Get rank of current process. - Returns - ------- - rank : int - Rank of current process. - """ - return collective.get_rank() - - -def get_world_size() -> int: - """Get total number workers. - Returns - ------- - n : int - Total number of process. - """ - return collective.get_world_size() - - -def is_distributed() -> int: - """If rabit is distributed.""" - return collective.is_distributed() - - -def tracker_print(msg: Any) -> None: - """Print message to the tracker. - This function can be used to communicate the information of - the progress to the tracker - Parameters - ---------- - msg : str - The message to be printed to tracker. - """ - collective.communicator_print(msg) - - -def get_processor_name() -> bytes: - """Get the processor name. - Returns - ------- - name : str - the name of processor(host) - """ - return collective.get_processor_name().encode() - - -T = TypeVar("T") # pylint:disable=invalid-name - - -def broadcast(data: T, root: int) -> T: - """Broadcast object from one node to all other nodes. - Parameters - ---------- - data : any type that can be pickled - Input data, if current rank does not equal root, this can be None - root : int - Rank of the node to broadcast data from. - Returns - ------- - object : int - the result of broadcast. - """ - return collective.broadcast(data, root) - - -@unique -class Op(IntEnum): - """Supported operations for rabit.""" - - MAX = 0 - MIN = 1 - SUM = 2 - OR = 3 - - -def allreduce( # pylint:disable=invalid-name - data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None -) -> np.ndarray: - """Perform allreduce, return the result. - Parameters - ---------- - data : - Input data. - op : - Reduction operators, can be MIN, MAX, SUM, BITOR - prepare_fun : - Lazy preprocessing function, if it is not None, prepare_fun(data) - will be called by the function before performing allreduce, to initialize the data - If the result of Allreduce can be recovered directly, - then prepare_fun will NOT be called - Returns - ------- - result : - The result of allreduce, have same shape as data - Notes - ----- - This function is not thread-safe. - """ - if prepare_fun is None: - return collective.allreduce(data, collective.Op(op)) - raise ValueError("preprocessing function is no longer supported") - - -def version_number() -> int: - """Returns version number of current stored model. - This means how many calls to CheckPoint we made so far. - Returns - ------- - version : int - Version number of currently stored model - """ - return 0 - - -class RabitContext: - """A context controlling rabit initialization and finalization.""" - - def __init__(self, args: Optional[List[bytes]] = None) -> None: - if args is None: - args = [] - self.args = args - - def __enter__(self) -> None: - init(self.args) - assert is_distributed() - LOGGER.warning(_deprecation_warning()) - LOGGER.debug("-------------- rabit say hello ------------------") - - def __exit__(self, *args: List) -> None: - finalize() - LOGGER.debug("--------------- rabit say bye ------------------") diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 32b0a67a7..f7de0400d 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -39,37 +39,6 @@ def test_rabit_communicator(): assert worker.exitcode == 0 -# TODO(rongou): remove this once we remove the rabit api. -def run_rabit_api_worker(rabit_env, world_size): - with xgb.rabit.RabitContext(rabit_env): - assert xgb.rabit.get_world_size() == world_size - assert xgb.rabit.is_distributed() - assert xgb.rabit.get_processor_name().decode() == socket.gethostname() - ret = xgb.rabit.broadcast('test1234', 0) - assert str(ret) == 'test1234' - ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM) - assert np.array_equal(ret, np.asarray([2, 4, 6])) - - -# TODO(rongou): remove this once we remove the rabit api. -def test_rabit_api(): - world_size = 2 - tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) - tracker.start(world_size) - rabit_env = [] - for k, v in tracker.worker_envs().items(): - rabit_env.append(f"{k}={v}".encode()) - workers = [] - for _ in range(world_size): - worker = multiprocessing.Process(target=run_rabit_api_worker, - args=(rabit_env, world_size)) - workers.append(worker) - worker.start() - for worker in workers: - worker.join() - assert worker.exitcode == 0 - - def run_federated_worker(port, world_size, rank): with xgb.collective.CommunicatorContext(xgboost_communicator='federated', federated_server_address=f'localhost:{port}',