Require keyword args for data iterator. (#8327)
This commit is contained in:
parent
e1f9f80df2
commit
5545c49cfc
@ -502,8 +502,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
pointer.
|
pointer.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@_deprecate_positional_args
|
@require_pos_args(True)
|
||||||
def data_handle(
|
def input_data(
|
||||||
data: Any,
|
data: Any,
|
||||||
*,
|
*,
|
||||||
feature_names: Optional[FeatureNames] = None,
|
feature_names: Optional[FeatureNames] = None,
|
||||||
@ -528,7 +528,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
return self._handle_exception(lambda: self.next(data_handle), 0)
|
return self._handle_exception(lambda: self.next(input_data), 0)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
@ -554,7 +554,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
# Notice for `_deprecate_positional_args`
|
# Notice for `require_pos_args`
|
||||||
# Authors: Olivier Grisel
|
# Authors: Olivier Grisel
|
||||||
# Gael Varoquaux
|
# Gael Varoquaux
|
||||||
# Andreas Mueller
|
# Andreas Mueller
|
||||||
@ -563,50 +563,63 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
# Nicolas Tresegnie
|
# Nicolas Tresegnie
|
||||||
# Sylvain Marie
|
# Sylvain Marie
|
||||||
# License: BSD 3 clause
|
# License: BSD 3 clause
|
||||||
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
|
def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
|
||||||
"""Decorator for methods that issues warnings for positional arguments
|
"""Decorator for methods that issues warnings for positional arguments
|
||||||
|
|
||||||
Using the keyword-only argument syntax in pep 3102, arguments after the
|
Using the keyword-only argument syntax in pep 3102, arguments after the
|
||||||
* will issue a warning when passed as a positional argument.
|
* will issue a warning or error when passed as a positional argument.
|
||||||
|
|
||||||
Modified from sklearn utils.validation.
|
Modified from sklearn utils.validation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
f : function
|
error :
|
||||||
function to check arguments on
|
Whether to throw an error or raise a warning.
|
||||||
"""
|
"""
|
||||||
sig = signature(f)
|
|
||||||
kwonly_args = []
|
|
||||||
all_args = []
|
|
||||||
|
|
||||||
for name, param in sig.parameters.items():
|
def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
|
||||||
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
"""Throw error/warning if there are positional arguments after the asterisk.
|
||||||
all_args.append(name)
|
|
||||||
elif param.kind == Parameter.KEYWORD_ONLY:
|
|
||||||
kwonly_args.append(name)
|
|
||||||
|
|
||||||
@wraps(f)
|
Parameters
|
||||||
def inner_f(*args: Any, **kwargs: Any) -> _T:
|
----------
|
||||||
extra_args = len(args) - len(all_args)
|
f :
|
||||||
if extra_args > 0:
|
function to check arguments on.
|
||||||
# ignore first 'self' argument for instance methods
|
|
||||||
args_msg = [
|
|
||||||
f"{name}" for name, _ in zip(
|
|
||||||
kwonly_args[:extra_args], args[-extra_args:]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
# pylint: disable=consider-using-f-string
|
|
||||||
warnings.warn(
|
|
||||||
"Pass `{}` as keyword args. Passing these as positional "
|
|
||||||
"arguments will be considered as error in future releases.".
|
|
||||||
format(", ".join(args_msg)), FutureWarning
|
|
||||||
)
|
|
||||||
for k, arg in zip(sig.parameters, args):
|
|
||||||
kwargs[k] = arg
|
|
||||||
return f(**kwargs)
|
|
||||||
|
|
||||||
return inner_f
|
"""
|
||||||
|
sig = signature(func)
|
||||||
|
kwonly_args = []
|
||||||
|
all_args = []
|
||||||
|
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||||
|
all_args.append(name)
|
||||||
|
elif param.kind == Parameter.KEYWORD_ONLY:
|
||||||
|
kwonly_args.append(name)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def inner_f(*args: Any, **kwargs: Any) -> _T:
|
||||||
|
extra_args = len(args) - len(all_args)
|
||||||
|
if extra_args > 0:
|
||||||
|
# ignore first 'self' argument for instance methods
|
||||||
|
args_msg = [
|
||||||
|
f"{name}"
|
||||||
|
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
|
||||||
|
]
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
|
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
|
||||||
|
if error:
|
||||||
|
raise TypeError(msg)
|
||||||
|
warnings.warn(msg, FutureWarning)
|
||||||
|
for k, arg in zip(sig.parameters, args):
|
||||||
|
kwargs[k] = arg
|
||||||
|
return func(**kwargs)
|
||||||
|
|
||||||
|
return inner_f
|
||||||
|
|
||||||
|
return throw_if
|
||||||
|
|
||||||
|
|
||||||
|
_deprecate_positional_args = require_pos_args(False)
|
||||||
|
|
||||||
|
|
||||||
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||||
|
|||||||
@ -198,6 +198,10 @@ class IteratorForTest(xgb.core.DataIter):
|
|||||||
def next(self, input_data: Callable) -> int:
|
def next(self, input_data: Callable) -> int:
|
||||||
if self.it == len(self.X):
|
if self.it == len(self.X):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="keyword args"):
|
||||||
|
input_data(self.X[self.it], self.y[self.it], None)
|
||||||
|
|
||||||
# Use copy to make sure the iterator doesn't hold a reference to the data.
|
# Use copy to make sure the iterator doesn't hold a reference to the data.
|
||||||
input_data(
|
input_data(
|
||||||
data=self.X[self.it].copy(),
|
data=self.X[self.it].copy(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user