parent
97c3a80a34
commit
e47b3a3da3
@ -1,7 +1,18 @@
|
|||||||
|
# pylint: disable=protected-access
|
||||||
"""Shared typing definition."""
|
"""Shared typing definition."""
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||||
# cudf.DataFrame/cupy.array/dlpack
|
# cudf.DataFrame/cupy.array/dlpack
|
||||||
@ -32,14 +43,15 @@ FPreProcCallable = Callable
|
|||||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
|
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
|
||||||
|
|
||||||
CTypeT = Union[
|
CTypeT = TypeVar(
|
||||||
|
"CTypeT",
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
ctypes.c_char_p,
|
ctypes.c_char_p,
|
||||||
ctypes.c_int,
|
ctypes.c_int,
|
||||||
ctypes.c_float,
|
ctypes.c_float,
|
||||||
ctypes.c_uint,
|
ctypes.c_uint,
|
||||||
ctypes.c_size_t,
|
ctypes.c_size_t,
|
||||||
]
|
)
|
||||||
|
|
||||||
# supported numeric types
|
# supported numeric types
|
||||||
CNumeric = Union[
|
CNumeric = Union[
|
||||||
@ -52,21 +64,36 @@ CNumeric = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
# c pointer types
|
# c pointer types
|
||||||
# real type should be, as defined in typeshed
|
if TYPE_CHECKING:
|
||||||
# but this has to be put in a .pyi file
|
CStrPtr = ctypes._Pointer[ctypes.c_char]
|
||||||
# c_str_ptr_t = ctypes.pointer[ctypes.c_char]
|
|
||||||
CStrPtr = ctypes.pointer
|
|
||||||
# c_str_pptr_t = ctypes.pointer[ctypes.c_char_p]
|
|
||||||
CStrPptr = ctypes.pointer
|
|
||||||
# c_float_ptr_t = ctypes.pointer[ctypes.c_float]
|
|
||||||
CFloatPtr = ctypes.pointer
|
|
||||||
|
|
||||||
# c_numeric_ptr_t = Union[
|
CStrPptr = ctypes._Pointer[ctypes.c_char_p]
|
||||||
# ctypes.pointer[ctypes.c_float], ctypes.pointer[ctypes.c_double],
|
|
||||||
# ctypes.pointer[ctypes.c_uint], ctypes.pointer[ctypes.c_uint64],
|
CFloatPtr = ctypes._Pointer[ctypes.c_float]
|
||||||
# ctypes.pointer[ctypes.c_int32], ctypes.pointer[ctypes.c_int64]
|
|
||||||
# ]
|
CNumericPtr = Union[
|
||||||
CNumericPtr = ctypes.pointer
|
ctypes._Pointer[ctypes.c_float],
|
||||||
|
ctypes._Pointer[ctypes.c_double],
|
||||||
|
ctypes._Pointer[ctypes.c_uint],
|
||||||
|
ctypes._Pointer[ctypes.c_uint64],
|
||||||
|
ctypes._Pointer[ctypes.c_int32],
|
||||||
|
ctypes._Pointer[ctypes.c_int64],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
CStrPtr = ctypes._Pointer
|
||||||
|
|
||||||
|
CStrPptr = ctypes._Pointer
|
||||||
|
|
||||||
|
CFloatPtr = ctypes._Pointer
|
||||||
|
|
||||||
|
CNumericPtr = Union[
|
||||||
|
ctypes._Pointer,
|
||||||
|
ctypes._Pointer,
|
||||||
|
ctypes._Pointer,
|
||||||
|
ctypes._Pointer,
|
||||||
|
ctypes._Pointer,
|
||||||
|
ctypes._Pointer,
|
||||||
|
]
|
||||||
|
|
||||||
# template parameter
|
# template parameter
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|||||||
@ -99,9 +99,9 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
|
|||||||
res = []
|
res = []
|
||||||
for i in range(length.value):
|
for i in range(length.value):
|
||||||
try:
|
try:
|
||||||
res.append(str(data[i].decode('ascii'))) # type: ignore
|
res.append(str(cast(bytes, data[i]).decode('ascii')))
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
res.append(str(data[i].decode('utf-8'))) # type: ignore
|
res.append(str(cast(bytes, data[i]).decode('utf-8')))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -381,7 +381,7 @@ def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
|
|||||||
raise RuntimeError('expected char pointer')
|
raise RuntimeError('expected char pointer')
|
||||||
res = bytearray(length)
|
res = bytearray(length)
|
||||||
rptr = (ctypes.c_char * length).from_buffer(res)
|
rptr = (ctypes.c_char * length).from_buffer(res)
|
||||||
if not ctypes.memmove(rptr, cptr, length): # type: ignore
|
if not ctypes.memmove(rptr, cptr, length):
|
||||||
raise RuntimeError('memmove failed')
|
raise RuntimeError('memmove failed')
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -393,8 +393,8 @@ def c_str(string: str) -> ctypes.c_char_p:
|
|||||||
|
|
||||||
def c_array(
|
def c_array(
|
||||||
ctype: Type[CTypeT], values: ArrayLike
|
ctype: Type[CTypeT], values: ArrayLike
|
||||||
) -> Union[ctypes.Array, ctypes.pointer]:
|
) -> Union[ctypes.Array, ctypes._Pointer]:
|
||||||
"""Convert a python string to c array."""
|
"""Convert a python array to c array."""
|
||||||
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
||||||
return values.ctypes.data_as(ctypes.POINTER(ctype))
|
return values.ctypes.data_as(ctypes.POINTER(ctype))
|
||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from .callback import TrainingCallback, CallbackContainer, EvaluationMonitor, Ea
|
|||||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||||
from .core import Metric, Objective
|
from .core import Metric, Objective
|
||||||
from .compat import SKLEARN_INSTALLED, XGBStratifiedKFold, DataFrame
|
from .compat import SKLEARN_INSTALLED, XGBStratifiedKFold, DataFrame
|
||||||
from ._typing import _F, FPreProcCallable, BoosterParam
|
from ._typing import Callable, FPreProcCallable, BoosterParam
|
||||||
|
|
||||||
_CVFolds = Sequence["CVPack"]
|
_CVFolds = Sequence["CVPack"]
|
||||||
|
|
||||||
@ -205,10 +205,10 @@ class CVPack:
|
|||||||
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
|
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
|
||||||
self.bst = Booster(param, [dtrain, dtest])
|
self.bst = Booster(param, [dtrain, dtest])
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> _F:
|
def __getattr__(self, name: str) -> Callable:
|
||||||
def _inner(*args: Any, **kwargs: Any) -> Any:
|
def _inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
return getattr(self.bst, name)(*args, **kwargs)
|
return getattr(self.bst, name)(*args, **kwargs)
|
||||||
return cast(_F, _inner)
|
return _inner
|
||||||
|
|
||||||
def update(self, iteration: int, fobj: Optional[Objective]) -> None:
|
def update(self, iteration: int, fobj: Optional[Objective]) -> None:
|
||||||
""""Update the boosters for one iteration"""
|
""""Update the boosters for one iteration"""
|
||||||
|
|||||||
@ -6,7 +6,7 @@ dependencies:
|
|||||||
- pylint
|
- pylint
|
||||||
- wheel
|
- wheel
|
||||||
- setuptools
|
- setuptools
|
||||||
- mypy=0.961
|
- mypy>=0.981
|
||||||
- numpy
|
- numpy
|
||||||
- scipy
|
- scipy
|
||||||
- pandas
|
- pandas
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user