[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -3,9 +3,8 @@
|
||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
from . import dask
|
||||
from . import collective, dask
|
||||
from .core import (
|
||||
Booster,
|
||||
DataIter,
|
||||
@@ -63,4 +62,6 @@ __all__ = [
|
||||
"XGBRFRegressor",
|
||||
# dask
|
||||
"dask",
|
||||
# collective
|
||||
"collective",
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@ import pickle
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from . import collective
|
||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def _allreduce_metric(score: _ART) -> _ART:
|
||||
as final result.
|
||||
|
||||
'''
|
||||
world = rabit.get_world_size()
|
||||
world = collective.get_world_size()
|
||||
assert world != 0
|
||||
if world == 1:
|
||||
return score
|
||||
@@ -108,7 +108,7 @@ def _allreduce_metric(score: _ART) -> _ART:
|
||||
raise ValueError(
|
||||
'xgboost.cv function should not be used in distributed environment.')
|
||||
arr = numpy.array([score])
|
||||
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
|
||||
arr = collective.allreduce(arr, collective.Op.SUM) / world
|
||||
return arr[0]
|
||||
|
||||
|
||||
@@ -485,7 +485,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
return False
|
||||
|
||||
msg: str = f'[{epoch}]'
|
||||
if rabit.get_rank() == self.printer_rank:
|
||||
if collective.get_rank() == self.printer_rank:
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
stdv: Optional[float] = None
|
||||
@@ -498,7 +498,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
msg += '\n'
|
||||
|
||||
if (epoch % self.period) == 0 or self.period == 1:
|
||||
rabit.tracker_print(msg)
|
||||
collective.communicator_print(msg)
|
||||
self._latest = None
|
||||
else:
|
||||
# There is skipped message
|
||||
@@ -506,8 +506,8 @@ class EvaluationMonitor(TrainingCallback):
|
||||
return False
|
||||
|
||||
def after_training(self, model: _Model) -> _Model:
|
||||
if rabit.get_rank() == self.printer_rank and self._latest is not None:
|
||||
rabit.tracker_print(self._latest)
|
||||
if collective.get_rank() == self.printer_rank and self._latest is not None:
|
||||
collective.communicator_print(self._latest)
|
||||
return model
|
||||
|
||||
|
||||
@@ -552,7 +552,7 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||
('.pkl' if self._as_pickle else '.json'))
|
||||
self._epoch = 0
|
||||
if rabit.get_rank() == 0:
|
||||
if collective.get_rank() == 0:
|
||||
if self._as_pickle:
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(model, fd)
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import pickle
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -233,10 +233,11 @@ class CommunicatorContext:
|
||||
def __init__(self, **args: Any) -> None:
|
||||
self.args = args
|
||||
|
||||
def __enter__(self) -> None:
|
||||
def __enter__(self) -> Dict[str, Any]:
|
||||
init(**self.args)
|
||||
assert is_distributed()
|
||||
LOGGER.debug("-------------- communicator say hello ------------------")
|
||||
return self.args
|
||||
|
||||
def __exit__(self, *args: List) -> None:
|
||||
finalize()
|
||||
|
||||
@@ -59,7 +59,7 @@ from typing import (
|
||||
|
||||
import numpy
|
||||
|
||||
from . import config, rabit
|
||||
from . import collective, config
|
||||
from ._typing import _T, FeatureNames, FeatureTypes
|
||||
from .callback import TrainingCallback
|
||||
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||
@@ -112,7 +112,7 @@ TrainReturnT = TypedDict(
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RabitContext",
|
||||
"CommunicatorContext",
|
||||
"DaskDMatrix",
|
||||
"DaskDeviceQuantileDMatrix",
|
||||
"DaskXGBRegressor",
|
||||
@@ -158,7 +158,7 @@ def _try_start_tracker(
|
||||
if isinstance(addrs[0], tuple):
|
||||
host_ip = addrs[0][0]
|
||||
port = addrs[0][1]
|
||||
rabit_context = RabitTracker(
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=get_host_ip(host_ip),
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
@@ -168,12 +168,12 @@ def _try_start_tracker(
|
||||
addr = addrs[0]
|
||||
assert isinstance(addr, str) or addr is None
|
||||
host_ip = get_host_ip(addr)
|
||||
rabit_context = RabitTracker(
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||
)
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
thread = Thread(target=rabit_context.join)
|
||||
env.update(rabit_tracker.worker_envs())
|
||||
rabit_tracker.start(n_workers)
|
||||
thread = Thread(target=rabit_tracker.join)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
except socket.error as e:
|
||||
@@ -213,11 +213,11 @@ def _assert_dask_support() -> None:
|
||||
LOGGER.warning(msg)
|
||||
|
||||
|
||||
class RabitContext(rabit.RabitContext):
|
||||
"""A context controlling rabit initialization and finalization."""
|
||||
class CommunicatorContext(collective.CommunicatorContext):
|
||||
"""A context controlling collective communicator initialization and finalization."""
|
||||
|
||||
def __init__(self, args: List[bytes]) -> None:
|
||||
super().__init__(args)
|
||||
def __init__(self, **args: Any) -> None:
|
||||
super().__init__(**args)
|
||||
worker = distributed.get_worker()
|
||||
with distributed.worker_client() as client:
|
||||
info = client.scheduler_info()
|
||||
@@ -227,9 +227,7 @@ class RabitContext(rabit.RabitContext):
|
||||
# not the same as task ID is string and "10" is sorted before "2") with dask
|
||||
# worker ID. This outsources the rank assignment to dask and prevents
|
||||
# non-deterministic issue.
|
||||
self.args.append(
|
||||
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
|
||||
)
|
||||
self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address)
|
||||
|
||||
|
||||
def dconcat(value: Sequence[_T]) -> _T:
|
||||
@@ -811,7 +809,7 @@ def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix:
|
||||
|
||||
async def _get_rabit_args(
|
||||
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
|
||||
) -> List[bytes]:
|
||||
) -> Dict[str, Union[str, int]]:
|
||||
"""Get rabit context arguments from data distribution in DaskDMatrix."""
|
||||
# There are 3 possible different addresses:
|
||||
# 1. Provided by user via dask.config
|
||||
@@ -854,9 +852,7 @@ async def _get_rabit_args(
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, n_workers, sched_addr, user_addr
|
||||
)
|
||||
|
||||
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
|
||||
return rabit_args
|
||||
return env
|
||||
|
||||
|
||||
def _get_dask_config() -> Optional[Dict[str, Any]]:
|
||||
@@ -911,7 +907,7 @@ async def _train_async(
|
||||
|
||||
def dispatched_train(
|
||||
parameters: Dict,
|
||||
rabit_args: List[bytes],
|
||||
rabit_args: Dict[str, Union[str, int]],
|
||||
train_id: int,
|
||||
evals_name: List[str],
|
||||
evals_id: List[int],
|
||||
@@ -935,7 +931,7 @@ async def _train_async(
|
||||
n_threads = dwnt
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
local_history: TrainingCallback.EvalsLog = {}
|
||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
import ctypes
|
||||
from enum import IntEnum, unique
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import _LIB, c_str, _check_call
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.rabit]")
|
||||
|
||||
|
||||
def _init_rabit() -> None:
|
||||
"""internal library initializer."""
|
||||
if _LIB is not None:
|
||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||
_LIB.RabitIsDistributed.restype = ctypes.c_int
|
||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||
|
||||
|
||||
def init(args: Optional[List[bytes]] = None) -> None:
|
||||
"""Initialize the rabit library with arguments"""
|
||||
if args is None:
|
||||
args = []
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
|
||||
_LIB.RabitInit(len(arr), arr)
|
||||
|
||||
|
||||
def finalize() -> None:
|
||||
"""Finalize the process, notify tracker everything is done."""
|
||||
_LIB.RabitFinalize()
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get rank of current process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Rank of current process.
|
||||
"""
|
||||
ret = _LIB.RabitGetRank()
|
||||
return ret
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get total number workers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n : int
|
||||
Total number of process.
|
||||
"""
|
||||
ret = _LIB.RabitGetWorldSize()
|
||||
return ret
|
||||
|
||||
|
||||
def is_distributed() -> int:
|
||||
'''If rabit is distributed.'''
|
||||
is_dist = _LIB.RabitIsDistributed()
|
||||
return is_dist
|
||||
|
||||
|
||||
def tracker_print(msg: Any) -> None:
|
||||
"""Print message to the tracker.
|
||||
|
||||
This function can be used to communicate the information of
|
||||
the progress to the tracker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : str
|
||||
The message to be printed to tracker.
|
||||
"""
|
||||
if not isinstance(msg, str):
|
||||
msg = str(msg)
|
||||
is_dist = _LIB.RabitIsDistributed()
|
||||
if is_dist != 0:
|
||||
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
|
||||
|
||||
def get_processor_name() -> bytes:
|
||||
"""Get the processor name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
name : str
|
||||
the name of processor(host)
|
||||
"""
|
||||
mxlen = 256
|
||||
length = ctypes.c_ulong()
|
||||
buf = ctypes.create_string_buffer(mxlen)
|
||||
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
||||
return buf.value
|
||||
|
||||
|
||||
T = TypeVar("T") # pylint:disable=invalid-name
|
||||
|
||||
|
||||
def broadcast(data: T, root: int) -> T:
|
||||
"""Broadcast object from one node to all other nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : any type that can be pickled
|
||||
Input data, if current rank does not equal root, this can be None
|
||||
root : int
|
||||
Rank of the node to broadcast data from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object : int
|
||||
the result of broadcast.
|
||||
"""
|
||||
rank = get_rank()
|
||||
length = ctypes.c_ulong()
|
||||
if root == rank:
|
||||
assert data is not None, 'need to pass in data when broadcasting'
|
||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
length.value = len(s)
|
||||
# run first broadcast
|
||||
_check_call(_LIB.RabitBroadcast(ctypes.byref(length),
|
||||
ctypes.sizeof(ctypes.c_ulong), root))
|
||||
if root != rank:
|
||||
dptr = (ctypes.c_char * length.value)()
|
||||
# run second
|
||||
_check_call(_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||
length.value, root))
|
||||
data = pickle.loads(dptr.raw)
|
||||
del dptr
|
||||
else:
|
||||
_check_call(_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||
length.value, root))
|
||||
del s
|
||||
return data
|
||||
|
||||
|
||||
# enumeration of dtypes
|
||||
DTYPE_ENUM__ = {
|
||||
np.dtype('int8'): 0,
|
||||
np.dtype('uint8'): 1,
|
||||
np.dtype('int32'): 2,
|
||||
np.dtype('uint32'): 3,
|
||||
np.dtype('int64'): 4,
|
||||
np.dtype('uint64'): 5,
|
||||
np.dtype('float32'): 6,
|
||||
np.dtype('float64'): 7
|
||||
}
|
||||
|
||||
|
||||
@unique
|
||||
class Op(IntEnum):
|
||||
'''Supported operations for rabit.'''
|
||||
MAX = 0
|
||||
MIN = 1
|
||||
SUM = 2
|
||||
OR = 3
|
||||
|
||||
|
||||
def allreduce( # pylint:disable=invalid-name
|
||||
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||
) -> np.ndarray:
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
Input data.
|
||||
op :
|
||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun :
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to initialize the data
|
||||
If the result of Allreduce can be recovered directly,
|
||||
then prepare_fun will NOT be called
|
||||
|
||||
Returns
|
||||
-------
|
||||
result :
|
||||
The result of allreduce, have same shape as data
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is not thread-safe.
|
||||
"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise Exception('allreduce only takes in numpy.ndarray')
|
||||
buf = data.ravel()
|
||||
if buf.base is data.base:
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise Exception(f"data type {buf.dtype} not supported")
|
||||
if prepare_fun is None:
|
||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
int(op), None, None))
|
||||
else:
|
||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
|
||||
def pfunc(_: Any) -> None:
|
||||
"""prepare function."""
|
||||
fn = cast(Callable[[np.ndarray], None], prepare_fun)
|
||||
fn(data)
|
||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, func_ptr(pfunc), None))
|
||||
return buf
|
||||
|
||||
|
||||
def version_number() -> int:
|
||||
"""Returns version number of current stored model.
|
||||
|
||||
This means how many calls to CheckPoint we made so far.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version : int
|
||||
Version number of currently stored model
|
||||
"""
|
||||
ret = _LIB.RabitVersionNumber()
|
||||
return ret
|
||||
|
||||
|
||||
class RabitContext:
|
||||
"""A context controlling rabit initialization and finalization."""
|
||||
|
||||
def __init__(self, args: List[bytes] = None) -> None:
|
||||
if args is None:
|
||||
args = []
|
||||
self.args = args
|
||||
|
||||
def __enter__(self) -> None:
|
||||
init(self.args)
|
||||
assert is_distributed()
|
||||
LOGGER.debug("-------------- rabit say hello ------------------")
|
||||
|
||||
def __exit__(self, *args: List) -> None:
|
||||
finalize()
|
||||
LOGGER.debug("--------------- rabit say bye ------------------")
|
||||
|
||||
|
||||
# initialization script
|
||||
_init_rabit()
|
||||
@@ -2,6 +2,7 @@
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines
|
||||
import json
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -57,7 +58,7 @@ from .params import (
|
||||
HasQueryIdCol,
|
||||
)
|
||||
from .utils import (
|
||||
RabitContext,
|
||||
CommunicatorContext,
|
||||
_get_args_from_message_list,
|
||||
_get_default_params_from_func,
|
||||
_get_gpu_id,
|
||||
@@ -769,7 +770,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
):
|
||||
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
|
||||
|
||||
_rabit_args = ""
|
||||
_rabit_args = {}
|
||||
if context.partitionId() == 0:
|
||||
get_logger("XGBoostPySpark").info(
|
||||
"booster params: %s\n"
|
||||
@@ -780,12 +781,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
dmatrix_kwargs,
|
||||
)
|
||||
|
||||
_rabit_args = str(_get_rabit_args(context, num_workers))
|
||||
_rabit_args = _get_rabit_args(context, num_workers)
|
||||
|
||||
messages = context.allGather(message=str(_rabit_args))
|
||||
messages = context.allGather(message=json.dumps(_rabit_args))
|
||||
_rabit_args = _get_args_from_message_list(messages)
|
||||
evals_result = {}
|
||||
with RabitContext(_rabit_args, context):
|
||||
with CommunicatorContext(context, **_rabit_args):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
pandas_df_iter,
|
||||
features_cols_names,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for helper functions."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from threading import Thread
|
||||
@@ -9,7 +10,7 @@ import pyspark
|
||||
from pyspark.sql.session import SparkSession
|
||||
from xgboost.tracker import RabitTracker
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost import collective
|
||||
|
||||
|
||||
def get_class_name(cls):
|
||||
@@ -36,21 +37,21 @@ def _get_default_params_from_func(func, unsupported_set):
|
||||
return filtered_params_dict
|
||||
|
||||
|
||||
class RabitContext:
|
||||
class CommunicatorContext:
|
||||
"""
|
||||
A context controlling rabit initialization and finalization.
|
||||
A context controlling collective communicator initialization and finalization.
|
||||
This isn't specificially necessary (note Part 3), but it is more understandable coding-wise.
|
||||
"""
|
||||
|
||||
def __init__(self, args, context):
|
||||
def __init__(self, context, **args):
|
||||
self.args = args
|
||||
self.args.append(("DMLC_TASK_ID=" + str(context.partitionId())).encode())
|
||||
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||
|
||||
def __enter__(self):
|
||||
rabit.init(self.args)
|
||||
collective.init(**self.args)
|
||||
|
||||
def __exit__(self, *args):
|
||||
rabit.finalize()
|
||||
collective.finalize()
|
||||
|
||||
|
||||
def _start_tracker(context, n_workers):
|
||||
@@ -74,8 +75,7 @@ def _get_rabit_args(context, n_workers):
|
||||
"""
|
||||
# pylint: disable=consider-using-f-string
|
||||
env = _start_tracker(context, n_workers)
|
||||
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
|
||||
return rabit_args
|
||||
return env
|
||||
|
||||
|
||||
def _get_host_ip(context):
|
||||
@@ -95,7 +95,7 @@ def _get_args_from_message_list(messages):
|
||||
if message != "":
|
||||
output = message
|
||||
break
|
||||
return [elem.split("'")[1].encode() for elem in output.strip("][").split(", ")]
|
||||
return json.loads(output)
|
||||
|
||||
|
||||
def _get_spark_session():
|
||||
|
||||
Reference in New Issue
Block a user