Upgrade mypy. (#8302)

Some breaking changes were made in mypy.
This commit is contained in:
Jiaming Yuan 2022-10-05 14:31:59 +08:00 committed by GitHub
parent 97c3a80a34
commit e47b3a3da3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 26 deletions

View File

@ -1,7 +1,18 @@
# pylint: disable=protected-access
"""Shared typing definition.""" """Shared typing definition."""
import ctypes import ctypes
import os import os
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Sequence,
Type,
TypeVar,
Union,
)
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/ # os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
# cudf.DataFrame/cupy.array/dlpack # cudf.DataFrame/cupy.array/dlpack
@ -32,14 +43,15 @@ FPreProcCallable = Callable
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103 c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
CTypeT = Union[ CTypeT = TypeVar(
"CTypeT",
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_char_p, ctypes.c_char_p,
ctypes.c_int, ctypes.c_int,
ctypes.c_float, ctypes.c_float,
ctypes.c_uint, ctypes.c_uint,
ctypes.c_size_t, ctypes.c_size_t,
] )
# supported numeric types # supported numeric types
CNumeric = Union[ CNumeric = Union[
@ -52,21 +64,36 @@ CNumeric = Union[
] ]
# c pointer types # c pointer types
# real type should be, as defined in typeshed if TYPE_CHECKING:
# but this has to be put in a .pyi file CStrPtr = ctypes._Pointer[ctypes.c_char]
# c_str_ptr_t = ctypes.pointer[ctypes.c_char]
CStrPtr = ctypes.pointer
# c_str_pptr_t = ctypes.pointer[ctypes.c_char_p]
CStrPptr = ctypes.pointer
# c_float_ptr_t = ctypes.pointer[ctypes.c_float]
CFloatPtr = ctypes.pointer
# c_numeric_ptr_t = Union[ CStrPptr = ctypes._Pointer[ctypes.c_char_p]
# ctypes.pointer[ctypes.c_float], ctypes.pointer[ctypes.c_double],
# ctypes.pointer[ctypes.c_uint], ctypes.pointer[ctypes.c_uint64], CFloatPtr = ctypes._Pointer[ctypes.c_float]
# ctypes.pointer[ctypes.c_int32], ctypes.pointer[ctypes.c_int64]
# ] CNumericPtr = Union[
CNumericPtr = ctypes.pointer ctypes._Pointer[ctypes.c_float],
ctypes._Pointer[ctypes.c_double],
ctypes._Pointer[ctypes.c_uint],
ctypes._Pointer[ctypes.c_uint64],
ctypes._Pointer[ctypes.c_int32],
ctypes._Pointer[ctypes.c_int64],
]
else:
CStrPtr = ctypes._Pointer
CStrPptr = ctypes._Pointer
CFloatPtr = ctypes._Pointer
CNumericPtr = Union[
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
]
# template parameter # template parameter
_T = TypeVar("_T") _T = TypeVar("_T")

View File

@ -99,9 +99,9 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
res = [] res = []
for i in range(length.value): for i in range(length.value):
try: try:
res.append(str(data[i].decode('ascii'))) # type: ignore res.append(str(cast(bytes, data[i]).decode('ascii')))
except UnicodeDecodeError: except UnicodeDecodeError:
res.append(str(data[i].decode('utf-8'))) # type: ignore res.append(str(cast(bytes, data[i]).decode('utf-8')))
return res return res
@ -381,7 +381,7 @@ def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
raise RuntimeError('expected char pointer') raise RuntimeError('expected char pointer')
res = bytearray(length) res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res) rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length): # type: ignore if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed') raise RuntimeError('memmove failed')
return res return res
@ -393,8 +393,8 @@ def c_str(string: str) -> ctypes.c_char_p:
def c_array( def c_array(
ctype: Type[CTypeT], values: ArrayLike ctype: Type[CTypeT], values: ArrayLike
) -> Union[ctypes.Array, ctypes.pointer]: ) -> Union[ctypes.Array, ctypes._Pointer]:
"""Convert a python string to c array.""" """Convert a python array to c array."""
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype): if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
return values.ctypes.data_as(ctypes.POINTER(ctype)) return values.ctypes.data_as(ctypes.POINTER(ctype))
return (ctype * len(values))(*values) return (ctype * len(values))(*values)

View File

@ -13,7 +13,7 @@ from .callback import TrainingCallback, CallbackContainer, EvaluationMonitor, Ea
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .core import Metric, Objective from .core import Metric, Objective
from .compat import SKLEARN_INSTALLED, XGBStratifiedKFold, DataFrame from .compat import SKLEARN_INSTALLED, XGBStratifiedKFold, DataFrame
from ._typing import _F, FPreProcCallable, BoosterParam from ._typing import Callable, FPreProcCallable, BoosterParam
_CVFolds = Sequence["CVPack"] _CVFolds = Sequence["CVPack"]
@ -205,10 +205,10 @@ class CVPack:
self.watchlist = [(dtrain, 'train'), (dtest, 'test')] self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.bst = Booster(param, [dtrain, dtest]) self.bst = Booster(param, [dtrain, dtest])
def __getattr__(self, name: str) -> _F: def __getattr__(self, name: str) -> Callable:
def _inner(*args: Any, **kwargs: Any) -> Any: def _inner(*args: Any, **kwargs: Any) -> Any:
return getattr(self.bst, name)(*args, **kwargs) return getattr(self.bst, name)(*args, **kwargs)
return cast(_F, _inner) return _inner
def update(self, iteration: int, fobj: Optional[Objective]) -> None: def update(self, iteration: int, fobj: Optional[Objective]) -> None:
""""Update the boosters for one iteration""" """"Update the boosters for one iteration"""

View File

@ -6,7 +6,7 @@ dependencies:
- pylint - pylint
- wheel - wheel
- setuptools - setuptools
- mypy=0.961 - mypy>=0.981
- numpy - numpy
- scipy - scipy
- pandas - pandas