Address remaining mypy errors in python package. (#7914)

This commit is contained in:
Jiaming Yuan 2022-05-18 22:46:15 +08:00 committed by GitHub
parent edf9a9608e
commit f93a727869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 41 additions and 43 deletions

View File

@ -91,10 +91,7 @@ endif
# If any of the dask tests failed, contributor won't see the other error.
mypy:
cd python-package; \
mypy ./xgboost/dask.py && \
mypy ./xgboost/rabit.py && \
mypy ./xgboost/tracker.py && \
mypy ./xgboost/sklearn.py && \
mypy . && \
mypy ../demo/guide-python/external_memory.py && \
mypy ../demo/guide-python/categorical.py && \
mypy ../demo/guide-python/cat_in_the_dat.py && \

View File

@ -335,11 +335,12 @@ if __name__ == '__main__':
'scipy',
],
ext_modules=[CMakeExtension('libxgboost')],
# error: expected "str": "Type[Command]"
cmdclass={
'build_ext': BuildExt,
'sdist': Sdist,
'install_lib': InstallLib,
'install': Install
'build_ext': BuildExt, # type: ignore
'sdist': Sdist, # type: ignore
'install_lib': InstallLib, # type: ignore
'install': Install # type: ignore
},
extras_require={
'pandas': ['pandas'],

View File

@ -8,12 +8,10 @@ import importlib.util
import logging
import numpy as np
from xgboost._typing import CStrPtr
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
def py_str(x: CStrPtr) -> str:
def py_str(x: bytes) -> str:
"""convert c string back to python string"""
return x.decode('utf-8') # type: ignore
@ -103,16 +101,6 @@ class XGBoostLabelEncoder(LabelEncoder):
self.__dict__.update(meta)
# dask
try:
import pkg_resources
pkg_resources.get_distribution('dask')
DASK_INSTALLED = True
except pkg_resources.DistributionNotFound:
dask = None
DASK_INSTALLED = False
try:
import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr

View File

@ -114,7 +114,9 @@ def set_config(**new_config: Any) -> None:
def get_config() -> Dict[str, Any]:
config_str = ctypes.c_char_p()
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
config = json.loads(py_str(config_str.value))
value = config_str.value
assert value
config = json.loads(py_str(value))
return config

View File

@ -43,7 +43,7 @@ class XGBoostError(ValueError):
"""Error thrown by xgboost trainer."""
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, CStrPptr]:
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, ctypes.Array]:
"""Convert a Python str or list of Python str to C pointer
Parameters
@ -55,9 +55,9 @@ def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, CStrPptr]:
if isinstance(data, str):
return bytes(data, "utf-8")
if isinstance(data, list):
pointers: ctypes.pointer = (ctypes.c_char_p * len(data))()
pointers: ctypes.Array[ctypes.c_char_p] = (ctypes.c_char_p * len(data))()
data_as_bytes = [bytes(d, 'utf-8') for d in data]
pointers[:] = data_as_bytes
pointers[:] = data_as_bytes # type: ignore
return pointers
raise TypeError()
@ -272,7 +272,7 @@ def _cuda_array_interface(data: DataType) -> bytes:
def ctypes2numpy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> np.ndarray:
"""Convert a ctypes pointer array to a numpy array."""
ctype: Type[CNumeric] = _numpy2ctypes_type(dtype)
if not isinstance(cptr, ctypes.POINTER(ctype)):
if not isinstance(cptr, ctypes.POINTER(ctype)): # type: ignore
raise RuntimeError(f"expected {ctype} pointer")
res = np.zeros(length, dtype=dtype)
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): # type: ignore
@ -287,7 +287,10 @@ def ctypes2cupy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> CupyT
from cupy.cuda.memory import MemoryPointer
from cupy.cuda.memory import UnownedMemory
CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
CUPY_TO_CTYPES_MAPPING: Dict[Type[np.number], Type[CNumeric]] = {
cupy.float32: ctypes.c_float,
cupy.uint32: ctypes.c_uint,
}
if dtype not in CUPY_TO_CTYPES_MAPPING:
raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}")
addr = ctypes.cast(cptr, ctypes.c_void_p).value
@ -322,10 +325,12 @@ def c_str(string: str) -> ctypes.c_char_p:
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype: Type[CTypeT], values: ArrayLike) -> ctypes.Array:
def c_array(
ctype: Type[CTypeT], values: ArrayLike
) -> Union[ctypes.Array, ctypes.pointer]:
"""Convert a python string to c array."""
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
return (ctype * len(values)).from_buffer_copy(values)
return values.ctypes.data_as(ctypes.POINTER(ctype))
return (ctype * len(values))(*values)
@ -1611,7 +1616,9 @@ class Booster:
_check_call(_LIB.XGBoosterGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
value = ret.value
assert value
return py_str(value)
return None
def attributes(self) -> Dict[str, Optional[str]]:
@ -1639,12 +1646,10 @@ class Booster:
The attributes to set. Setting a value to None deletes an attribute.
"""
for key, value in kwargs.items():
c_value = None
if value is not None:
if not isinstance(value, str):
raise ValueError("Set Attr only accepts string values")
value = c_str(str(value))
_check_call(_LIB.XGBoosterSetAttr(
self.handle, c_str(key), value))
c_value = c_str(str(value))
_check_call(_LIB.XGBoosterSetAttr(self.handle, c_str(key), c_value))
def _get_feature_info(self, field: str) -> Optional[FeatureInfo]:
length = c_bst_ulong()
@ -2628,10 +2633,10 @@ class Booster:
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
nph = np.histogram(values, bins=bins)
nph = np.column_stack((nph[1][1:], nph[0]))
nph = nph[nph[:, 1] > 0]
nph_stacked = np.column_stack((nph[1][1:], nph[0]))
nph_stacked = nph_stacked[nph_stacked[:, 1] > 0]
if nph.size == 0:
if nph_stacked.size == 0:
ft = self.feature_types
fn = self.feature_names
if fn is None:
@ -2649,11 +2654,11 @@ class Booster:
)
if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count'])
return DataFrame(nph_stacked, columns=['SplitValue', 'Count'])
if as_pandas and not PANDAS_INSTALLED:
warnings.warn(
"Returning histogram as ndarray"
" (as_pandas == True, but pandas is not installed).",
UserWarning
)
return nph
return nph_stacked

View File

@ -7,7 +7,6 @@ from typing import Optional
from contextlib import contextmanager
from io import StringIO
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
from xgboost.compat import DASK_INSTALLED
import pytest
import gc
import xgboost as xgb
@ -44,8 +43,14 @@ def no_sklearn():
def no_dask():
return {'condition': not DASK_INSTALLED,
'reason': 'Dask is not installed'}
try:
import pkg_resources
pkg_resources.get_distribution("dask")
DASK_INSTALLED = True
except pkg_resources.DistributionNotFound:
DASK_INSTALLED = False
return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"}
def no_pandas():