parent
97c3a80a34
commit
e47b3a3da3
@ -1,7 +1,18 @@
|
||||
# pylint: disable=protected-access
|
||||
"""Shared typing definition."""
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||
# cudf.DataFrame/cupy.array/dlpack
|
||||
@ -32,14 +43,15 @@ FPreProcCallable = Callable
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
|
||||
|
||||
CTypeT = Union[
|
||||
CTypeT = TypeVar(
|
||||
"CTypeT",
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_float,
|
||||
ctypes.c_uint,
|
||||
ctypes.c_size_t,
|
||||
]
|
||||
)
|
||||
|
||||
# supported numeric types
|
||||
CNumeric = Union[
|
||||
@ -52,21 +64,36 @@ CNumeric = Union[
|
||||
]
|
||||
|
||||
# c pointer types
|
||||
# real type should be, as defined in typeshed
|
||||
# but this has to be put in a .pyi file
|
||||
# c_str_ptr_t = ctypes.pointer[ctypes.c_char]
|
||||
CStrPtr = ctypes.pointer
|
||||
# c_str_pptr_t = ctypes.pointer[ctypes.c_char_p]
|
||||
CStrPptr = ctypes.pointer
|
||||
# c_float_ptr_t = ctypes.pointer[ctypes.c_float]
|
||||
CFloatPtr = ctypes.pointer
|
||||
if TYPE_CHECKING:
|
||||
CStrPtr = ctypes._Pointer[ctypes.c_char]
|
||||
|
||||
# c_numeric_ptr_t = Union[
|
||||
# ctypes.pointer[ctypes.c_float], ctypes.pointer[ctypes.c_double],
|
||||
# ctypes.pointer[ctypes.c_uint], ctypes.pointer[ctypes.c_uint64],
|
||||
# ctypes.pointer[ctypes.c_int32], ctypes.pointer[ctypes.c_int64]
|
||||
# ]
|
||||
CNumericPtr = ctypes.pointer
|
||||
CStrPptr = ctypes._Pointer[ctypes.c_char_p]
|
||||
|
||||
CFloatPtr = ctypes._Pointer[ctypes.c_float]
|
||||
|
||||
CNumericPtr = Union[
|
||||
ctypes._Pointer[ctypes.c_float],
|
||||
ctypes._Pointer[ctypes.c_double],
|
||||
ctypes._Pointer[ctypes.c_uint],
|
||||
ctypes._Pointer[ctypes.c_uint64],
|
||||
ctypes._Pointer[ctypes.c_int32],
|
||||
ctypes._Pointer[ctypes.c_int64],
|
||||
]
|
||||
else:
|
||||
CStrPtr = ctypes._Pointer
|
||||
|
||||
CStrPptr = ctypes._Pointer
|
||||
|
||||
CFloatPtr = ctypes._Pointer
|
||||
|
||||
CNumericPtr = Union[
|
||||
ctypes._Pointer,
|
||||
ctypes._Pointer,
|
||||
ctypes._Pointer,
|
||||
ctypes._Pointer,
|
||||
ctypes._Pointer,
|
||||
ctypes._Pointer,
|
||||
]
|
||||
|
||||
# template parameter
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -99,9 +99,9 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
|
||||
res = []
|
||||
for i in range(length.value):
|
||||
try:
|
||||
res.append(str(data[i].decode('ascii'))) # type: ignore
|
||||
res.append(str(cast(bytes, data[i]).decode('ascii')))
|
||||
except UnicodeDecodeError:
|
||||
res.append(str(data[i].decode('utf-8'))) # type: ignore
|
||||
res.append(str(cast(bytes, data[i]).decode('utf-8')))
|
||||
return res
|
||||
|
||||
|
||||
@ -381,7 +381,7 @@ def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
|
||||
raise RuntimeError('expected char pointer')
|
||||
res = bytearray(length)
|
||||
rptr = (ctypes.c_char * length).from_buffer(res)
|
||||
if not ctypes.memmove(rptr, cptr, length): # type: ignore
|
||||
if not ctypes.memmove(rptr, cptr, length):
|
||||
raise RuntimeError('memmove failed')
|
||||
return res
|
||||
|
||||
@ -393,8 +393,8 @@ def c_str(string: str) -> ctypes.c_char_p:
|
||||
|
||||
def c_array(
|
||||
ctype: Type[CTypeT], values: ArrayLike
|
||||
) -> Union[ctypes.Array, ctypes.pointer]:
|
||||
"""Convert a python string to c array."""
|
||||
) -> Union[ctypes.Array, ctypes._Pointer]:
|
||||
"""Convert a python array to c array."""
|
||||
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
||||
return values.ctypes.data_as(ctypes.POINTER(ctype))
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
@ -13,7 +13,7 @@ from .callback import TrainingCallback, CallbackContainer, EvaluationMonitor, Ea
|
||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||
from .core import Metric, Objective
|
||||
from .compat import SKLEARN_INSTALLED, XGBStratifiedKFold, DataFrame
|
||||
from ._typing import _F, FPreProcCallable, BoosterParam
|
||||
from ._typing import Callable, FPreProcCallable, BoosterParam
|
||||
|
||||
_CVFolds = Sequence["CVPack"]
|
||||
|
||||
@ -205,10 +205,10 @@ class CVPack:
|
||||
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
|
||||
self.bst = Booster(param, [dtrain, dtest])
|
||||
|
||||
def __getattr__(self, name: str) -> _F:
|
||||
def __getattr__(self, name: str) -> Callable:
|
||||
def _inner(*args: Any, **kwargs: Any) -> Any:
|
||||
return getattr(self.bst, name)(*args, **kwargs)
|
||||
return cast(_F, _inner)
|
||||
return _inner
|
||||
|
||||
def update(self, iteration: int, fobj: Optional[Objective]) -> None:
|
||||
""""Update the boosters for one iteration"""
|
||||
|
||||
@ -6,7 +6,7 @@ dependencies:
|
||||
- pylint
|
||||
- wheel
|
||||
- setuptools
|
||||
- mypy=0.961
|
||||
- mypy>=0.981
|
||||
- numpy
|
||||
- scipy
|
||||
- pandas
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user