2021-09-17 18:31:02 +08:00

229 lines
6.0 KiB
Python

# coding: utf-8
# pylint: disable= invalid-name
"""Distributed XGBoost Rabit related API."""
import ctypes
import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
import numpy as np
from .core import _LIB, c_str, STRING_TYPES, _check_call
def _init_rabit() -> None:
"""internal library initializer."""
if _LIB is not None:
_LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int
_LIB.RabitIsDistributed.restype = ctypes.c_int
_LIB.RabitVersionNumber.restype = ctypes.c_int
def init(args: Optional[List[bytes]] = None) -> None:
"""Initialize the rabit library with arguments"""
if args is None:
args = []
arr = (ctypes.c_char_p * len(args))()
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
_LIB.RabitInit(len(arr), arr)
def finalize() -> None:
"""Finalize the process, notify tracker everything is done."""
_LIB.RabitFinalize()
def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.RabitGetRank()
return ret
def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.RabitGetWorldSize()
return ret
def is_distributed() -> int:
'''If rabit is distributed.'''
is_dist = _LIB.RabitIsDistributed()
return is_dist
def tracker_print(msg: Any) -> None:
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
if not isinstance(msg, STRING_TYPES):
msg = str(msg)
is_dist = _LIB.RabitIsDistributed()
if is_dist != 0:
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
else:
print(msg.strip(), flush=True)
def get_processor_name() -> bytes:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
mxlen = 256
length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen)
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
return buf.value
T = TypeVar("T")
def broadcast(data: T, root: int) -> T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_check_call(_LIB.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root))
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_check_call(_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root))
data = pickle.loads(dptr.raw)
del dptr
else:
_check_call(_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root))
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8'): 0,
np.dtype('uint8'): 1,
np.dtype('int32'): 2,
np.dtype('uint32'): 3,
np.dtype('int64'): 4,
np.dtype('uint64'): 5,
np.dtype('float32'): 6,
np.dtype('float64'): 7
}
class Op: # pylint: disable=too-few-public-methods
'''Supported operations for rabit.'''
MAX = 0
MIN = 1
SUM = 2
OR = 3
def allreduce(
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data: numpy array
Input data.
op: int
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result : array_like
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise Exception('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported")
if prepare_fun is None:
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None))
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(_: Any) -> None:
"""prepare function."""
fn = cast(Callable[[np.ndarray], None], prepare_fun)
fn(data)
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None))
return buf
def version_number() -> int:
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
ret = _LIB.RabitVersionNumber()
return ret
# intialization script
_init_rabit()