Support more sklearn tags for testing. (#10230)

This commit is contained in:
Jiaming Yuan
2024-04-29 06:33:23 +08:00
committed by GitHub
parent f8c3d22587
commit 837d44a345
2 changed files with 37 additions and 15 deletions

View File

@@ -782,7 +782,10 @@ class XGBModel(XGBModelBase):
def _more_tags(self) -> Dict[str, bool]:
"""Tags used for scikit-learn data validation."""
return {"allow_nan": True, "no_validation": True}
tags = {"allow_nan": True, "no_validation": True}
if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun":
tags["non_deterministic"] = True
return tags
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
@@ -1439,6 +1442,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
) -> None:
super().__init__(objective=objective, **kwargs)
def _more_tags(self) -> Dict[str, bool]:
tags = super()._more_tags()
tags["multilabel"] = True
return tags
@_deprecate_positional_args
def fit(
self,
@@ -1717,6 +1725,12 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
) -> None:
super().__init__(objective=objective, **kwargs)
def _more_tags(self) -> Dict[str, bool]:
tags = super()._more_tags()
tags["multioutput"] = True
tags["multioutput_only"] = False
return tags
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",