Require keyword args for data iterator. (#8327)

This commit is contained in:
Jiaming Yuan 2022-10-10 17:47:13 +08:00 committed by GitHub
parent e1f9f80df2
commit 5545c49cfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 36 deletions

View File

@ -502,8 +502,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
pointer. pointer.
""" """
@_deprecate_positional_args @require_pos_args(True)
def data_handle( def input_data(
data: Any, data: Any,
*, *,
feature_names: Optional[FeatureNames] = None, feature_names: Optional[FeatureNames] = None,
@ -528,7 +528,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
**kwargs, **kwargs,
) )
# pylint: disable=not-callable # pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0) return self._handle_exception(lambda: self.next(input_data), 0)
@abstractmethod @abstractmethod
def reset(self) -> None: def reset(self) -> None:
@ -554,7 +554,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
raise NotImplementedError() raise NotImplementedError()
# Notice for `_deprecate_positional_args` # Notice for `require_pos_args`
# Authors: Olivier Grisel # Authors: Olivier Grisel
# Gael Varoquaux # Gael Varoquaux
# Andreas Mueller # Andreas Mueller
@ -563,50 +563,63 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
# Nicolas Tresegnie # Nicolas Tresegnie
# Sylvain Marie # Sylvain Marie
# License: BSD 3 clause # License: BSD 3 clause
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]: def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""Decorator for methods that issues warnings for positional arguments """Decorator for methods that issues warnings for positional arguments
Using the keyword-only argument syntax in pep 3102, arguments after the Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument. * will issue a warning or error when passed as a positional argument.
Modified from sklearn utils.validation. Modified from sklearn utils.validation.
Parameters Parameters
---------- ----------
f : function error :
function to check arguments on Whether to throw an error or raise a warning.
""" """
sig = signature(f)
kwonly_args = []
all_args = []
for name, param in sig.parameters.items(): def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
if param.kind == Parameter.POSITIONAL_OR_KEYWORD: """Throw error/warning if there are positional arguments after the asterisk.
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@wraps(f) Parameters
def inner_f(*args: Any, **kwargs: Any) -> _T: ----------
extra_args = len(args) - len(all_args) f :
if extra_args > 0: function to check arguments on.
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}" for name, _ in zip(
kwonly_args[:extra_args], args[-extra_args:]
)
]
# pylint: disable=consider-using-f-string
warnings.warn(
"Pass `{}` as keyword args. Passing these as positional "
"arguments will be considered as error in future releases.".
format(", ".join(args_msg)), FutureWarning
)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return f(**kwargs)
return inner_f """
sig = signature(func)
kwonly_args = []
all_args = []
for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@wraps(func)
def inner_f(*args: Any, **kwargs: Any) -> _T:
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}"
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
]
# pylint: disable=consider-using-f-string
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
if error:
raise TypeError(msg)
warnings.warn(msg, FutureWarning)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return func(**kwargs)
return inner_f
return throw_if
_deprecate_positional_args = require_pos_args(False)
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods

View File

@ -198,6 +198,10 @@ class IteratorForTest(xgb.core.DataIter):
def next(self, input_data: Callable) -> int: def next(self, input_data: Callable) -> int:
if self.it == len(self.X): if self.it == len(self.X):
return 0 return 0
with pytest.raises(TypeError, match="keyword args"):
input_data(self.X[self.it], self.y[self.it], None)
# Use copy to make sure the iterator doesn't hold a reference to the data. # Use copy to make sure the iterator doesn't hold a reference to the data.
input_data( input_data(
data=self.X[self.it].copy(), data=self.X[self.it].copy(),