From e48e05e6e29477634c7d78bd9e0d4acfeac56b17 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 17 Sep 2021 18:31:02 +0800 Subject: [PATCH] Add typehint to rabit module. (#7240) --- Makefile | 1 + python-package/xgboost/core.py | 15 +++++++------ python-package/xgboost/rabit.py | 36 +++++++++++++++++++------------ python-package/xgboost/sklearn.py | 7 +----- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index 09b137d92..d65cb56cc 100644 --- a/Makefile +++ b/Makefile @@ -92,6 +92,7 @@ endif mypy: cd python-package; \ mypy ./xgboost/dask.py && \ + mypy ./xgboost/rabit.py && \ mypy ../demo/guide-python/external_memory.py && \ mypy ../tests/python-gpu/test_gpu_with_dask.py && \ mypy ../tests/python/test_data_iterator.py && \ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 7bfd867cb..45a3d68d6 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -7,7 +7,7 @@ import collections from collections.abc import Mapping from typing import List, Optional, Any, Union, Dict, TypeVar # pylint: enable=no-name-in-module,import-error -from typing import Callable, Tuple +from typing import Callable, Tuple, cast import ctypes import os import re @@ -2008,7 +2008,8 @@ class Booster(object): dims = c_bst_ulong() if base_margin is not None: - proxy = _ProxyDMatrix() + proxy: Optional[_ProxyDMatrix] = _ProxyDMatrix() + assert proxy is not None proxy.set_info(base_margin=base_margin) p_handle = proxy.handle else: @@ -2274,8 +2275,8 @@ class Booster(object): return self.get_score(fmap, importance_type='weight') def get_score( - self, fmap: os.PathLike = '', importance_type: str = 'weight' - ) -> Dict[str, float]: + self, fmap: Union[str, os.PathLike] = '', importance_type: str = 'weight' + ) -> Dict[str, Union[float, List[float]]]: """Get feature importance of each feature. For tree model Importance type can be defined as: @@ -2332,7 +2333,7 @@ class Booster(object): features_arr = from_cstr_to_pystr(features, n_out_features) scores_arr = _prediction_output(shape, out_dim, scores, False) - results = {} + results: Dict[str, Union[float, List[float]]] = {} if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1: for feat, score in zip(features_arr, scores_arr): results[feat] = [float(s) for s in score] @@ -2481,7 +2482,7 @@ class Booster(object): fmap: Union[os.PathLike, str] = '', bins: Optional[int] = None, as_pandas: bool = True - ): + ) -> Union[np.ndarray, DataFrame]: """Get split value histogram of a feature Parameters @@ -2526,7 +2527,7 @@ class Booster(object): fn = [f"f{i}" for i in range(self.num_features())] try: index = fn.index(feature) - feature_t = ft[index] + feature_t: Optional[str] = cast(List[str], ft)[index] except (ValueError, AttributeError, TypeError): # None.index: attr err, None[0]: type err, fn.index(-1): value err feature_t = None diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index b09d17b12..0980dec36 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -3,12 +3,14 @@ """Distributed XGBoost Rabit related API.""" import ctypes import pickle +from typing import Any, TypeVar, Callable, Optional, cast, List, Union + import numpy as np from .core import _LIB, c_str, STRING_TYPES, _check_call -def _init_rabit(): +def _init_rabit() -> None: """internal library initializer.""" if _LIB is not None: _LIB.RabitGetRank.restype = ctypes.c_int @@ -17,21 +19,21 @@ def _init_rabit(): _LIB.RabitVersionNumber.restype = ctypes.c_int -def init(args=None): +def init(args: Optional[List[bytes]] = None) -> None: """Initialize the rabit library with arguments""" if args is None: args = [] arr = (ctypes.c_char_p * len(args))() - arr[:] = args + arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args) _LIB.RabitInit(len(arr), arr) -def finalize(): +def finalize() -> None: """Finalize the process, notify tracker everything is done.""" _LIB.RabitFinalize() -def get_rank(): +def get_rank() -> int: """Get rank of current process. Returns @@ -43,7 +45,7 @@ def get_rank(): return ret -def get_world_size(): +def get_world_size() -> int: """Get total number workers. Returns @@ -55,13 +57,13 @@ def get_world_size(): return ret -def is_distributed(): +def is_distributed() -> int: '''If rabit is distributed.''' is_dist = _LIB.RabitIsDistributed() return is_dist -def tracker_print(msg): +def tracker_print(msg: Any) -> None: """Print message to the tracker. This function can be used to communicate the information of @@ -81,7 +83,7 @@ def tracker_print(msg): print(msg.strip(), flush=True) -def get_processor_name(): +def get_processor_name() -> bytes: """Get the processor name. Returns @@ -96,7 +98,10 @@ def get_processor_name(): return buf.value -def broadcast(data, root): +T = TypeVar("T") + + +def broadcast(data: T, root: int) -> T: """Broadcast object from one node to all other nodes. Parameters @@ -155,7 +160,9 @@ class Op: # pylint: disable=too-few-public-methods OR = 3 -def allreduce(data, op, prepare_fun=None): +def allreduce( + data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None +) -> np.ndarray: """Perform allreduce, return the result. Parameters @@ -193,16 +200,17 @@ def allreduce(data, op, prepare_fun=None): else: func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - def pfunc(_): + def pfunc(_: Any) -> None: """prepare function.""" - prepare_fun(data) + fn = cast(Callable[[np.ndarray], None], prepare_fun) + fn(data) _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), buf.size, DTYPE_ENUM__[buf.dtype], op, func_ptr(pfunc), None)) return buf -def version_number(): +def version_number() -> int: """Returns version number of current stored model. This means how many calls to CheckPoint we made so far. diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 573da3ae8..e4b6f2f8f 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -23,14 +23,9 @@ from .compat import ( XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder, - DataFrame, - scipy_csr, ) -# Actually XGBoost supports a lot more data types including `scipy.sparse.csr_matrix` and -# many others. See `data.py` for a complete list. The `array_like` here is just for -# easier type checks. -array_like = TypeVar("array_like", bound=Union[np.ndarray, DataFrame, scipy_csr]) +array_like = Any class XGBRankerMixIn: # pylint: disable=too-few-public-methods