Remove the deprecated Python rabit module. (#9523)
This commit is contained in:
parent
aa86bd5207
commit
209335b18c
@ -4,7 +4,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
from . import tracker # noqa
|
||||
from . import collective, dask, rabit
|
||||
from . import collective, dask
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
|
||||
@ -1,169 +0,0 @@
|
||||
"""Compatibility shim for xgboost.rabit; to be removed in 2.0"""
|
||||
import logging
|
||||
import warnings
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import collective
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.rabit]")
|
||||
|
||||
|
||||
def _deprecation_warning() -> str:
|
||||
return (
|
||||
"The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed "
|
||||
"in 2.0. Please use xgboost.collective instead."
|
||||
)
|
||||
|
||||
|
||||
def init(args: Optional[List[bytes]] = None) -> None:
|
||||
"""Initialize the rabit library with arguments"""
|
||||
warnings.warn(_deprecation_warning(), FutureWarning)
|
||||
parsed = {}
|
||||
if args:
|
||||
for arg in args:
|
||||
kv = arg.decode().split("=")
|
||||
if len(kv) == 2:
|
||||
parsed[kv[0]] = kv[1]
|
||||
collective.init(**parsed)
|
||||
|
||||
|
||||
def finalize() -> None:
|
||||
"""Finalize the process, notify tracker everything is done."""
|
||||
collective.finalize()
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get rank of current process.
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Rank of current process.
|
||||
"""
|
||||
return collective.get_rank()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get total number workers.
|
||||
Returns
|
||||
-------
|
||||
n : int
|
||||
Total number of process.
|
||||
"""
|
||||
return collective.get_world_size()
|
||||
|
||||
|
||||
def is_distributed() -> int:
|
||||
"""If rabit is distributed."""
|
||||
return collective.is_distributed()
|
||||
|
||||
|
||||
def tracker_print(msg: Any) -> None:
|
||||
"""Print message to the tracker.
|
||||
This function can be used to communicate the information of
|
||||
the progress to the tracker
|
||||
Parameters
|
||||
----------
|
||||
msg : str
|
||||
The message to be printed to tracker.
|
||||
"""
|
||||
collective.communicator_print(msg)
|
||||
|
||||
|
||||
def get_processor_name() -> bytes:
|
||||
"""Get the processor name.
|
||||
Returns
|
||||
-------
|
||||
name : str
|
||||
the name of processor(host)
|
||||
"""
|
||||
return collective.get_processor_name().encode()
|
||||
|
||||
|
||||
T = TypeVar("T") # pylint:disable=invalid-name
|
||||
|
||||
|
||||
def broadcast(data: T, root: int) -> T:
|
||||
"""Broadcast object from one node to all other nodes.
|
||||
Parameters
|
||||
----------
|
||||
data : any type that can be pickled
|
||||
Input data, if current rank does not equal root, this can be None
|
||||
root : int
|
||||
Rank of the node to broadcast data from.
|
||||
Returns
|
||||
-------
|
||||
object : int
|
||||
the result of broadcast.
|
||||
"""
|
||||
return collective.broadcast(data, root)
|
||||
|
||||
|
||||
@unique
|
||||
class Op(IntEnum):
|
||||
"""Supported operations for rabit."""
|
||||
|
||||
MAX = 0
|
||||
MIN = 1
|
||||
SUM = 2
|
||||
OR = 3
|
||||
|
||||
|
||||
def allreduce( # pylint:disable=invalid-name
|
||||
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||
) -> np.ndarray:
|
||||
"""Perform allreduce, return the result.
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
Input data.
|
||||
op :
|
||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun :
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to initialize the data
|
||||
If the result of Allreduce can be recovered directly,
|
||||
then prepare_fun will NOT be called
|
||||
Returns
|
||||
-------
|
||||
result :
|
||||
The result of allreduce, have same shape as data
|
||||
Notes
|
||||
-----
|
||||
This function is not thread-safe.
|
||||
"""
|
||||
if prepare_fun is None:
|
||||
return collective.allreduce(data, collective.Op(op))
|
||||
raise ValueError("preprocessing function is no longer supported")
|
||||
|
||||
|
||||
def version_number() -> int:
|
||||
"""Returns version number of current stored model.
|
||||
This means how many calls to CheckPoint we made so far.
|
||||
Returns
|
||||
-------
|
||||
version : int
|
||||
Version number of currently stored model
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class RabitContext:
|
||||
"""A context controlling rabit initialization and finalization."""
|
||||
|
||||
def __init__(self, args: Optional[List[bytes]] = None) -> None:
|
||||
if args is None:
|
||||
args = []
|
||||
self.args = args
|
||||
|
||||
def __enter__(self) -> None:
|
||||
init(self.args)
|
||||
assert is_distributed()
|
||||
LOGGER.warning(_deprecation_warning())
|
||||
LOGGER.debug("-------------- rabit say hello ------------------")
|
||||
|
||||
def __exit__(self, *args: List) -> None:
|
||||
finalize()
|
||||
LOGGER.debug("--------------- rabit say bye ------------------")
|
||||
@ -39,37 +39,6 @@ def test_rabit_communicator():
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
# TODO(rongou): remove this once we remove the rabit api.
|
||||
def run_rabit_api_worker(rabit_env, world_size):
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
assert xgb.rabit.get_world_size() == world_size
|
||||
assert xgb.rabit.is_distributed()
|
||||
assert xgb.rabit.get_processor_name().decode() == socket.gethostname()
|
||||
ret = xgb.rabit.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
# TODO(rongou): remove this once we remove the rabit api.
|
||||
def test_rabit_api():
|
||||
world_size = 2
|
||||
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
rabit_env = []
|
||||
for k, v in tracker.worker_envs().items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_rabit_api_worker,
|
||||
args=(rabit_env, world_size))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
def run_federated_worker(port, world_size, rank):
|
||||
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||
federated_server_address=f'localhost:{port}',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user