Address remaining mypy errors in python package. (#7914)
This commit is contained in:
parent
edf9a9608e
commit
f93a727869
5
Makefile
5
Makefile
@ -91,10 +91,7 @@ endif
|
||||
# If any of the dask tests failed, contributor won't see the other error.
|
||||
mypy:
|
||||
cd python-package; \
|
||||
mypy ./xgboost/dask.py && \
|
||||
mypy ./xgboost/rabit.py && \
|
||||
mypy ./xgboost/tracker.py && \
|
||||
mypy ./xgboost/sklearn.py && \
|
||||
mypy . && \
|
||||
mypy ../demo/guide-python/external_memory.py && \
|
||||
mypy ../demo/guide-python/categorical.py && \
|
||||
mypy ../demo/guide-python/cat_in_the_dat.py && \
|
||||
|
||||
@ -335,11 +335,12 @@ if __name__ == '__main__':
|
||||
'scipy',
|
||||
],
|
||||
ext_modules=[CMakeExtension('libxgboost')],
|
||||
# error: expected "str": "Type[Command]"
|
||||
cmdclass={
|
||||
'build_ext': BuildExt,
|
||||
'sdist': Sdist,
|
||||
'install_lib': InstallLib,
|
||||
'install': Install
|
||||
'build_ext': BuildExt, # type: ignore
|
||||
'sdist': Sdist, # type: ignore
|
||||
'install_lib': InstallLib, # type: ignore
|
||||
'install': Install # type: ignore
|
||||
},
|
||||
extras_require={
|
||||
'pandas': ['pandas'],
|
||||
|
||||
@ -8,12 +8,10 @@ import importlib.util
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from xgboost._typing import CStrPtr
|
||||
|
||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
||||
|
||||
|
||||
def py_str(x: CStrPtr) -> str:
|
||||
def py_str(x: bytes) -> str:
|
||||
"""convert c string back to python string"""
|
||||
return x.decode('utf-8') # type: ignore
|
||||
|
||||
@ -103,16 +101,6 @@ class XGBoostLabelEncoder(LabelEncoder):
|
||||
self.__dict__.update(meta)
|
||||
|
||||
|
||||
# dask
|
||||
try:
|
||||
import pkg_resources
|
||||
pkg_resources.get_distribution('dask')
|
||||
DASK_INSTALLED = True
|
||||
except pkg_resources.DistributionNotFound:
|
||||
dask = None
|
||||
DASK_INSTALLED = False
|
||||
|
||||
|
||||
try:
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
|
||||
@ -114,7 +114,9 @@ def set_config(**new_config: Any) -> None:
|
||||
def get_config() -> Dict[str, Any]:
|
||||
config_str = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBGetGlobalConfig(ctypes.byref(config_str)))
|
||||
config = json.loads(py_str(config_str.value))
|
||||
value = config_str.value
|
||||
assert value
|
||||
config = json.loads(py_str(value))
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ class XGBoostError(ValueError):
|
||||
"""Error thrown by xgboost trainer."""
|
||||
|
||||
|
||||
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, CStrPptr]:
|
||||
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, ctypes.Array]:
|
||||
"""Convert a Python str or list of Python str to C pointer
|
||||
|
||||
Parameters
|
||||
@ -55,9 +55,9 @@ def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, CStrPptr]:
|
||||
if isinstance(data, str):
|
||||
return bytes(data, "utf-8")
|
||||
if isinstance(data, list):
|
||||
pointers: ctypes.pointer = (ctypes.c_char_p * len(data))()
|
||||
pointers: ctypes.Array[ctypes.c_char_p] = (ctypes.c_char_p * len(data))()
|
||||
data_as_bytes = [bytes(d, 'utf-8') for d in data]
|
||||
pointers[:] = data_as_bytes
|
||||
pointers[:] = data_as_bytes # type: ignore
|
||||
return pointers
|
||||
raise TypeError()
|
||||
|
||||
@ -272,7 +272,7 @@ def _cuda_array_interface(data: DataType) -> bytes:
|
||||
def ctypes2numpy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> np.ndarray:
|
||||
"""Convert a ctypes pointer array to a numpy array."""
|
||||
ctype: Type[CNumeric] = _numpy2ctypes_type(dtype)
|
||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||
if not isinstance(cptr, ctypes.POINTER(ctype)): # type: ignore
|
||||
raise RuntimeError(f"expected {ctype} pointer")
|
||||
res = np.zeros(length, dtype=dtype)
|
||||
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): # type: ignore
|
||||
@ -287,7 +287,10 @@ def ctypes2cupy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> CupyT
|
||||
from cupy.cuda.memory import MemoryPointer
|
||||
from cupy.cuda.memory import UnownedMemory
|
||||
|
||||
CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
|
||||
CUPY_TO_CTYPES_MAPPING: Dict[Type[np.number], Type[CNumeric]] = {
|
||||
cupy.float32: ctypes.c_float,
|
||||
cupy.uint32: ctypes.c_uint,
|
||||
}
|
||||
if dtype not in CUPY_TO_CTYPES_MAPPING:
|
||||
raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}")
|
||||
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
||||
@ -322,10 +325,12 @@ def c_str(string: str) -> ctypes.c_char_p:
|
||||
return ctypes.c_char_p(string.encode('utf-8'))
|
||||
|
||||
|
||||
def c_array(ctype: Type[CTypeT], values: ArrayLike) -> ctypes.Array:
|
||||
def c_array(
|
||||
ctype: Type[CTypeT], values: ArrayLike
|
||||
) -> Union[ctypes.Array, ctypes.pointer]:
|
||||
"""Convert a python string to c array."""
|
||||
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
||||
return (ctype * len(values)).from_buffer_copy(values)
|
||||
return values.ctypes.data_as(ctypes.POINTER(ctype))
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
@ -1611,7 +1616,9 @@ class Booster:
|
||||
_check_call(_LIB.XGBoosterGetAttr(
|
||||
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
|
||||
if success.value != 0:
|
||||
return py_str(ret.value)
|
||||
value = ret.value
|
||||
assert value
|
||||
return py_str(value)
|
||||
return None
|
||||
|
||||
def attributes(self) -> Dict[str, Optional[str]]:
|
||||
@ -1639,12 +1646,10 @@ class Booster:
|
||||
The attributes to set. Setting a value to None deletes an attribute.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
c_value = None
|
||||
if value is not None:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Set Attr only accepts string values")
|
||||
value = c_str(str(value))
|
||||
_check_call(_LIB.XGBoosterSetAttr(
|
||||
self.handle, c_str(key), value))
|
||||
c_value = c_str(str(value))
|
||||
_check_call(_LIB.XGBoosterSetAttr(self.handle, c_str(key), c_value))
|
||||
|
||||
def _get_feature_info(self, field: str) -> Optional[FeatureInfo]:
|
||||
length = c_bst_ulong()
|
||||
@ -2628,10 +2633,10 @@ class Booster:
|
||||
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
||||
|
||||
nph = np.histogram(values, bins=bins)
|
||||
nph = np.column_stack((nph[1][1:], nph[0]))
|
||||
nph = nph[nph[:, 1] > 0]
|
||||
nph_stacked = np.column_stack((nph[1][1:], nph[0]))
|
||||
nph_stacked = nph_stacked[nph_stacked[:, 1] > 0]
|
||||
|
||||
if nph.size == 0:
|
||||
if nph_stacked.size == 0:
|
||||
ft = self.feature_types
|
||||
fn = self.feature_names
|
||||
if fn is None:
|
||||
@ -2649,11 +2654,11 @@ class Booster:
|
||||
)
|
||||
|
||||
if as_pandas and PANDAS_INSTALLED:
|
||||
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||
return DataFrame(nph_stacked, columns=['SplitValue', 'Count'])
|
||||
if as_pandas and not PANDAS_INSTALLED:
|
||||
warnings.warn(
|
||||
"Returning histogram as ndarray"
|
||||
" (as_pandas == True, but pandas is not installed).",
|
||||
UserWarning
|
||||
)
|
||||
return nph
|
||||
return nph_stacked
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
from xgboost.compat import DASK_INSTALLED
|
||||
import pytest
|
||||
import gc
|
||||
import xgboost as xgb
|
||||
@ -44,8 +43,14 @@ def no_sklearn():
|
||||
|
||||
|
||||
def no_dask():
|
||||
return {'condition': not DASK_INSTALLED,
|
||||
'reason': 'Dask is not installed'}
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
pkg_resources.get_distribution("dask")
|
||||
DASK_INSTALLED = True
|
||||
except pkg_resources.DistributionNotFound:
|
||||
DASK_INSTALLED = False
|
||||
return {"condition": not DASK_INSTALLED, "reason": "Dask is not installed"}
|
||||
|
||||
|
||||
def no_pandas():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user