Disable callback and ES on random forest. (#7236)

This commit is contained in:
Jiaming Yuan 2021-09-17 18:21:17 +08:00 committed by GitHub
parent c311a8c1d8
commit c735c17f33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 123 additions and 8 deletions

View File

@ -42,7 +42,7 @@ from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
from .sklearn import _wrap_evaluation_matrices, _objective_decorator
from .sklearn import _wrap_evaluation_matrices, _objective_decorator, _check_rf_callback
from .sklearn import XGBRankerMixIn
from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba
@ -1710,7 +1710,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
callbacks: Optional[List[TrainingCallback]] = None,
) -> "DaskXGBRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args)
@ -1813,7 +1813,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBClassifier":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != 'self'}
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args)
async def _predict_proba_async(
@ -2001,7 +2001,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRanker":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
@ -2047,6 +2047,30 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
def fit(
self,
X: _DaskCollection,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRFRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
super().fit(**args)
return self
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
@ -2086,3 +2110,27 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
def fit(
self,
X: _DaskCollection,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRFClassifier":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
super().fit(**args)
return self

View File

@ -40,6 +40,17 @@ class XGBRankerMixIn: # pylint: disable=too-few-public-methods
_estimator_type = "ranker"
def _check_rf_callback(
early_stopping_rounds: Optional[int],
callbacks: Optional[List[TrainingCallback]],
) -> None:
if early_stopping_rounds is not None or callbacks is not None:
raise NotImplementedError(
"`early_stopping_rounds` and `callbacks` are not implemented for"
" random forest."
)
_SklObjective = Optional[
Union[
str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
@ -1420,6 +1431,30 @@ class XGBRFClassifier(XGBClassifier):
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: array_like,
y: array_like,
*,
sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None,
feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None
) -> "XGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
super().fit(**args)
return self
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost regression.",
@ -1451,18 +1486,46 @@ class XGBRFRegressor(XGBRegressor):
reg_lambda: float = 1e-5,
**kwargs: Any
) -> None:
super().__init__(learning_rate=learning_rate, subsample=subsample,
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda, **kwargs)
reg_lambda=reg_lambda,
**kwargs
)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params['num_parallel_tree'] = self.n_estimators
params["num_parallel_tree"] = self.n_estimators
return params
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: array_like,
y: array_like,
*,
sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None,
feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None
) -> "XGBRFRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
super().fit(**args)
return self
@xgboost_model_doc(
'Implementation of the Scikit-Learn API for XGBoost Ranking.',

View File

@ -402,6 +402,10 @@ def run_boston_housing_rf_regression(tree_method):
labels = y[test_index]
assert mean_squared_error(preds, labels) < 35
rfreg = xgb.XGBRFRegressor()
with pytest.raises(NotImplementedError):
rfreg.fit(X, y, early_stopping_rounds=10)
def test_boston_housing_rf_regression():
run_boston_housing_rf_regression("hist")