Support more sklearn tags for testing. (#10230)
This commit is contained in:
parent
f8c3d22587
commit
837d44a345
@ -782,7 +782,10 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def _more_tags(self) -> Dict[str, bool]:
|
def _more_tags(self) -> Dict[str, bool]:
|
||||||
"""Tags used for scikit-learn data validation."""
|
"""Tags used for scikit-learn data validation."""
|
||||||
return {"allow_nan": True, "no_validation": True}
|
tags = {"allow_nan": True, "no_validation": True}
|
||||||
|
if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun":
|
||||||
|
tags["non_deterministic"] = True
|
||||||
|
return tags
|
||||||
|
|
||||||
def __sklearn_is_fitted__(self) -> bool:
|
def __sklearn_is_fitted__(self) -> bool:
|
||||||
return hasattr(self, "_Booster")
|
return hasattr(self, "_Booster")
|
||||||
@ -1439,6 +1442,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(objective=objective, **kwargs)
|
super().__init__(objective=objective, **kwargs)
|
||||||
|
|
||||||
|
def _more_tags(self) -> Dict[str, bool]:
|
||||||
|
tags = super()._more_tags()
|
||||||
|
tags["multilabel"] = True
|
||||||
|
return tags
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
@ -1717,6 +1725,12 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(objective=objective, **kwargs)
|
super().__init__(objective=objective, **kwargs)
|
||||||
|
|
||||||
|
def _more_tags(self) -> Dict[str, bool]:
|
||||||
|
tags = super()._more_tags()
|
||||||
|
tags["multioutput"] = True
|
||||||
|
tags["multioutput_only"] = False
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"scikit-learn API for XGBoost random forest regression.",
|
"scikit-learn API for XGBoost random forest regression.",
|
||||||
|
|||||||
@ -1300,19 +1300,11 @@ def test_estimator_reg(estimator, check):
|
|||||||
):
|
):
|
||||||
estimator.fit(X, y)
|
estimator.fit(X, y)
|
||||||
return
|
return
|
||||||
if (
|
elif os.environ["PYTEST_CURRENT_TEST"].find("check_regressor_multioutput") != -1:
|
||||||
os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params")
|
# sklearn requires float64
|
||||||
!= -1
|
with pytest.raises(AssertionError, match="Got float32"):
|
||||||
):
|
check(estimator)
|
||||||
# A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor
|
else:
|
||||||
# returns actual internal default values for parameters in `get_params`, but
|
|
||||||
# those are set as `None` in sklearn interface to avoid duplication. So we fit
|
|
||||||
# a dummy model and obtain the default parameters here for the mutation tests.
|
|
||||||
from sklearn.datasets import make_regression
|
|
||||||
|
|
||||||
X, y = make_regression(n_samples=2, n_features=1)
|
|
||||||
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
|
|
||||||
|
|
||||||
check(estimator)
|
check(estimator)
|
||||||
|
|
||||||
|
|
||||||
@ -1475,3 +1467,19 @@ def test_fit_none() -> None:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="labels"):
|
with pytest.raises(ValueError, match="labels"):
|
||||||
xgb.XGBRegressor().fit(X, None)
|
xgb.XGBRegressor().fit(X, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tags() -> None:
|
||||||
|
for reg in [xgb.XGBRegressor(), xgb.XGBRFRegressor()]:
|
||||||
|
tags = reg._more_tags()
|
||||||
|
assert "non_deterministic" not in tags
|
||||||
|
assert tags["multioutput"] is True
|
||||||
|
assert tags["multioutput_only"] is False
|
||||||
|
|
||||||
|
for clf in [xgb.XGBClassifier()]:
|
||||||
|
tags = clf._more_tags()
|
||||||
|
assert "multioutput" not in tags
|
||||||
|
assert tags["multilabel"] is True
|
||||||
|
|
||||||
|
tags = xgb.XGBRanker()._more_tags()
|
||||||
|
assert "multioutput" not in tags
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user