Typehint for Sklearn. (#6799)

This commit is contained in:
Jiaming Yuan
2021-04-14 06:55:21 +08:00
committed by GitHub
parent 3d919db0c0
commit dee5ef2dfd
11 changed files with 335 additions and 262 deletions

View File

@@ -3,4 +3,5 @@ description-file = README.rst
[mypy]
ignore_missing_imports = True
disallow_untyped_defs = True
disallow_untyped_defs = True
follow_imports = silent

View File

@@ -276,6 +276,9 @@ class TrainingCallback(ABC):
.. versionadded:: 1.3.0
'''
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
def __init__(self):
pass
@@ -287,13 +290,11 @@ class TrainingCallback(ABC):
'''Run after training is finished.'''
return model
def before_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
'''Run before each iteration. Return True when training should stop.'''
return False
def after_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
'''Run after each iteration. Return True when training should stop.'''
return False
@@ -351,7 +352,7 @@ class CallbackContainer:
'''
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
EvalsLog = TrainingCallback.EvalsLog
def __init__(self,
callbacks: List[TrainingCallback],

View File

@@ -1,6 +1,7 @@
# coding: utf-8
# pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies."""
from typing import Any
import sys
import types
import importlib.util
@@ -36,7 +37,7 @@ except ImportError:
MultiIndex = object
Int64Index = object
DataFrame = object
DataFrame: Any = object
Series = object
pandas_concat = None
PANDAS_INSTALLED = False
@@ -109,10 +110,12 @@ except pkg_resources.DistributionNotFound:
try:
import sparse
import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr
SCIPY_INSTALLED = True
except ImportError:
sparse = False
scipy_sparse = False
scipy_csr: Any = object
SCIPY_INSTALLED = False

View File

@@ -96,7 +96,11 @@ def from_cstr_to_pystr(data, length) -> List[str]:
return res
def _convert_ntree_limit(booster, ntree_limit, iteration_range):
def _convert_ntree_limit(
booster: "Booster",
ntree_limit: Optional[int],
iteration_range: Optional[Tuple[int, int]]
) -> Optional[Tuple[int, int]]:
if ntree_limit is not None and ntree_limit != 0:
warnings.warn(
"ntree_limit is deprecated, use `iteration_range` or model "
@@ -1234,7 +1238,7 @@ class Booster(object):
params += [('eval_metric', eval_metric)]
return params
def _transform_monotone_constrains(self, value: Union[dict, str]) -> str:
def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
if isinstance(value, str):
return value
@@ -1246,7 +1250,9 @@ class Booster(object):
return '(' + ','.join([str(value.get(feature_name, 0))
for feature_name in self.feature_names]) + ')'
def _transform_interaction_constraints(self, value: Union[list, str]) -> str:
def _transform_interaction_constraints(
self, value: Union[List[Tuple[str]], str]
) -> str:
if isinstance(value, str):
return value
@@ -1447,7 +1453,7 @@ class Booster(object):
attr_names = from_cstr_to_pystr(sarr, length)
return {n: self.attr(n) for n in attr_names}
def set_attr(self, **kwargs):
def set_attr(self, **kwargs: Optional[str]) -> None:
"""Set the attribute of the Booster.
Parameters
@@ -1971,7 +1977,7 @@ class Booster(object):
"Data type:" + str(type(data)) + " not supported by inplace prediction."
)
def save_model(self, fname):
def save_model(self, fname: Union[str, os.PathLike]):
"""Save the model to a file.
The model is saved in an XGBoost internal format which is universal among the

View File

@@ -1028,7 +1028,8 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
# Somehow dask fail to infer output shape change for 2-dim prediction, and
# `chunks = (None, output_shape[1])` doesn't work due to None is not
# supported in map_blocks.
chunks = list(data.chunks)
chunks: Optional[List[Tuple]] = list(data.chunks)
assert isinstance(chunks, list)
chunks[1] = (output_shape[1], )
else:
chunks = None
@@ -1633,7 +1634,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
)
if callable(self.objective):
obj = _objective_decorator(self.objective)
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params = self._configure_fit(
@@ -1734,7 +1735,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params["objective"] = "binary:logistic"
if callable(self.objective):
obj = _objective_decorator(self.objective)
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params = self._configure_fit(

File diff suppressed because it is too large Load Diff