Add typehint to rabit module. (#7240)

This commit is contained in:
Jiaming Yuan 2021-09-17 18:31:02 +08:00 committed by GitHub
parent c735c17f33
commit e48e05e6e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 27 deletions

View File

@ -92,6 +92,7 @@ endif
mypy: mypy:
cd python-package; \ cd python-package; \
mypy ./xgboost/dask.py && \ mypy ./xgboost/dask.py && \
mypy ./xgboost/rabit.py && \
mypy ../demo/guide-python/external_memory.py && \ mypy ../demo/guide-python/external_memory.py && \
mypy ../tests/python-gpu/test_gpu_with_dask.py && \ mypy ../tests/python-gpu/test_gpu_with_dask.py && \
mypy ../tests/python/test_data_iterator.py && \ mypy ../tests/python/test_data_iterator.py && \

View File

@ -7,7 +7,7 @@ import collections
from collections.abc import Mapping from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict, TypeVar from typing import List, Optional, Any, Union, Dict, TypeVar
# pylint: enable=no-name-in-module,import-error # pylint: enable=no-name-in-module,import-error
from typing import Callable, Tuple from typing import Callable, Tuple, cast
import ctypes import ctypes
import os import os
import re import re
@ -2008,7 +2008,8 @@ class Booster(object):
dims = c_bst_ulong() dims = c_bst_ulong()
if base_margin is not None: if base_margin is not None:
proxy = _ProxyDMatrix() proxy: Optional[_ProxyDMatrix] = _ProxyDMatrix()
assert proxy is not None
proxy.set_info(base_margin=base_margin) proxy.set_info(base_margin=base_margin)
p_handle = proxy.handle p_handle = proxy.handle
else: else:
@ -2274,8 +2275,8 @@ class Booster(object):
return self.get_score(fmap, importance_type='weight') return self.get_score(fmap, importance_type='weight')
def get_score( def get_score(
self, fmap: os.PathLike = '', importance_type: str = 'weight' self, fmap: Union[str, os.PathLike] = '', importance_type: str = 'weight'
) -> Dict[str, float]: ) -> Dict[str, Union[float, List[float]]]:
"""Get feature importance of each feature. """Get feature importance of each feature.
For tree model Importance type can be defined as: For tree model Importance type can be defined as:
@ -2332,7 +2333,7 @@ class Booster(object):
features_arr = from_cstr_to_pystr(features, n_out_features) features_arr = from_cstr_to_pystr(features, n_out_features)
scores_arr = _prediction_output(shape, out_dim, scores, False) scores_arr = _prediction_output(shape, out_dim, scores, False)
results = {} results: Dict[str, Union[float, List[float]]] = {}
if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1: if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1:
for feat, score in zip(features_arr, scores_arr): for feat, score in zip(features_arr, scores_arr):
results[feat] = [float(s) for s in score] results[feat] = [float(s) for s in score]
@ -2481,7 +2482,7 @@ class Booster(object):
fmap: Union[os.PathLike, str] = '', fmap: Union[os.PathLike, str] = '',
bins: Optional[int] = None, bins: Optional[int] = None,
as_pandas: bool = True as_pandas: bool = True
): ) -> Union[np.ndarray, DataFrame]:
"""Get split value histogram of a feature """Get split value histogram of a feature
Parameters Parameters
@ -2526,7 +2527,7 @@ class Booster(object):
fn = [f"f{i}" for i in range(self.num_features())] fn = [f"f{i}" for i in range(self.num_features())]
try: try:
index = fn.index(feature) index = fn.index(feature)
feature_t = ft[index] feature_t: Optional[str] = cast(List[str], ft)[index]
except (ValueError, AttributeError, TypeError): except (ValueError, AttributeError, TypeError):
# None.index: attr err, None[0]: type err, fn.index(-1): value err # None.index: attr err, None[0]: type err, fn.index(-1): value err
feature_t = None feature_t = None

View File

@ -3,12 +3,14 @@
"""Distributed XGBoost Rabit related API.""" """Distributed XGBoost Rabit related API."""
import ctypes import ctypes
import pickle import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
import numpy as np import numpy as np
from .core import _LIB, c_str, STRING_TYPES, _check_call from .core import _LIB, c_str, STRING_TYPES, _check_call
def _init_rabit(): def _init_rabit() -> None:
"""internal library initializer.""" """internal library initializer."""
if _LIB is not None: if _LIB is not None:
_LIB.RabitGetRank.restype = ctypes.c_int _LIB.RabitGetRank.restype = ctypes.c_int
@ -17,21 +19,21 @@ def _init_rabit():
_LIB.RabitVersionNumber.restype = ctypes.c_int _LIB.RabitVersionNumber.restype = ctypes.c_int
def init(args=None): def init(args: Optional[List[bytes]] = None) -> None:
"""Initialize the rabit library with arguments""" """Initialize the rabit library with arguments"""
if args is None: if args is None:
args = [] args = []
arr = (ctypes.c_char_p * len(args))() arr = (ctypes.c_char_p * len(args))()
arr[:] = args arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
_LIB.RabitInit(len(arr), arr) _LIB.RabitInit(len(arr), arr)
def finalize(): def finalize() -> None:
"""Finalize the process, notify tracker everything is done.""" """Finalize the process, notify tracker everything is done."""
_LIB.RabitFinalize() _LIB.RabitFinalize()
def get_rank(): def get_rank() -> int:
"""Get rank of current process. """Get rank of current process.
Returns Returns
@ -43,7 +45,7 @@ def get_rank():
return ret return ret
def get_world_size(): def get_world_size() -> int:
"""Get total number workers. """Get total number workers.
Returns Returns
@ -55,13 +57,13 @@ def get_world_size():
return ret return ret
def is_distributed(): def is_distributed() -> int:
'''If rabit is distributed.''' '''If rabit is distributed.'''
is_dist = _LIB.RabitIsDistributed() is_dist = _LIB.RabitIsDistributed()
return is_dist return is_dist
def tracker_print(msg): def tracker_print(msg: Any) -> None:
"""Print message to the tracker. """Print message to the tracker.
This function can be used to communicate the information of This function can be used to communicate the information of
@ -81,7 +83,7 @@ def tracker_print(msg):
print(msg.strip(), flush=True) print(msg.strip(), flush=True)
def get_processor_name(): def get_processor_name() -> bytes:
"""Get the processor name. """Get the processor name.
Returns Returns
@ -96,7 +98,10 @@ def get_processor_name():
return buf.value return buf.value
def broadcast(data, root): T = TypeVar("T")
def broadcast(data: T, root: int) -> T:
"""Broadcast object from one node to all other nodes. """Broadcast object from one node to all other nodes.
Parameters Parameters
@ -155,7 +160,9 @@ class Op: # pylint: disable=too-few-public-methods
OR = 3 OR = 3
def allreduce(data, op, prepare_fun=None): def allreduce(
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result. """Perform allreduce, return the result.
Parameters Parameters
@ -193,16 +200,17 @@ def allreduce(data, op, prepare_fun=None):
else: else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(_): def pfunc(_: Any) -> None:
"""prepare function.""" """prepare function."""
prepare_fun(data) fn = cast(Callable[[np.ndarray], None], prepare_fun)
fn(data)
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None)) op, func_ptr(pfunc), None))
return buf return buf
def version_number(): def version_number() -> int:
"""Returns version number of current stored model. """Returns version number of current stored model.
This means how many calls to CheckPoint we made so far. This means how many calls to CheckPoint we made so far.

View File

@ -23,14 +23,9 @@ from .compat import (
XGBClassifierBase, XGBClassifierBase,
XGBRegressorBase, XGBRegressorBase,
XGBoostLabelEncoder, XGBoostLabelEncoder,
DataFrame,
scipy_csr,
) )
# Actually XGBoost supports a lot more data types including `scipy.sparse.csr_matrix` and array_like = Any
# many others. See `data.py` for a complete list. The `array_like` here is just for
# easier type checks.
array_like = TypeVar("array_like", bound=Union[np.ndarray, DataFrame, scipy_csr])
class XGBRankerMixIn: # pylint: disable=too-few-public-methods class XGBRankerMixIn: # pylint: disable=too-few-public-methods