Add typehint to rabit module. (#7240)
This commit is contained in:
parent
c735c17f33
commit
e48e05e6e2
1
Makefile
1
Makefile
@ -92,6 +92,7 @@ endif
|
|||||||
mypy:
|
mypy:
|
||||||
cd python-package; \
|
cd python-package; \
|
||||||
mypy ./xgboost/dask.py && \
|
mypy ./xgboost/dask.py && \
|
||||||
|
mypy ./xgboost/rabit.py && \
|
||||||
mypy ../demo/guide-python/external_memory.py && \
|
mypy ../demo/guide-python/external_memory.py && \
|
||||||
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
|
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
|
||||||
mypy ../tests/python/test_data_iterator.py && \
|
mypy ../tests/python/test_data_iterator.py && \
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import collections
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||||
# pylint: enable=no-name-in-module,import-error
|
# pylint: enable=no-name-in-module,import-error
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple, cast
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -2008,7 +2008,8 @@ class Booster(object):
|
|||||||
dims = c_bst_ulong()
|
dims = c_bst_ulong()
|
||||||
|
|
||||||
if base_margin is not None:
|
if base_margin is not None:
|
||||||
proxy = _ProxyDMatrix()
|
proxy: Optional[_ProxyDMatrix] = _ProxyDMatrix()
|
||||||
|
assert proxy is not None
|
||||||
proxy.set_info(base_margin=base_margin)
|
proxy.set_info(base_margin=base_margin)
|
||||||
p_handle = proxy.handle
|
p_handle = proxy.handle
|
||||||
else:
|
else:
|
||||||
@ -2274,8 +2275,8 @@ class Booster(object):
|
|||||||
return self.get_score(fmap, importance_type='weight')
|
return self.get_score(fmap, importance_type='weight')
|
||||||
|
|
||||||
def get_score(
|
def get_score(
|
||||||
self, fmap: os.PathLike = '', importance_type: str = 'weight'
|
self, fmap: Union[str, os.PathLike] = '', importance_type: str = 'weight'
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, Union[float, List[float]]]:
|
||||||
"""Get feature importance of each feature.
|
"""Get feature importance of each feature.
|
||||||
For tree model Importance type can be defined as:
|
For tree model Importance type can be defined as:
|
||||||
|
|
||||||
@ -2332,7 +2333,7 @@ class Booster(object):
|
|||||||
features_arr = from_cstr_to_pystr(features, n_out_features)
|
features_arr = from_cstr_to_pystr(features, n_out_features)
|
||||||
scores_arr = _prediction_output(shape, out_dim, scores, False)
|
scores_arr = _prediction_output(shape, out_dim, scores, False)
|
||||||
|
|
||||||
results = {}
|
results: Dict[str, Union[float, List[float]]] = {}
|
||||||
if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1:
|
if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1:
|
||||||
for feat, score in zip(features_arr, scores_arr):
|
for feat, score in zip(features_arr, scores_arr):
|
||||||
results[feat] = [float(s) for s in score]
|
results[feat] = [float(s) for s in score]
|
||||||
@ -2481,7 +2482,7 @@ class Booster(object):
|
|||||||
fmap: Union[os.PathLike, str] = '',
|
fmap: Union[os.PathLike, str] = '',
|
||||||
bins: Optional[int] = None,
|
bins: Optional[int] = None,
|
||||||
as_pandas: bool = True
|
as_pandas: bool = True
|
||||||
):
|
) -> Union[np.ndarray, DataFrame]:
|
||||||
"""Get split value histogram of a feature
|
"""Get split value histogram of a feature
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -2526,7 +2527,7 @@ class Booster(object):
|
|||||||
fn = [f"f{i}" for i in range(self.num_features())]
|
fn = [f"f{i}" for i in range(self.num_features())]
|
||||||
try:
|
try:
|
||||||
index = fn.index(feature)
|
index = fn.index(feature)
|
||||||
feature_t = ft[index]
|
feature_t: Optional[str] = cast(List[str], ft)[index]
|
||||||
except (ValueError, AttributeError, TypeError):
|
except (ValueError, AttributeError, TypeError):
|
||||||
# None.index: attr err, None[0]: type err, fn.index(-1): value err
|
# None.index: attr err, None[0]: type err, fn.index(-1): value err
|
||||||
feature_t = None
|
feature_t = None
|
||||||
|
|||||||
@ -3,12 +3,14 @@
|
|||||||
"""Distributed XGBoost Rabit related API."""
|
"""Distributed XGBoost Rabit related API."""
|
||||||
import ctypes
|
import ctypes
|
||||||
import pickle
|
import pickle
|
||||||
|
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .core import _LIB, c_str, STRING_TYPES, _check_call
|
from .core import _LIB, c_str, STRING_TYPES, _check_call
|
||||||
|
|
||||||
|
|
||||||
def _init_rabit():
|
def _init_rabit() -> None:
|
||||||
"""internal library initializer."""
|
"""internal library initializer."""
|
||||||
if _LIB is not None:
|
if _LIB is not None:
|
||||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||||
@ -17,21 +19,21 @@ def _init_rabit():
|
|||||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
def init(args=None):
|
def init(args: Optional[List[bytes]] = None) -> None:
|
||||||
"""Initialize the rabit library with arguments"""
|
"""Initialize the rabit library with arguments"""
|
||||||
if args is None:
|
if args is None:
|
||||||
args = []
|
args = []
|
||||||
arr = (ctypes.c_char_p * len(args))()
|
arr = (ctypes.c_char_p * len(args))()
|
||||||
arr[:] = args
|
arr[:] = cast(List[Union[ctypes.c_char_p, bytes, None, int]], args)
|
||||||
_LIB.RabitInit(len(arr), arr)
|
_LIB.RabitInit(len(arr), arr)
|
||||||
|
|
||||||
|
|
||||||
def finalize():
|
def finalize() -> None:
|
||||||
"""Finalize the process, notify tracker everything is done."""
|
"""Finalize the process, notify tracker everything is done."""
|
||||||
_LIB.RabitFinalize()
|
_LIB.RabitFinalize()
|
||||||
|
|
||||||
|
|
||||||
def get_rank():
|
def get_rank() -> int:
|
||||||
"""Get rank of current process.
|
"""Get rank of current process.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -43,7 +45,7 @@ def get_rank():
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size() -> int:
|
||||||
"""Get total number workers.
|
"""Get total number workers.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -55,13 +57,13 @@ def get_world_size():
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def is_distributed():
|
def is_distributed() -> int:
|
||||||
'''If rabit is distributed.'''
|
'''If rabit is distributed.'''
|
||||||
is_dist = _LIB.RabitIsDistributed()
|
is_dist = _LIB.RabitIsDistributed()
|
||||||
return is_dist
|
return is_dist
|
||||||
|
|
||||||
|
|
||||||
def tracker_print(msg):
|
def tracker_print(msg: Any) -> None:
|
||||||
"""Print message to the tracker.
|
"""Print message to the tracker.
|
||||||
|
|
||||||
This function can be used to communicate the information of
|
This function can be used to communicate the information of
|
||||||
@ -81,7 +83,7 @@ def tracker_print(msg):
|
|||||||
print(msg.strip(), flush=True)
|
print(msg.strip(), flush=True)
|
||||||
|
|
||||||
|
|
||||||
def get_processor_name():
|
def get_processor_name() -> bytes:
|
||||||
"""Get the processor name.
|
"""Get the processor name.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -96,7 +98,10 @@ def get_processor_name():
|
|||||||
return buf.value
|
return buf.value
|
||||||
|
|
||||||
|
|
||||||
def broadcast(data, root):
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast(data: T, root: int) -> T:
|
||||||
"""Broadcast object from one node to all other nodes.
|
"""Broadcast object from one node to all other nodes.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -155,7 +160,9 @@ class Op: # pylint: disable=too-few-public-methods
|
|||||||
OR = 3
|
OR = 3
|
||||||
|
|
||||||
|
|
||||||
def allreduce(data, op, prepare_fun=None):
|
def allreduce(
|
||||||
|
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||||
|
) -> np.ndarray:
|
||||||
"""Perform allreduce, return the result.
|
"""Perform allreduce, return the result.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -193,16 +200,17 @@ def allreduce(data, op, prepare_fun=None):
|
|||||||
else:
|
else:
|
||||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||||
|
|
||||||
def pfunc(_):
|
def pfunc(_: Any) -> None:
|
||||||
"""prepare function."""
|
"""prepare function."""
|
||||||
prepare_fun(data)
|
fn = cast(Callable[[np.ndarray], None], prepare_fun)
|
||||||
|
fn(data)
|
||||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
op, func_ptr(pfunc), None))
|
op, func_ptr(pfunc), None))
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
def version_number():
|
def version_number() -> int:
|
||||||
"""Returns version number of current stored model.
|
"""Returns version number of current stored model.
|
||||||
|
|
||||||
This means how many calls to CheckPoint we made so far.
|
This means how many calls to CheckPoint we made so far.
|
||||||
|
|||||||
@ -23,14 +23,9 @@ from .compat import (
|
|||||||
XGBClassifierBase,
|
XGBClassifierBase,
|
||||||
XGBRegressorBase,
|
XGBRegressorBase,
|
||||||
XGBoostLabelEncoder,
|
XGBoostLabelEncoder,
|
||||||
DataFrame,
|
|
||||||
scipy_csr,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually XGBoost supports a lot more data types including `scipy.sparse.csr_matrix` and
|
array_like = Any
|
||||||
# many others. See `data.py` for a complete list. The `array_like` here is just for
|
|
||||||
# easier type checks.
|
|
||||||
array_like = TypeVar("array_like", bound=Union[np.ndarray, DataFrame, scipy_csr])
|
|
||||||
|
|
||||||
|
|
||||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user