Support more sklearn tags for testing. (#10230)
This commit is contained in:
@@ -782,7 +782,10 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def _more_tags(self) -> Dict[str, bool]:
|
||||
"""Tags used for scikit-learn data validation."""
|
||||
return {"allow_nan": True, "no_validation": True}
|
||||
tags = {"allow_nan": True, "no_validation": True}
|
||||
if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun":
|
||||
tags["non_deterministic"] = True
|
||||
return tags
|
||||
|
||||
def __sklearn_is_fitted__(self) -> bool:
|
||||
return hasattr(self, "_Booster")
|
||||
@@ -1439,6 +1442,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
) -> None:
|
||||
super().__init__(objective=objective, **kwargs)
|
||||
|
||||
def _more_tags(self) -> Dict[str, bool]:
|
||||
tags = super()._more_tags()
|
||||
tags["multilabel"] = True
|
||||
return tags
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
@@ -1717,6 +1725,12 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||
) -> None:
|
||||
super().__init__(objective=objective, **kwargs)
|
||||
|
||||
def _more_tags(self) -> Dict[str, bool]:
|
||||
tags = super()._more_tags()
|
||||
tags["multioutput"] = True
|
||||
tags["multioutput_only"] = False
|
||||
return tags
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"scikit-learn API for XGBoost random forest regression.",
|
||||
|
||||
Reference in New Issue
Block a user