Remove the deprecated Python rabit module. (#9523)

This commit is contained in:
Jiaming Yuan 2023-08-27 03:37:05 +08:00 committed by GitHub
parent aa86bd5207
commit 209335b18c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 1 additions and 201 deletions

View File

@ -4,7 +4,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
""" """
from . import tracker # noqa from . import tracker # noqa
from . import collective, dask, rabit from . import collective, dask
from .core import ( from .core import (
Booster, Booster,
DataIter, DataIter,

View File

@ -1,169 +0,0 @@
"""Compatibility shim for xgboost.rabit; to be removed in 2.0"""
import logging
import warnings
from enum import IntEnum, unique
from typing import Any, Callable, List, Optional, TypeVar
import numpy as np
from . import collective
LOGGER = logging.getLogger("[xgboost.rabit]")
def _deprecation_warning() -> str:
return (
"The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed "
"in 2.0. Please use xgboost.collective instead."
)
def init(args: Optional[List[bytes]] = None) -> None:
"""Initialize the rabit library with arguments"""
warnings.warn(_deprecation_warning(), FutureWarning)
parsed = {}
if args:
for arg in args:
kv = arg.decode().split("=")
if len(kv) == 2:
parsed[kv[0]] = kv[1]
collective.init(**parsed)
def finalize() -> None:
"""Finalize the process, notify tracker everything is done."""
collective.finalize()
def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
return collective.get_rank()
def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
return collective.get_world_size()
def is_distributed() -> int:
"""If rabit is distributed."""
return collective.is_distributed()
def tracker_print(msg: Any) -> None:
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
collective.communicator_print(msg)
def get_processor_name() -> bytes:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
return collective.get_processor_name().encode()
T = TypeVar("T") # pylint:disable=invalid-name
def broadcast(data: T, root: int) -> T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
return collective.broadcast(data, root)
@unique
class Op(IntEnum):
"""Supported operations for rabit."""
MAX = 0
MIN = 1
SUM = 2
OR = 3
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data :
Input data.
op :
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result :
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if prepare_fun is None:
return collective.allreduce(data, collective.Op(op))
raise ValueError("preprocessing function is no longer supported")
def version_number() -> int:
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
return 0
class RabitContext:
"""A context controlling rabit initialization and finalization."""
def __init__(self, args: Optional[List[bytes]] = None) -> None:
if args is None:
args = []
self.args = args
def __enter__(self) -> None:
init(self.args)
assert is_distributed()
LOGGER.warning(_deprecation_warning())
LOGGER.debug("-------------- rabit say hello ------------------")
def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- rabit say bye ------------------")

View File

@ -39,37 +39,6 @@ def test_rabit_communicator():
assert worker.exitcode == 0 assert worker.exitcode == 0
# TODO(rongou): remove this once we remove the rabit api.
def run_rabit_api_worker(rabit_env, world_size):
with xgb.rabit.RabitContext(rabit_env):
assert xgb.rabit.get_world_size() == world_size
assert xgb.rabit.is_distributed()
assert xgb.rabit.get_processor_name().decode() == socket.gethostname()
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))
# TODO(rongou): remove this once we remove the rabit api.
def test_rabit_api():
world_size = 2
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
tracker.start(world_size)
rabit_env = []
for k, v in tracker.worker_envs().items():
rabit_env.append(f"{k}={v}".encode())
workers = []
for _ in range(world_size):
worker = multiprocessing.Process(target=run_rabit_api_worker,
args=(rabit_env, world_size))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0
def run_federated_worker(port, world_size, rank): def run_federated_worker(port, world_size, rank):
with xgb.collective.CommunicatorContext(xgboost_communicator='federated', with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
federated_server_address=f'localhost:{port}', federated_server_address=f'localhost:{port}',