Require keyword args for data iterator. (#8327)

This commit is contained in:
Jiaming Yuan 2022-10-10 17:47:13 +08:00 committed by GitHub
parent e1f9f80df2
commit 5545c49cfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 36 deletions

View File

@ -502,8 +502,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
pointer.
"""
@_deprecate_positional_args
def data_handle(
@require_pos_args(True)
def input_data(
data: Any,
*,
feature_names: Optional[FeatureNames] = None,
@ -528,7 +528,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
**kwargs,
)
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0)
return self._handle_exception(lambda: self.next(input_data), 0)
@abstractmethod
def reset(self) -> None:
@ -554,7 +554,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
raise NotImplementedError()
# Notice for `_deprecate_positional_args`
# Notice for `require_pos_args`
# Authors: Olivier Grisel
# Gael Varoquaux
# Andreas Mueller
@ -563,20 +563,30 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
# Nicolas Tresegnie
# Sylvain Marie
# License: BSD 3 clause
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""Decorator for methods that issues warnings for positional arguments
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
* will issue a warning or error when passed as a positional argument.
Modified from sklearn utils.validation.
Parameters
----------
f : function
function to check arguments on
error :
Whether to throw an error or raise a warning.
"""
sig = signature(f)
def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
"""Throw error/warning if there are positional arguments after the asterisk.
Parameters
----------
f :
function to check arguments on.
"""
sig = signature(func)
kwonly_args = []
all_args = []
@ -586,28 +596,31 @@ def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@wraps(f)
@wraps(func)
def inner_f(*args: Any, **kwargs: Any) -> _T:
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}" for name, _ in zip(
kwonly_args[:extra_args], args[-extra_args:]
)
f"{name}"
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
]
# pylint: disable=consider-using-f-string
warnings.warn(
"Pass `{}` as keyword args. Passing these as positional "
"arguments will be considered as error in future releases.".
format(", ".join(args_msg)), FutureWarning
)
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
if error:
raise TypeError(msg)
warnings.warn(msg, FutureWarning)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return f(**kwargs)
return func(**kwargs)
return inner_f
return throw_if
_deprecate_positional_args = require_pos_args(False)
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost.

View File

@ -198,6 +198,10 @@ class IteratorForTest(xgb.core.DataIter):
def next(self, input_data: Callable) -> int:
if self.it == len(self.X):
return 0
with pytest.raises(TypeError, match="keyword args"):
input_data(self.X[self.it], self.y[self.it], None)
# Use copy to make sure the iterator doesn't hold a reference to the data.
input_data(
data=self.X[self.it].copy(),