Add typehint to rabit module. (#7240)
This commit is contained in:
@@ -3,12 +3,14 @@
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
import ctypes
|
||||
import pickle
|
||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import _LIB, c_str, STRING_TYPES, _check_call
|
||||
|
||||
|
||||
def _init_rabit():
|
||||
def _init_rabit() -> None:
|
||||
"""internal library initializer."""
|
||||
if _LIB is not None:
|
||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||
@@ -17,21 +19,21 @@ def _init_rabit():
|
||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||
|
||||
|
||||
def init(args=None):
|
||||
def init(args: Optional[List[bytes]] = None) -> None:
|
||||
"""Initialize the rabit library with arguments"""
|
||||
if args is None:
|
||||
args = []
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = args
|
||||
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
|
||||
_LIB.RabitInit(len(arr), arr)
|
||||
|
||||
|
||||
def finalize():
|
||||
def finalize() -> None:
|
||||
"""Finalize the process, notify tracker everything is done."""
|
||||
_LIB.RabitFinalize()
|
||||
|
||||
|
||||
def get_rank():
|
||||
def get_rank() -> int:
|
||||
"""Get rank of current process.
|
||||
|
||||
Returns
|
||||
@@ -43,7 +45,7 @@ def get_rank():
|
||||
return ret
|
||||
|
||||
|
||||
def get_world_size():
|
||||
def get_world_size() -> int:
|
||||
"""Get total number workers.
|
||||
|
||||
Returns
|
||||
@@ -55,13 +57,13 @@ def get_world_size():
|
||||
return ret
|
||||
|
||||
|
||||
def is_distributed():
|
||||
def is_distributed() -> int:
|
||||
'''If rabit is distributed.'''
|
||||
is_dist = _LIB.RabitIsDistributed()
|
||||
return is_dist
|
||||
|
||||
|
||||
def tracker_print(msg):
|
||||
def tracker_print(msg: Any) -> None:
|
||||
"""Print message to the tracker.
|
||||
|
||||
This function can be used to communicate the information of
|
||||
@@ -81,7 +83,7 @@ def tracker_print(msg):
|
||||
print(msg.strip(), flush=True)
|
||||
|
||||
|
||||
def get_processor_name():
|
||||
def get_processor_name() -> bytes:
|
||||
"""Get the processor name.
|
||||
|
||||
Returns
|
||||
@@ -96,7 +98,10 @@ def get_processor_name():
|
||||
return buf.value
|
||||
|
||||
|
||||
def broadcast(data, root):
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def broadcast(data: T, root: int) -> T:
|
||||
"""Broadcast object from one node to all other nodes.
|
||||
|
||||
Parameters
|
||||
@@ -155,7 +160,9 @@ class Op: # pylint: disable=too-few-public-methods
|
||||
OR = 3
|
||||
|
||||
|
||||
def allreduce(data, op, prepare_fun=None):
|
||||
def allreduce(
|
||||
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||
) -> np.ndarray:
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
Parameters
|
||||
@@ -193,16 +200,17 @@ def allreduce(data, op, prepare_fun=None):
|
||||
else:
|
||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
|
||||
def pfunc(_):
|
||||
def pfunc(_: Any) -> None:
|
||||
"""prepare function."""
|
||||
prepare_fun(data)
|
||||
fn = cast(Callable[[np.ndarray], None], prepare_fun)
|
||||
fn(data)
|
||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, func_ptr(pfunc), None))
|
||||
return buf
|
||||
|
||||
|
||||
def version_number():
|
||||
def version_number() -> int:
|
||||
"""Returns version number of current stored model.
|
||||
|
||||
This means how many calls to CheckPoint we made so far.
|
||||
|
||||
Reference in New Issue
Block a user