Add type hints to core.py (#7707)
Co-authored-by: Chengyang Gu <bridgream@gmail.com> Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
parent
66cb4afc6c
commit
c92ab2ce49
60
python-package/xgboost/_typing.py
Normal file
60
python-package/xgboost/_typing.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Shared typing definition."""
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
from typing import Optional, List, Any, TypeVar, Union
|
||||||
|
|
||||||
|
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||||
|
# cudf.DataFrame/cupy.array/dlpack
|
||||||
|
DataType = Any
|
||||||
|
|
||||||
|
# xgboost accepts some other possible types in practice due to historical reason, which is
|
||||||
|
# lesser tested. For now we encourage users to pass a simple list of string.
|
||||||
|
FeatureNames = Optional[List[str]]
|
||||||
|
|
||||||
|
ArrayLike = Any
|
||||||
|
PathLike = Union[str, os.PathLike]
|
||||||
|
CupyT = ArrayLike # maybe need a stub for cupy arrays
|
||||||
|
NumpyOrCupy = Any
|
||||||
|
|
||||||
|
# ctypes
|
||||||
|
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||||
|
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
|
||||||
|
|
||||||
|
CTypeT = Union[
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.c_float,
|
||||||
|
ctypes.c_uint,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
]
|
||||||
|
|
||||||
|
# supported numeric types
|
||||||
|
CNumeric = Union[
|
||||||
|
ctypes.c_float,
|
||||||
|
ctypes.c_double,
|
||||||
|
ctypes.c_uint,
|
||||||
|
ctypes.c_uint64,
|
||||||
|
ctypes.c_int32,
|
||||||
|
ctypes.c_int64,
|
||||||
|
]
|
||||||
|
|
||||||
|
# c pointer types
|
||||||
|
# real type should be, as defined in typeshed
|
||||||
|
# but this has to be put in a .pyi file
|
||||||
|
# c_str_ptr_t = ctypes.pointer[ctypes.c_char]
|
||||||
|
CStrPtr = ctypes.pointer
|
||||||
|
# c_str_pptr_t = ctypes.pointer[ctypes.c_char_p]
|
||||||
|
CStrPptr = ctypes.pointer
|
||||||
|
# c_float_ptr_t = ctypes.pointer[ctypes.c_float]
|
||||||
|
CFloatPtr = ctypes.pointer
|
||||||
|
|
||||||
|
# c_numeric_ptr_t = Union[
|
||||||
|
# ctypes.pointer[ctypes.c_float], ctypes.pointer[ctypes.c_double],
|
||||||
|
# ctypes.pointer[ctypes.c_uint], ctypes.pointer[ctypes.c_uint64],
|
||||||
|
# ctypes.pointer[ctypes.c_int32], ctypes.pointer[ctypes.c_int64]
|
||||||
|
# ]
|
||||||
|
CNumericPtr = ctypes.pointer
|
||||||
|
|
||||||
|
# template parameter
|
||||||
|
_T = TypeVar("_T")
|
||||||
@ -4,7 +4,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||||
from typing import Callable, Tuple, cast, Sequence
|
from typing import Callable, Tuple, cast, Sequence, Type, Iterable
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -17,22 +17,30 @@ from inspect import signature, Parameter
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
from .compat import (STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED,
|
from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED, lazy_isinstance
|
||||||
lazy_isinstance)
|
|
||||||
from .libpath import find_lib_path
|
from .libpath import find_lib_path
|
||||||
|
from ._typing import (
|
||||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
CStrPptr,
|
||||||
c_bst_ulong = ctypes.c_uint64
|
c_bst_ulong,
|
||||||
# xgboost accepts some other possible types in practice due to historical reason, which is
|
CNumeric,
|
||||||
# lesser tested. For now we encourage users to pass a simple list of string.
|
DataType,
|
||||||
FeatNamesT = Optional[List[str]]
|
CNumericPtr,
|
||||||
|
CStrPtr,
|
||||||
|
CTypeT,
|
||||||
|
ArrayLike,
|
||||||
|
CFloatPtr,
|
||||||
|
NumpyOrCupy,
|
||||||
|
FeatureNames,
|
||||||
|
_T,
|
||||||
|
CupyT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XGBoostError(ValueError):
|
class XGBoostError(ValueError):
|
||||||
"""Error thrown by xgboost trainer."""
|
"""Error thrown by xgboost trainer."""
|
||||||
|
|
||||||
|
|
||||||
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, CStrPptr]:
|
||||||
"""Convert a Python str or list of Python str to C pointer
|
"""Convert a Python str or list of Python str to C pointer
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -44,14 +52,14 @@ def from_pystr_to_cstr(data: Union[str, List[str]]):
|
|||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
return bytes(data, "utf-8")
|
return bytes(data, "utf-8")
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
pointers = (ctypes.c_char_p * len(data))()
|
pointers: ctypes.pointer = (ctypes.c_char_p * len(data))()
|
||||||
data = [bytes(d, 'utf-8') for d in data]
|
data_as_bytes = [bytes(d, 'utf-8') for d in data]
|
||||||
pointers[:] = data
|
pointers[:] = data_as_bytes
|
||||||
return pointers
|
return pointers
|
||||||
raise TypeError()
|
raise TypeError()
|
||||||
|
|
||||||
|
|
||||||
def from_cstr_to_pystr(data, length) -> List[str]:
|
def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
|
||||||
"""Revert C pointer to Python str
|
"""Revert C pointer to Python str
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -64,9 +72,9 @@ def from_cstr_to_pystr(data, length) -> List[str]:
|
|||||||
res = []
|
res = []
|
||||||
for i in range(length.value):
|
for i in range(length.value):
|
||||||
try:
|
try:
|
||||||
res.append(str(data[i].decode('ascii')))
|
res.append(str(data[i].decode('ascii'))) # type: ignore
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
res.append(str(data[i].decode('utf-8')))
|
res.append(str(data[i].decode('utf-8'))) # type: ignore
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -91,7 +99,7 @@ def _convert_ntree_limit(
|
|||||||
return iteration_range
|
return iteration_range
|
||||||
|
|
||||||
|
|
||||||
def _expect(expectations, got):
|
def _expect(expectations: Sequence[Type], got: Type) -> str:
|
||||||
"""Translate input error into string.
|
"""Translate input error into string.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -167,7 +175,7 @@ Likely causes:
|
|||||||
Error message(s): {os_error_list}
|
Error message(s): {os_error_list}
|
||||||
""")
|
""")
|
||||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||||
lib.callback = _get_log_callback_func()
|
lib.callback = _get_log_callback_func() # type: ignore
|
||||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||||
raise XGBoostError(lib.XGBGetLastError())
|
raise XGBoostError(lib.XGBGetLastError())
|
||||||
return lib
|
return lib
|
||||||
@ -192,7 +200,7 @@ def _check_call(ret: int) -> None:
|
|||||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||||
|
|
||||||
|
|
||||||
def _has_categorical(booster: "Booster", data: Any) -> bool:
|
def _has_categorical(booster: "Booster", data: DataType) -> bool:
|
||||||
"""Check whether the booster and input data for prediction contain categorical data.
|
"""Check whether the booster and input data for prediction contain categorical data.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -224,8 +232,8 @@ def build_info() -> dict:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _numpy2ctypes_type(dtype):
|
def _numpy2ctypes_type(dtype: Type[np.number]) -> Type[CNumeric]:
|
||||||
_NUMPY_TO_CTYPES_MAPPING = {
|
_NUMPY_TO_CTYPES_MAPPING: Dict[Type[np.number], Type[CNumeric]] = {
|
||||||
np.float32: ctypes.c_float,
|
np.float32: ctypes.c_float,
|
||||||
np.float64: ctypes.c_double,
|
np.float64: ctypes.c_double,
|
||||||
np.uint32: ctypes.c_uint,
|
np.uint32: ctypes.c_uint,
|
||||||
@ -242,7 +250,7 @@ def _numpy2ctypes_type(dtype):
|
|||||||
return _NUMPY_TO_CTYPES_MAPPING[dtype]
|
return _NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||||
|
|
||||||
|
|
||||||
def _cuda_array_interface(data) -> bytes:
|
def _cuda_array_interface(data: DataType) -> bytes:
|
||||||
assert (
|
assert (
|
||||||
data.dtype.hasobject is False
|
data.dtype.hasobject is False
|
||||||
), "Input data contains `object` dtype. Expecting numeric data."
|
), "Input data contains `object` dtype. Expecting numeric data."
|
||||||
@ -253,9 +261,9 @@ def _cuda_array_interface(data) -> bytes:
|
|||||||
return interface_str
|
return interface_str
|
||||||
|
|
||||||
|
|
||||||
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
def ctypes2numpy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> np.ndarray:
|
||||||
"""Convert a ctypes pointer array to a numpy array."""
|
"""Convert a ctypes pointer array to a numpy array."""
|
||||||
ctype = _numpy2ctypes_type(dtype)
|
ctype: Type[CNumeric] = _numpy2ctypes_type(dtype)
|
||||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||||
raise RuntimeError(f"expected {ctype} pointer")
|
raise RuntimeError(f"expected {ctype} pointer")
|
||||||
res = np.zeros(length, dtype=dtype)
|
res = np.zeros(length, dtype=dtype)
|
||||||
@ -264,7 +272,7 @@ def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def ctypes2cupy(cptr, length, dtype):
|
def ctypes2cupy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> CupyT:
|
||||||
"""Convert a ctypes pointer array to a cupy array."""
|
"""Convert a ctypes pointer array to a cupy array."""
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import cupy
|
import cupy
|
||||||
@ -290,7 +298,7 @@ def ctypes2cupy(cptr, length, dtype):
|
|||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def ctypes2buffer(cptr, length) -> bytearray:
|
def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
|
||||||
"""Convert ctypes pointer to buffer type."""
|
"""Convert ctypes pointer to buffer type."""
|
||||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||||
raise RuntimeError('expected char pointer')
|
raise RuntimeError('expected char pointer')
|
||||||
@ -301,25 +309,30 @@ def ctypes2buffer(cptr, length) -> bytearray:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def c_str(string):
|
def c_str(string: str) -> ctypes.c_char_p:
|
||||||
"""Convert a python string to cstring."""
|
"""Convert a python string to cstring."""
|
||||||
return ctypes.c_char_p(string.encode('utf-8'))
|
return ctypes.c_char_p(string.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def c_array(ctype, values):
|
def c_array(ctype: Type[CTypeT], values: ArrayLike) -> ctypes.Array:
|
||||||
"""Convert a python string to c array."""
|
"""Convert a python string to c array."""
|
||||||
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
||||||
return (ctype * len(values)).from_buffer_copy(values)
|
return (ctype * len(values)).from_buffer_copy(values)
|
||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
def _prediction_output(shape, dims, predts, is_cuda):
|
def _prediction_output(
|
||||||
arr_shape: np.ndarray = ctypes2numpy(shape, dims.value, np.uint64)
|
shape: CNumericPtr,
|
||||||
|
dims: c_bst_ulong,
|
||||||
|
predts: CFloatPtr,
|
||||||
|
is_cuda: bool
|
||||||
|
) -> NumpyOrCupy:
|
||||||
|
arr_shape = ctypes2numpy(shape, dims.value, np.uint64)
|
||||||
length = int(np.prod(arr_shape))
|
length = int(np.prod(arr_shape))
|
||||||
if is_cuda:
|
if is_cuda:
|
||||||
arr_predict = ctypes2cupy(predts, length, np.float32)
|
arr_predict = ctypes2cupy(predts, length, np.float32)
|
||||||
else:
|
else:
|
||||||
arr_predict: np.ndarray = ctypes2numpy(predts, length, np.float32)
|
arr_predict = ctypes2numpy(predts, length, np.float32)
|
||||||
arr_predict = arr_predict.reshape(arr_shape)
|
arr_predict = arr_predict.reshape(arr_shape)
|
||||||
return arr_predict
|
return arr_predict
|
||||||
|
|
||||||
@ -415,7 +428,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
def data_handle(
|
def data_handle(
|
||||||
data: Any,
|
data: Any,
|
||||||
*,
|
*,
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[List[str]] = None,
|
feature_types: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -472,7 +485,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
# Nicolas Tresegnie
|
# Nicolas Tresegnie
|
||||||
# Sylvain Marie
|
# Sylvain Marie
|
||||||
# License: BSD 3 clause
|
# License: BSD 3 clause
|
||||||
def _deprecate_positional_args(f):
|
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
|
||||||
"""Decorator for methods that issues warnings for positional arguments
|
"""Decorator for methods that issues warnings for positional arguments
|
||||||
|
|
||||||
Using the keyword-only argument syntax in pep 3102, arguments after the
|
Using the keyword-only argument syntax in pep 3102, arguments after the
|
||||||
@ -496,7 +509,7 @@ def _deprecate_positional_args(f):
|
|||||||
kwonly_args.append(name)
|
kwonly_args.append(name)
|
||||||
|
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def inner_f(*args, **kwargs):
|
def inner_f(*args: Any, **kwargs: Any) -> _T:
|
||||||
extra_args = len(args) - len(all_args)
|
extra_args = len(args) - len(all_args)
|
||||||
if extra_args > 0:
|
if extra_args > 0:
|
||||||
# ignore first 'self' argument for instance methods
|
# ignore first 'self' argument for instance methods
|
||||||
@ -529,21 +542,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data,
|
data: DataType,
|
||||||
label=None,
|
label: Optional[ArrayLike] = None,
|
||||||
*,
|
*,
|
||||||
weight=None,
|
weight: Optional[ArrayLike] = None,
|
||||||
base_margin=None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
missing: Optional[float] = None,
|
missing: Optional[float] = None,
|
||||||
silent=False,
|
silent: bool = False,
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[List[str]] = None,
|
feature_types: Optional[List[str]] = None,
|
||||||
nthread: Optional[int] = None,
|
nthread: Optional[int] = None,
|
||||||
group=None,
|
group: Optional[ArrayLike] = None,
|
||||||
qid=None,
|
qid: Optional[ArrayLike] = None,
|
||||||
label_lower_bound=None,
|
label_lower_bound: Optional[ArrayLike] = None,
|
||||||
label_upper_bound=None,
|
label_upper_bound: Optional[ArrayLike] = None,
|
||||||
feature_weights=None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
enable_categorical: bool = False,
|
enable_categorical: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Parameters
|
"""Parameters
|
||||||
@ -658,7 +671,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
"nthread": self.nthread,
|
"nthread": self.nthread,
|
||||||
"cache_prefix": it.cache_prefix if it.cache_prefix else "",
|
"cache_prefix": it.cache_prefix if it.cache_prefix else "",
|
||||||
}
|
}
|
||||||
args = from_pystr_to_cstr(json.dumps(args))
|
args_cstr = from_pystr_to_cstr(json.dumps(args))
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
reset_callback, next_callback = it.get_callbacks(
|
reset_callback, next_callback = it.get_callbacks(
|
||||||
True, enable_categorical
|
True, enable_categorical
|
||||||
@ -668,7 +681,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
it.proxy.handle,
|
it.proxy.handle,
|
||||||
reset_callback,
|
reset_callback,
|
||||||
next_callback,
|
next_callback,
|
||||||
args,
|
args_cstr,
|
||||||
ctypes.byref(handle),
|
ctypes.byref(handle),
|
||||||
)
|
)
|
||||||
it.reraise()
|
it.reraise()
|
||||||
@ -685,16 +698,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
def set_info(
|
def set_info(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
label=None,
|
label: Optional[ArrayLike] = None,
|
||||||
weight=None,
|
weight: Optional[ArrayLike] = None,
|
||||||
base_margin=None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
group=None,
|
group: Optional[ArrayLike] = None,
|
||||||
qid=None,
|
qid: Optional[ArrayLike] = None,
|
||||||
label_lower_bound=None,
|
label_lower_bound: Optional[ArrayLike] = None,
|
||||||
label_upper_bound=None,
|
label_upper_bound: Optional[ArrayLike] = None,
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[List[str]] = None,
|
feature_types: Optional[List[str]] = None,
|
||||||
feature_weights=None
|
feature_weights: Optional[ArrayLike] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`."""
|
"""Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`."""
|
||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
@ -763,7 +776,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
ctypes.byref(ret)))
|
ctypes.byref(ret)))
|
||||||
return ctypes2numpy(ret, length.value, np.uint32)
|
return ctypes2numpy(ret, length.value, np.uint32)
|
||||||
|
|
||||||
def set_float_info(self, field: str, data) -> None:
|
def set_float_info(self, field: str, data: ArrayLike) -> None:
|
||||||
"""Set float type property into the DMatrix.
|
"""Set float type property into the DMatrix.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -777,7 +790,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, data, field, 'float')
|
dispatch_meta_backend(self, data, field, 'float')
|
||||||
|
|
||||||
def set_float_info_npy2d(self, field: str, data) -> None:
|
def set_float_info_npy2d(self, field: str, data: ArrayLike) -> None:
|
||||||
"""Set float type property into the DMatrix
|
"""Set float type property into the DMatrix
|
||||||
for numpy 2d array input
|
for numpy 2d array input
|
||||||
|
|
||||||
@ -792,7 +805,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, data, field, 'float')
|
dispatch_meta_backend(self, data, field, 'float')
|
||||||
|
|
||||||
def set_uint_info(self, field: str, data) -> None:
|
def set_uint_info(self, field: str, data: ArrayLike) -> None:
|
||||||
"""Set uint type property into the DMatrix.
|
"""Set uint type property into the DMatrix.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -806,7 +819,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, data, field, 'uint32')
|
dispatch_meta_backend(self, data, field, 'uint32')
|
||||||
|
|
||||||
def save_binary(self, fname, silent=True) -> None:
|
def save_binary(self, fname: Union[str, os.PathLike], silent: bool = True) -> None:
|
||||||
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||||
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
||||||
|
|
||||||
@ -822,7 +835,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
c_str(fname),
|
c_str(fname),
|
||||||
ctypes.c_int(silent)))
|
ctypes.c_int(silent)))
|
||||||
|
|
||||||
def set_label(self, label) -> None:
|
def set_label(self, label: ArrayLike) -> None:
|
||||||
"""Set label of dmatrix
|
"""Set label of dmatrix
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -833,7 +846,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, label, 'label', 'float')
|
dispatch_meta_backend(self, label, 'label', 'float')
|
||||||
|
|
||||||
def set_weight(self, weight) -> None:
|
def set_weight(self, weight: ArrayLike) -> None:
|
||||||
"""Set weight of each instance.
|
"""Set weight of each instance.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -852,7 +865,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, weight, 'weight', 'float')
|
dispatch_meta_backend(self, weight, 'weight', 'float')
|
||||||
|
|
||||||
def set_base_margin(self, margin) -> None:
|
def set_base_margin(self, margin: ArrayLike) -> None:
|
||||||
"""Set base margin of booster to start from.
|
"""Set base margin of booster to start from.
|
||||||
|
|
||||||
This can be used to specify a prediction value of existing model to be
|
This can be used to specify a prediction value of existing model to be
|
||||||
@ -869,7 +882,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, margin, 'base_margin', 'float')
|
dispatch_meta_backend(self, margin, 'base_margin', 'float')
|
||||||
|
|
||||||
def set_group(self, group) -> None:
|
def set_group(self, group: ArrayLike) -> None:
|
||||||
"""Set group size of DMatrix (used for ranking).
|
"""Set group size of DMatrix (used for ranking).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -997,7 +1010,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
return feature_names
|
return feature_names
|
||||||
|
|
||||||
@feature_names.setter
|
@feature_names.setter
|
||||||
def feature_names(self, feature_names: FeatNamesT) -> None:
|
def feature_names(self, feature_names: FeatureNames) -> None:
|
||||||
"""Set feature names (column labels).
|
"""Set feature names (column labels).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -1026,9 +1039,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
not any(x in f for x in set(('[', ']', '<')))
|
not any(x in f for x in set(('[', ']', '<')))
|
||||||
for f in feature_names):
|
for f in feature_names):
|
||||||
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
||||||
c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names]
|
feature_names_bytes = [bytes(f, encoding='utf-8') for f in feature_names]
|
||||||
c_feature_names = (ctypes.c_char_p *
|
c_feature_names = (ctypes.c_char_p *
|
||||||
len(c_feature_names))(*c_feature_names)
|
len(feature_names_bytes))(*feature_names_bytes)
|
||||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||||
self.handle, c_str('feature_name'),
|
self.handle, c_str('feature_name'),
|
||||||
c_feature_names,
|
c_feature_names,
|
||||||
@ -1091,10 +1104,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
feature_types = [feature_types]
|
feature_types = [feature_types]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
feature_types = [feature_types]
|
feature_types = [feature_types]
|
||||||
c_feature_types = [bytes(f, encoding='utf-8')
|
feature_types_bytes = [bytes(f, encoding='utf-8')
|
||||||
for f in feature_types]
|
for f in feature_types]
|
||||||
c_feature_types = (ctypes.c_char_p *
|
c_feature_types = (ctypes.c_char_p *
|
||||||
len(c_feature_types))(*c_feature_types)
|
len(feature_types_bytes))(*feature_types_bytes)
|
||||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||||
self.handle, c_str('feature_type'),
|
self.handle, c_str('feature_type'),
|
||||||
c_feature_types,
|
c_feature_types,
|
||||||
@ -1118,11 +1131,11 @@ class _ProxyDMatrix(DMatrix):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self): # pylint: disable=super-init-not-called
|
def __init__(self) -> None: # pylint: disable=super-init-not-called
|
||||||
self.handle = ctypes.c_void_p()
|
self.handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
|
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
|
||||||
|
|
||||||
def _set_data_from_cuda_interface(self, data) -> None:
|
def _set_data_from_cuda_interface(self, data: DataType) -> None:
|
||||||
"""Set data from CUDA array interface."""
|
"""Set data from CUDA array interface."""
|
||||||
interface = data.__cuda_array_interface__
|
interface = data.__cuda_array_interface__
|
||||||
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
||||||
@ -1130,14 +1143,14 @@ class _ProxyDMatrix(DMatrix):
|
|||||||
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
|
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_data_from_cuda_columnar(self, data, cat_codes: list) -> None:
|
def _set_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None:
|
||||||
"""Set data from CUDA columnar format."""
|
"""Set data from CUDA columnar format."""
|
||||||
from .data import _cudf_array_interfaces
|
from .data import _cudf_array_interfaces
|
||||||
|
|
||||||
interfaces_str = _cudf_array_interfaces(data, cat_codes)
|
interfaces_str = _cudf_array_interfaces(data, cat_codes)
|
||||||
_check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str))
|
_check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str))
|
||||||
|
|
||||||
def _set_data_from_array(self, data: np.ndarray):
|
def _set_data_from_array(self, data: np.ndarray) -> None:
|
||||||
"""Set data from numpy array."""
|
"""Set data from numpy array."""
|
||||||
from .data import _array_interface
|
from .data import _array_interface
|
||||||
|
|
||||||
@ -1145,7 +1158,7 @@ class _ProxyDMatrix(DMatrix):
|
|||||||
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
|
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_data_from_csr(self, csr):
|
def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
|
||||||
"""Set data from scipy csr"""
|
"""Set data from scipy csr"""
|
||||||
from .data import _array_interface
|
from .data import _array_interface
|
||||||
|
|
||||||
@ -1175,24 +1188,24 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def __init__( # pylint: disable=super-init-not-called
|
def __init__( # pylint: disable=super-init-not-called
|
||||||
self,
|
self,
|
||||||
data,
|
data: DataType,
|
||||||
label=None,
|
label: Optional[ArrayLike] = None,
|
||||||
*,
|
*,
|
||||||
weight=None,
|
weight: Optional[ArrayLike] = None,
|
||||||
base_margin=None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
missing=None,
|
missing: Optional[float] = None,
|
||||||
silent=False,
|
silent: bool = False,
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types=None,
|
feature_types: Optional[List[str]] = None,
|
||||||
nthread: Optional[int] = None,
|
nthread: Optional[int] = None,
|
||||||
max_bin: int = 256,
|
max_bin: int = 256,
|
||||||
group=None,
|
group: Optional[ArrayLike] = None,
|
||||||
qid=None,
|
qid: Optional[ArrayLike] = None,
|
||||||
label_lower_bound=None,
|
label_lower_bound: Optional[ArrayLike] = None,
|
||||||
label_upper_bound=None,
|
label_upper_bound: Optional[ArrayLike] = None,
|
||||||
feature_weights=None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
enable_categorical: bool = False,
|
enable_categorical: bool = False,
|
||||||
):
|
) -> None:
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
self.nthread = nthread if nthread is not None else 1
|
self.nthread = nthread if nthread is not None else 1
|
||||||
@ -1223,7 +1236,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
enable_categorical=enable_categorical,
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, data, enable_categorical: bool, **meta) -> None:
|
def _init(self, data: DataType, enable_categorical: bool, **meta: Any) -> None:
|
||||||
from .data import (
|
from .data import (
|
||||||
_is_dlpack,
|
_is_dlpack,
|
||||||
_transform_dlpack,
|
_transform_dlpack,
|
||||||
@ -1304,9 +1317,10 @@ def _configure_metrics(params: Union[Dict, List]) -> Union[Dict, List]:
|
|||||||
params = dict((k, v) for k, v in params.items())
|
params = dict((k, v) for k, v in params.items())
|
||||||
eval_metrics = params["eval_metric"]
|
eval_metrics = params["eval_metric"]
|
||||||
params.pop("eval_metric", None)
|
params.pop("eval_metric", None)
|
||||||
params = list(params.items())
|
params_list = list(params.items())
|
||||||
for eval_metric in eval_metrics:
|
for eval_metric in eval_metrics:
|
||||||
params += [("eval_metric", eval_metric)]
|
params_list += [("eval_metric", eval_metric)]
|
||||||
|
return params_list
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@ -1417,7 +1431,7 @@ class Booster:
|
|||||||
"Constrained features are not a subset of training data feature names"
|
"Constrained features are not a subset of training data feature names"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
def _configure_constraints(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
def _configure_constraints(self, params: Union[List, Dict]) -> Union[List, Dict]:
|
||||||
if isinstance(params, dict):
|
if isinstance(params, dict):
|
||||||
value = params.get("monotone_constraints")
|
value = params.get("monotone_constraints")
|
||||||
if value:
|
if value:
|
||||||
@ -1546,7 +1560,7 @@ class Booster:
|
|||||||
def __copy__(self) -> "Booster":
|
def __copy__(self) -> "Booster":
|
||||||
return self.__deepcopy__(None)
|
return self.__deepcopy__(None)
|
||||||
|
|
||||||
def __deepcopy__(self, _) -> "Booster":
|
def __deepcopy__(self, _: Any) -> "Booster":
|
||||||
'''Return a copy of booster.'''
|
'''Return a copy of booster.'''
|
||||||
return Booster(model_file=self)
|
return Booster(model_file=self)
|
||||||
|
|
||||||
@ -1629,8 +1643,8 @@ class Booster:
|
|||||||
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
|
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
|
||||||
if features is not None:
|
if features is not None:
|
||||||
assert isinstance(features, list)
|
assert isinstance(features, list)
|
||||||
c_feature_info = [bytes(f, encoding="utf-8") for f in features]
|
feature_info_bytes = [bytes(f, encoding="utf-8") for f in features]
|
||||||
c_feature_info = (ctypes.c_char_p * len(c_feature_info))(*c_feature_info)
|
c_feature_info = (ctypes.c_char_p * len(feature_info_bytes))(*feature_info_bytes)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterSetStrFeatureInfo(
|
_LIB.XGBoosterSetStrFeatureInfo(
|
||||||
self.handle, c_str(field), c_feature_info, c_bst_ulong(len(features))
|
self.handle, c_str(field), c_feature_info, c_bst_ulong(len(features))
|
||||||
@ -1664,10 +1678,14 @@ class Booster:
|
|||||||
return self._get_feature_info("feature_name")
|
return self._get_feature_info("feature_name")
|
||||||
|
|
||||||
@feature_names.setter
|
@feature_names.setter
|
||||||
def feature_names(self, features: FeatNamesT) -> None:
|
def feature_names(self, features: FeatureNames) -> None:
|
||||||
self._set_feature_info(features, "feature_name")
|
self._set_feature_info(features, "feature_name")
|
||||||
|
|
||||||
def set_param(self, params, value=None):
|
def set_param(
|
||||||
|
self,
|
||||||
|
params: Union[Dict, Iterable[Tuple[str, Any]], str],
|
||||||
|
value: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""Set parameters into the Booster.
|
"""Set parameters into the Booster.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -1966,14 +1984,14 @@ class Booster:
|
|||||||
|
|
||||||
def inplace_predict(
|
def inplace_predict(
|
||||||
self,
|
self,
|
||||||
data: Any,
|
data: DataType,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: Tuple[int, int] = (0, 0),
|
||||||
predict_type: str = "value",
|
predict_type: str = "value",
|
||||||
missing: float = np.nan,
|
missing: float = np.nan,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Any = None,
|
base_margin: Any = None,
|
||||||
strict_shape: bool = False
|
strict_shape: bool = False
|
||||||
):
|
) -> NumpyOrCupy:
|
||||||
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not
|
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not
|
||||||
cache the prediction result.
|
cache the prediction result.
|
||||||
|
|
||||||
@ -2232,11 +2250,11 @@ class Booster:
|
|||||||
raise TypeError('Unknown file type: ', fname)
|
raise TypeError('Unknown file type: ', fname)
|
||||||
|
|
||||||
if self.attr("best_iteration") is not None:
|
if self.attr("best_iteration") is not None:
|
||||||
self.best_iteration = int(self.attr("best_iteration"))
|
self.best_iteration = int(self.attr("best_iteration")) # type: ignore
|
||||||
if self.attr("best_score") is not None:
|
if self.attr("best_score") is not None:
|
||||||
self.best_score = float(self.attr("best_score"))
|
self.best_score = float(self.attr("best_score")) # type: ignore
|
||||||
if self.attr("best_ntree_limit") is not None:
|
if self.attr("best_ntree_limit") is not None:
|
||||||
self.best_ntree_limit = int(self.attr("best_ntree_limit"))
|
self.best_ntree_limit = int(self.attr("best_ntree_limit")) # type: ignore
|
||||||
|
|
||||||
def num_boosted_rounds(self) -> int:
|
def num_boosted_rounds(self) -> int:
|
||||||
'''Get number of boosted rounds. For gblinear this is reset to 0 after
|
'''Get number of boosted rounds. For gblinear this is reset to 0 after
|
||||||
@ -2255,7 +2273,8 @@ class Booster:
|
|||||||
_check_call(_LIB.XGBoosterGetNumFeature(self.handle, ctypes.byref(features)))
|
_check_call(_LIB.XGBoosterGetNumFeature(self.handle, ctypes.byref(features)))
|
||||||
return features.value
|
return features.value
|
||||||
|
|
||||||
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
def dump_model(self, fout: Union[str, os.PathLike], fmap: Union[str, os.PathLike] = '',
|
||||||
|
with_stats: bool = False, dump_format: str = "text") -> None:
|
||||||
"""Dump model into a text or JSON file. Unlike :py:meth:`save_model`, the
|
"""Dump model into a text or JSON file. Unlike :py:meth:`save_model`, the
|
||||||
output format is primarily used for visualization or interpretation,
|
output format is primarily used for visualization or interpretation,
|
||||||
hence it's more human readable but cannot be loaded back to XGBoost.
|
hence it's more human readable but cannot be loaded back to XGBoost.
|
||||||
@ -2274,24 +2293,25 @@ class Booster:
|
|||||||
if isinstance(fout, (STRING_TYPES, os.PathLike)):
|
if isinstance(fout, (STRING_TYPES, os.PathLike)):
|
||||||
fout = os.fspath(os.path.expanduser(fout))
|
fout = os.fspath(os.path.expanduser(fout))
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
fout = open(fout, 'w', encoding="utf-8")
|
fout_obj = open(fout, 'w', encoding="utf-8")
|
||||||
need_close = True
|
need_close = True
|
||||||
else:
|
else:
|
||||||
|
fout_obj = fout
|
||||||
need_close = False
|
need_close = False
|
||||||
ret = self.get_dump(fmap, with_stats, dump_format)
|
ret = self.get_dump(fmap, with_stats, dump_format)
|
||||||
if dump_format == 'json':
|
if dump_format == 'json':
|
||||||
fout.write('[\n')
|
fout_obj.write('[\n')
|
||||||
for i, _ in enumerate(ret):
|
for i, _ in enumerate(ret):
|
||||||
fout.write(ret[i])
|
fout_obj.write(ret[i])
|
||||||
if i < len(ret) - 1:
|
if i < len(ret) - 1:
|
||||||
fout.write(",\n")
|
fout_obj.write(",\n")
|
||||||
fout.write('\n]')
|
fout_obj.write('\n]')
|
||||||
else:
|
else:
|
||||||
for i, _ in enumerate(ret):
|
for i, _ in enumerate(ret):
|
||||||
fout.write(f"booster[{i}]:\n")
|
fout_obj.write(f"booster[{i}]:\n")
|
||||||
fout.write(ret[i])
|
fout_obj.write(ret[i])
|
||||||
if need_close:
|
if need_close:
|
||||||
fout.close()
|
fout_obj.close()
|
||||||
|
|
||||||
def get_dump(
|
def get_dump(
|
||||||
self,
|
self,
|
||||||
@ -2438,11 +2458,11 @@ class Booster:
|
|||||||
tree_ids = []
|
tree_ids = []
|
||||||
node_ids = []
|
node_ids = []
|
||||||
fids = []
|
fids = []
|
||||||
splits = []
|
splits: List[Union[float, str]] = []
|
||||||
categories: List[Optional[float]] = []
|
categories: List[Union[Optional[float], List[str]]] = []
|
||||||
y_directs = []
|
y_directs: List[Union[float, str]] = []
|
||||||
n_directs = []
|
n_directs: List[Union[float, str]] = []
|
||||||
missings = []
|
missings: List[Union[float, str]] = []
|
||||||
gains = []
|
gains = []
|
||||||
covers = []
|
covers = []
|
||||||
|
|
||||||
@ -2483,9 +2503,9 @@ class Booster:
|
|||||||
# categorical
|
# categorical
|
||||||
parse = fid[0].split(":")
|
parse = fid[0].split(":")
|
||||||
cats = parse[1][1:-1] # strip the {}
|
cats = parse[1][1:-1] # strip the {}
|
||||||
cats = cats.split(",")
|
cats_split = cats.split(",")
|
||||||
splits.append(float("NAN"))
|
splits.append(float("NAN"))
|
||||||
categories.append(cats if cats else None)
|
categories.append(cats_split if cats_split else None)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Failed to parse model text dump.")
|
raise ValueError("Failed to parse model text dump.")
|
||||||
stats = re.split('=|,', fid[1])
|
stats = re.split('=|,', fid[1])
|
||||||
|
|||||||
@ -57,7 +57,7 @@ from .compat import lazy_isinstance
|
|||||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
||||||
from .core import Objective, Metric
|
from .core import Objective, Metric
|
||||||
from .core import _deprecate_positional_args, _has_categorical
|
from .core import _deprecate_positional_args, _has_categorical
|
||||||
from .data import FeatNamesT
|
from .data import FeatureNames
|
||||||
from .training import train as worker_train
|
from .training import train as worker_train
|
||||||
from .tracker import RabitTracker, get_host_ip
|
from .tracker import RabitTracker, get_host_ip
|
||||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
|
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
|
||||||
@ -326,7 +326,7 @@ class DaskDMatrix:
|
|||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
missing: float = None,
|
missing: float = None,
|
||||||
silent: bool = False, # pylint: disable=unused-argument
|
silent: bool = False, # pylint: disable=unused-argument
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[List[str]] = None,
|
feature_types: Optional[List[str]] = None,
|
||||||
group: Optional[_DaskCollection] = None,
|
group: Optional[_DaskCollection] = None,
|
||||||
qid: Optional[_DaskCollection] = None,
|
qid: Optional[_DaskCollection] = None,
|
||||||
@ -602,7 +602,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
qid: Optional[List[Any]] = None,
|
qid: Optional[List[Any]] = None,
|
||||||
label_lower_bound: Optional[List[Any]] = None,
|
label_lower_bound: Optional[List[Any]] = None,
|
||||||
label_upper_bound: Optional[List[Any]] = None,
|
label_upper_bound: Optional[List[Any]] = None,
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._data = data
|
self._data = data
|
||||||
@ -645,7 +645,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
if self._iter == len(self._data):
|
if self._iter == len(self._data):
|
||||||
# Return 0 when there's no more batch.
|
# Return 0 when there's no more batch.
|
||||||
return 0
|
return 0
|
||||||
feature_names: FeatNamesT = None
|
feature_names: FeatureNames = None
|
||||||
if self._feature_names:
|
if self._feature_names:
|
||||||
feature_names = self._feature_names
|
feature_names = self._feature_names
|
||||||
else:
|
else:
|
||||||
@ -696,7 +696,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
missing: float = None,
|
missing: float = None,
|
||||||
silent: bool = False, # disable=unused-argument
|
silent: bool = False, # disable=unused-argument
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||||
max_bin: int = 256,
|
max_bin: int = 256,
|
||||||
group: Optional[_DaskCollection] = None,
|
group: Optional[_DaskCollection] = None,
|
||||||
@ -733,7 +733,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
|
|
||||||
|
|
||||||
def _create_device_quantile_dmatrix(
|
def _create_device_quantile_dmatrix(
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[Union[Any, List[Any]]],
|
feature_types: Optional[Union[Any, List[Any]]],
|
||||||
feature_weights: Optional[Any],
|
feature_weights: Optional[Any],
|
||||||
missing: float,
|
missing: float,
|
||||||
@ -774,7 +774,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
|
|
||||||
|
|
||||||
def _create_dmatrix(
|
def _create_dmatrix(
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[Union[Any, List[Any]]],
|
feature_types: Optional[Union[Any, List[Any]]],
|
||||||
feature_weights: Optional[Any],
|
feature_weights: Optional[Any],
|
||||||
missing: float,
|
missing: float,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
|
|
||||||
from .core import c_array, _LIB, _check_call, c_str
|
from .core import c_array, _LIB, _check_call, c_str
|
||||||
from .core import _cuda_array_interface
|
from .core import _cuda_array_interface
|
||||||
from .core import DataIter, _ProxyDMatrix, DMatrix, FeatNamesT
|
from .core import DataIter, _ProxyDMatrix, DMatrix, FeatureNames
|
||||||
from .compat import lazy_isinstance, DataFrame
|
from .compat import lazy_isinstance, DataFrame
|
||||||
|
|
||||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||||
@ -69,7 +69,7 @@ def _from_scipy_csr(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
nthread,
|
nthread,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
"""Initialize data from a CSR matrix."""
|
"""Initialize data from a CSR matrix."""
|
||||||
@ -108,7 +108,7 @@ def _is_scipy_csc(data):
|
|||||||
def _from_scipy_csc(
|
def _from_scipy_csc(
|
||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
if len(data.indices) != len(data.data):
|
if len(data.indices) != len(data.data):
|
||||||
@ -164,7 +164,7 @@ def _from_numpy_array(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
nthread,
|
nthread,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
"""Initialize data from a 2-D numpy matrix.
|
"""Initialize data from a 2-D numpy matrix.
|
||||||
@ -245,11 +245,11 @@ be set to `True`.""" + err
|
|||||||
def _transform_pandas_df(
|
def _transform_pandas_df(
|
||||||
data: DataFrame,
|
data: DataFrame,
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
feature_names: FeatNamesT = None,
|
feature_names: FeatureNames = None,
|
||||||
feature_types: Optional[List[str]] = None,
|
feature_types: Optional[List[str]] = None,
|
||||||
meta: Optional[str] = None,
|
meta: Optional[str] = None,
|
||||||
meta_type: Optional[str] = None,
|
meta_type: Optional[str] = None,
|
||||||
) -> Tuple[np.ndarray, FeatNamesT, Optional[List[str]]]:
|
) -> Tuple[np.ndarray, FeatureNames, Optional[List[str]]]:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas.api.types import is_sparse, is_categorical_dtype
|
from pandas.api.types import is_sparse, is_categorical_dtype
|
||||||
|
|
||||||
@ -313,9 +313,9 @@ def _from_pandas_df(
|
|||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
missing: float,
|
missing: float,
|
||||||
nthread: int,
|
nthread: int,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
) -> Tuple[ctypes.c_void_p, FeatNamesT, Optional[List[str]]]:
|
) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]:
|
||||||
data, feature_names, feature_types = _transform_pandas_df(
|
data, feature_names, feature_types = _transform_pandas_df(
|
||||||
data, enable_categorical, feature_names, feature_types
|
data, enable_categorical, feature_names, feature_types
|
||||||
)
|
)
|
||||||
@ -355,7 +355,7 @@ def _from_pandas_series(
|
|||||||
missing: float,
|
missing: float,
|
||||||
nthread: int,
|
nthread: int,
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
from pandas.api.types import is_categorical_dtype
|
from pandas.api.types import is_categorical_dtype
|
||||||
@ -386,7 +386,7 @@ _dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
|||||||
|
|
||||||
def _transform_dt_df(
|
def _transform_dt_df(
|
||||||
data,
|
data,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
meta=None,
|
meta=None,
|
||||||
meta_type=None,
|
meta_type=None,
|
||||||
@ -427,10 +427,10 @@ def _from_dt_df(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
nthread,
|
nthread,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
) -> Tuple[ctypes.c_void_p, FeatNamesT, Optional[List[str]]]:
|
) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]:
|
||||||
if enable_categorical:
|
if enable_categorical:
|
||||||
raise ValueError("categorical data in datatable is not supported yet.")
|
raise ValueError("categorical data in datatable is not supported yet.")
|
||||||
data, feature_names, feature_types = _transform_dt_df(
|
data, feature_names, feature_types = _transform_dt_df(
|
||||||
@ -594,7 +594,7 @@ def _cudf_array_interfaces(data, cat_codes: list) -> bytes:
|
|||||||
|
|
||||||
def _transform_cudf_df(
|
def _transform_cudf_df(
|
||||||
data,
|
data,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
):
|
):
|
||||||
@ -660,7 +660,7 @@ def _from_cudf_df(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
nthread,
|
nthread,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
) -> Tuple[ctypes.c_void_p, Any, Any]:
|
) -> Tuple[ctypes.c_void_p, Any, Any]:
|
||||||
@ -710,7 +710,7 @@ def _from_cupy_array(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
nthread,
|
nthread,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
"""Initialize DMatrix from cupy ndarray."""
|
"""Initialize DMatrix from cupy ndarray."""
|
||||||
@ -757,7 +757,7 @@ def _from_dlpack(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
nthread,
|
nthread,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
data = _transform_dlpack(data)
|
data = _transform_dlpack(data)
|
||||||
@ -772,7 +772,7 @@ def _is_uri(data):
|
|||||||
def _from_uri(
|
def _from_uri(
|
||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
_warn_unused_missing(data, missing)
|
_warn_unused_missing(data, missing)
|
||||||
@ -792,7 +792,7 @@ def _from_list(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
n_threads,
|
n_threads,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
array = np.array(data)
|
array = np.array(data)
|
||||||
@ -808,7 +808,7 @@ def _from_tuple(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
n_threads,
|
n_threads,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
||||||
@ -844,7 +844,7 @@ def dispatch_data_backend(
|
|||||||
data,
|
data,
|
||||||
missing,
|
missing,
|
||||||
threads,
|
threads,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
enable_categorical: bool = False,
|
enable_categorical: bool = False,
|
||||||
):
|
):
|
||||||
@ -1076,7 +1076,7 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
|
|
||||||
def _proxy_transform(
|
def _proxy_transform(
|
||||||
data,
|
data,
|
||||||
feature_names: FeatNamesT,
|
feature_names: FeatureNames,
|
||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from .core import Metric
|
|||||||
from .training import train
|
from .training import train
|
||||||
from .callback import TrainingCallback
|
from .callback import TrainingCallback
|
||||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
|
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
|
||||||
|
from ._typing import ArrayLike
|
||||||
|
|
||||||
# Do not use class names on scikit-learn directly. Re-define the classes on
|
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||||
# .compat to guarantee the behavior without scikit-learn
|
# .compat to guarantee the behavior without scikit-learn
|
||||||
@ -25,8 +26,6 @@ from .compat import (
|
|||||||
XGBoostLabelEncoder,
|
XGBoostLabelEncoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
array_like = Any
|
|
||||||
|
|
||||||
|
|
||||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
|
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
|
||||||
@ -862,19 +861,19 @@ class XGBModel(XGBModelBase):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
y: array_like,
|
y: ArrayLike,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[ArrayLike] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
|
||||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
|
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
|
||||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBModel":
|
) -> "XGBModel":
|
||||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||||
@ -1001,11 +1000,11 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
ntree_limit: Optional[int] = None,
|
ntree_limit: Optional[int] = None,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Predict with `X`. If the model is trained with early stopping, then `best_iteration`
|
"""Predict with `X`. If the model is trained with early stopping, then `best_iteration`
|
||||||
@ -1077,7 +1076,7 @@ class XGBModel(XGBModelBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self, X: array_like,
|
self, X: ArrayLike,
|
||||||
ntree_limit: int = 0,
|
ntree_limit: int = 0,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None
|
iteration_range: Optional[Tuple[int, int]] = None
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
@ -1317,19 +1316,19 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
y: array_like,
|
y: ArrayLike,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[ArrayLike] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
|
||||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBClassifier":
|
) -> "XGBClassifier":
|
||||||
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
||||||
@ -1425,11 +1424,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
ntree_limit: Optional[int] = None,
|
ntree_limit: Optional[int] = None,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
class_probs = super().predict(
|
class_probs = super().predict(
|
||||||
@ -1464,10 +1463,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
def predict_proba(
|
def predict_proba(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
ntree_limit: Optional[int] = None,
|
ntree_limit: Optional[int] = None,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
""" Predict the probability of each `X` example being of a given class.
|
""" Predict the probability of each `X` example being of a given class.
|
||||||
@ -1558,19 +1557,19 @@ class XGBRFClassifier(XGBClassifier):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
y: array_like,
|
y: ArrayLike,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[ArrayLike] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
|
||||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBRFClassifier":
|
) -> "XGBRFClassifier":
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
@ -1630,19 +1629,19 @@ class XGBRFRegressor(XGBRegressor):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
y: array_like,
|
y: ArrayLike,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[ArrayLike] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
|
||||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBRFRegressor":
|
) -> "XGBRFRegressor":
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
@ -1705,23 +1704,23 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: array_like,
|
X: ArrayLike,
|
||||||
y: array_like,
|
y: ArrayLike,
|
||||||
*,
|
*,
|
||||||
group: Optional[array_like] = None,
|
group: Optional[ArrayLike] = None,
|
||||||
qid: Optional[array_like] = None,
|
qid: Optional[ArrayLike] = None,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[ArrayLike] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
|
||||||
eval_group: Optional[Sequence[array_like]] = None,
|
eval_group: Optional[Sequence[ArrayLike]] = None,
|
||||||
eval_qid: Optional[Sequence[array_like]] = None,
|
eval_qid: Optional[Sequence[ArrayLike]] = None,
|
||||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = False,
|
verbose: Optional[bool] = False,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[ArrayLike] = None,
|
||||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBRanker":
|
) -> "XGBRanker":
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user