Require keyword args for data iterator. (#8327)
This commit is contained in:
parent
e1f9f80df2
commit
5545c49cfc
@ -502,8 +502,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
pointer.
|
||||
|
||||
"""
|
||||
@_deprecate_positional_args
|
||||
def data_handle(
|
||||
@require_pos_args(True)
|
||||
def input_data(
|
||||
data: Any,
|
||||
*,
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
@ -528,7 +528,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
**kwargs,
|
||||
)
|
||||
# pylint: disable=not-callable
|
||||
return self._handle_exception(lambda: self.next(data_handle), 0)
|
||||
return self._handle_exception(lambda: self.next(input_data), 0)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
@ -554,7 +554,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# Notice for `_deprecate_positional_args`
|
||||
# Notice for `require_pos_args`
|
||||
# Authors: Olivier Grisel
|
||||
# Gael Varoquaux
|
||||
# Andreas Mueller
|
||||
@ -563,50 +563,63 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
# Nicolas Tresegnie
|
||||
# Sylvain Marie
|
||||
# License: BSD 3 clause
|
||||
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
|
||||
def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
|
||||
"""Decorator for methods that issues warnings for positional arguments
|
||||
|
||||
Using the keyword-only argument syntax in pep 3102, arguments after the
|
||||
* will issue a warning when passed as a positional argument.
|
||||
* will issue a warning or error when passed as a positional argument.
|
||||
|
||||
Modified from sklearn utils.validation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
f : function
|
||||
function to check arguments on
|
||||
error :
|
||||
Whether to throw an error or raise a warning.
|
||||
"""
|
||||
sig = signature(f)
|
||||
kwonly_args = []
|
||||
all_args = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
all_args.append(name)
|
||||
elif param.kind == Parameter.KEYWORD_ONLY:
|
||||
kwonly_args.append(name)
|
||||
def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
|
||||
"""Throw error/warning if there are positional arguments after the asterisk.
|
||||
|
||||
@wraps(f)
|
||||
def inner_f(*args: Any, **kwargs: Any) -> _T:
|
||||
extra_args = len(args) - len(all_args)
|
||||
if extra_args > 0:
|
||||
# ignore first 'self' argument for instance methods
|
||||
args_msg = [
|
||||
f"{name}" for name, _ in zip(
|
||||
kwonly_args[:extra_args], args[-extra_args:]
|
||||
)
|
||||
]
|
||||
# pylint: disable=consider-using-f-string
|
||||
warnings.warn(
|
||||
"Pass `{}` as keyword args. Passing these as positional "
|
||||
"arguments will be considered as error in future releases.".
|
||||
format(", ".join(args_msg)), FutureWarning
|
||||
)
|
||||
for k, arg in zip(sig.parameters, args):
|
||||
kwargs[k] = arg
|
||||
return f(**kwargs)
|
||||
Parameters
|
||||
----------
|
||||
f :
|
||||
function to check arguments on.
|
||||
|
||||
return inner_f
|
||||
"""
|
||||
sig = signature(func)
|
||||
kwonly_args = []
|
||||
all_args = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
all_args.append(name)
|
||||
elif param.kind == Parameter.KEYWORD_ONLY:
|
||||
kwonly_args.append(name)
|
||||
|
||||
@wraps(func)
|
||||
def inner_f(*args: Any, **kwargs: Any) -> _T:
|
||||
extra_args = len(args) - len(all_args)
|
||||
if extra_args > 0:
|
||||
# ignore first 'self' argument for instance methods
|
||||
args_msg = [
|
||||
f"{name}"
|
||||
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
|
||||
]
|
||||
# pylint: disable=consider-using-f-string
|
||||
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
|
||||
if error:
|
||||
raise TypeError(msg)
|
||||
warnings.warn(msg, FutureWarning)
|
||||
for k, arg in zip(sig.parameters, args):
|
||||
kwargs[k] = arg
|
||||
return func(**kwargs)
|
||||
|
||||
return inner_f
|
||||
|
||||
return throw_if
|
||||
|
||||
|
||||
_deprecate_positional_args = require_pos_args(False)
|
||||
|
||||
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
|
||||
@ -198,6 +198,10 @@ class IteratorForTest(xgb.core.DataIter):
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == len(self.X):
|
||||
return 0
|
||||
|
||||
with pytest.raises(TypeError, match="keyword args"):
|
||||
input_data(self.X[self.it], self.y[self.it], None)
|
||||
|
||||
# Use copy to make sure the iterator doesn't hold a reference to the data.
|
||||
input_data(
|
||||
data=self.X[self.it].copy(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user