Typehint for Sklearn. (#6799)
This commit is contained in:
@@ -3,4 +3,5 @@ description-file = README.rst
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
disallow_untyped_defs = True
|
||||
disallow_untyped_defs = True
|
||||
follow_imports = silent
|
||||
@@ -276,6 +276,9 @@ class TrainingCallback(ABC):
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
'''
|
||||
|
||||
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -287,13 +290,11 @@ class TrainingCallback(ABC):
|
||||
'''Run after training is finished.'''
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch: int,
|
||||
evals_log: 'CallbackContainer.EvalsLog') -> bool:
|
||||
def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
|
||||
'''Run before each iteration. Return True when training should stop.'''
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch: int,
|
||||
evals_log: 'CallbackContainer.EvalsLog') -> bool:
|
||||
def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
|
||||
'''Run after each iteration. Return True when training should stop.'''
|
||||
return False
|
||||
|
||||
@@ -351,7 +352,7 @@ class CallbackContainer:
|
||||
|
||||
'''
|
||||
|
||||
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
||||
EvalsLog = TrainingCallback.EvalsLog
|
||||
|
||||
def __init__(self,
|
||||
callbacks: List[TrainingCallback],
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable= invalid-name, unused-import
|
||||
"""For compatibility and optional dependencies."""
|
||||
from typing import Any
|
||||
import sys
|
||||
import types
|
||||
import importlib.util
|
||||
@@ -36,7 +37,7 @@ except ImportError:
|
||||
|
||||
MultiIndex = object
|
||||
Int64Index = object
|
||||
DataFrame = object
|
||||
DataFrame: Any = object
|
||||
Series = object
|
||||
pandas_concat = None
|
||||
PANDAS_INSTALLED = False
|
||||
@@ -109,10 +110,12 @@ except pkg_resources.DistributionNotFound:
|
||||
try:
|
||||
import sparse
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
SCIPY_INSTALLED = True
|
||||
except ImportError:
|
||||
sparse = False
|
||||
scipy_sparse = False
|
||||
scipy_csr: Any = object
|
||||
SCIPY_INSTALLED = False
|
||||
|
||||
|
||||
|
||||
@@ -96,7 +96,11 @@ def from_cstr_to_pystr(data, length) -> List[str]:
|
||||
return res
|
||||
|
||||
|
||||
def _convert_ntree_limit(booster, ntree_limit, iteration_range):
|
||||
def _convert_ntree_limit(
|
||||
booster: "Booster",
|
||||
ntree_limit: Optional[int],
|
||||
iteration_range: Optional[Tuple[int, int]]
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
if ntree_limit is not None and ntree_limit != 0:
|
||||
warnings.warn(
|
||||
"ntree_limit is deprecated, use `iteration_range` or model "
|
||||
@@ -1234,7 +1238,7 @@ class Booster(object):
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
def _transform_monotone_constrains(self, value: Union[dict, str]) -> str:
|
||||
def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
@@ -1246,7 +1250,9 @@ class Booster(object):
|
||||
return '(' + ','.join([str(value.get(feature_name, 0))
|
||||
for feature_name in self.feature_names]) + ')'
|
||||
|
||||
def _transform_interaction_constraints(self, value: Union[list, str]) -> str:
|
||||
def _transform_interaction_constraints(
|
||||
self, value: Union[List[Tuple[str]], str]
|
||||
) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
@@ -1447,7 +1453,7 @@ class Booster(object):
|
||||
attr_names = from_cstr_to_pystr(sarr, length)
|
||||
return {n: self.attr(n) for n in attr_names}
|
||||
|
||||
def set_attr(self, **kwargs):
|
||||
def set_attr(self, **kwargs: Optional[str]) -> None:
|
||||
"""Set the attribute of the Booster.
|
||||
|
||||
Parameters
|
||||
@@ -1971,7 +1977,7 @@ class Booster(object):
|
||||
"Data type:" + str(type(data)) + " not supported by inplace prediction."
|
||||
)
|
||||
|
||||
def save_model(self, fname):
|
||||
def save_model(self, fname: Union[str, os.PathLike]):
|
||||
"""Save the model to a file.
|
||||
|
||||
The model is saved in an XGBoost internal format which is universal among the
|
||||
|
||||
@@ -1028,7 +1028,8 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
|
||||
# Somehow dask fail to infer output shape change for 2-dim prediction, and
|
||||
# `chunks = (None, output_shape[1])` doesn't work due to None is not
|
||||
# supported in map_blocks.
|
||||
chunks = list(data.chunks)
|
||||
chunks: Optional[List[Tuple]] = list(data.chunks)
|
||||
assert isinstance(chunks, list)
|
||||
chunks[1] = (output_shape[1], )
|
||||
else:
|
||||
chunks = None
|
||||
@@ -1633,7 +1634,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
)
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
obj: Optional[Callable] = _objective_decorator(self.objective)
|
||||
else:
|
||||
obj = None
|
||||
model, metric, params = self._configure_fit(
|
||||
@@ -1734,7 +1735,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
params["objective"] = "binary:logistic"
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
obj: Optional[Callable] = _objective_decorator(self.objective)
|
||||
else:
|
||||
obj = None
|
||||
model, metric, params = self._configure_fit(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user