Disable callback and ES on random forest. (#7236)
This commit is contained in:
parent
c311a8c1d8
commit
c735c17f33
@ -42,7 +42,7 @@ from .core import _deprecate_positional_args
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker, get_host_ip
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import _wrap_evaluation_matrices, _objective_decorator
|
||||
from .sklearn import _wrap_evaluation_matrices, _objective_decorator, _check_rf_callback
|
||||
from .sklearn import XGBRankerMixIn
|
||||
from .sklearn import xgboost_model_doc
|
||||
from .sklearn import _cls_predict_proba
|
||||
@ -1710,7 +1710,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
callbacks: Optional[List[TrainingCallback]] = None,
|
||||
) -> "DaskXGBRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
|
||||
@ -1813,7 +1813,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBClassifier":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
async def _predict_proba_async(
|
||||
@ -2001,7 +2001,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBRanker":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
@ -2047,6 +2047,30 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
return 1
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBRFRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||
super().fit(**args)
|
||||
return self
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
|
||||
@ -2086,3 +2110,27 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
return 1
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBRFClassifier":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||
super().fit(**args)
|
||||
return self
|
||||
|
||||
@ -40,6 +40,17 @@ class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||
_estimator_type = "ranker"
|
||||
|
||||
|
||||
def _check_rf_callback(
|
||||
early_stopping_rounds: Optional[int],
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
) -> None:
|
||||
if early_stopping_rounds is not None or callbacks is not None:
|
||||
raise NotImplementedError(
|
||||
"`early_stopping_rounds` and `callbacks` are not implemented for"
|
||||
" random forest."
|
||||
)
|
||||
|
||||
|
||||
_SklObjective = Optional[
|
||||
Union[
|
||||
str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||
@ -1420,6 +1431,30 @@ class XGBRFClassifier(XGBClassifier):
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
return 1
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: array_like,
|
||||
y: array_like,
|
||||
*,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "XGBRFClassifier":
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||
super().fit(**args)
|
||||
return self
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the scikit-learn API for XGBoost regression.",
|
||||
@ -1451,18 +1486,46 @@ class XGBRFRegressor(XGBRegressor):
|
||||
reg_lambda: float = 1e-5,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(learning_rate=learning_rate, subsample=subsample,
|
||||
colsample_bynode=colsample_bynode,
|
||||
reg_lambda=reg_lambda, **kwargs)
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
subsample=subsample,
|
||||
colsample_bynode=colsample_bynode,
|
||||
reg_lambda=reg_lambda,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_xgb_params(self) -> Dict[str, Any]:
|
||||
params = super().get_xgb_params()
|
||||
params['num_parallel_tree'] = self.n_estimators
|
||||
params["num_parallel_tree"] = self.n_estimators
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
return 1
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: array_like,
|
||||
y: array_like,
|
||||
*,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "XGBRFRegressor":
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||
super().fit(**args)
|
||||
return self
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
'Implementation of the Scikit-Learn API for XGBoost Ranking.',
|
||||
|
||||
@ -402,6 +402,10 @@ def run_boston_housing_rf_regression(tree_method):
|
||||
labels = y[test_index]
|
||||
assert mean_squared_error(preds, labels) < 35
|
||||
|
||||
rfreg = xgb.XGBRFRegressor()
|
||||
with pytest.raises(NotImplementedError):
|
||||
rfreg.fit(X, y, early_stopping_rounds=10)
|
||||
|
||||
|
||||
def test_boston_housing_rf_regression():
|
||||
run_boston_housing_rf_regression("hist")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user