diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 1684317ac..7585dbf57 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -106,3 +106,9 @@ Dask API .. autofunction:: xgboost.dask.DaskXGBClassifier .. autofunction:: xgboost.dask.DaskXGBRegressor + +.. autofunction:: xgboost.dask.DaskXGBRanker + +.. autofunction:: xgboost.dask.DaskXGBRFRegressor + +.. autofunction:: xgboost.dask.DaskXGBRFClassifier diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b1cc71447..6c40a8c97 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -38,8 +38,8 @@ from .core import Objective, Metric from .core import _deprecate_positional_args from .training import train as worker_train from .tracker import RabitTracker, get_host_ip -from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator -from .sklearn import xgboost_model_doc +from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase +from .sklearn import xgboost_model_doc, _objective_decorator from .sklearn import _cls_predict_proba from .sklearn import XGBRanker @@ -1262,7 +1262,6 @@ class DaskScikitLearnBase(XGBModel): _client = None - # pylint: disable=arguments-differ @_deprecate_positional_args async def _predict_async( self, data: _DaskCollection, @@ -1282,7 +1281,7 @@ class DaskScikitLearnBase(XGBModel): def predict( self, - data: _DaskCollection, + X: _DaskCollection, output_margin: bool = False, ntree_limit: Optional[int] = None, validate_features: bool = True, @@ -1291,10 +1290,13 @@ class DaskScikitLearnBase(XGBModel): _assert_dask_support() msg = '`ntree_limit` is not supported on dask, use model slicing instead.' assert ntree_limit is None, msg - return self.client.sync(self._predict_async, data, - output_margin=output_margin, - validate_features=validate_features, - base_margin=base_margin) + return self.client.sync( + self._predict_async, + X, + output_margin=output_margin, + validate_features=validate_features, + base_margin=base_margin + ) def __await__(self) -> Awaitable[Any]: # Generate a coroutine wrapper to make this class awaitable. @@ -1586,7 +1588,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): """, ) class DaskXGBRanker(DaskScikitLearnBase): - def __init__(self, objective: str = "rank:pairwise", **kwargs: Any): + @_deprecate_positional_args + def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): if callable(objective): raise ValueError("Custom objective function not supported by XGBRanker.") super().__init__(objective=objective, kwargs=kwargs) @@ -1698,3 +1701,75 @@ class DaskXGBRanker(DaskScikitLearnBase): # FIXME(trivialfis): arguments differ due to additional parameters like group and qid. fit.__doc__ = XGBRanker.fit.__doc__ + + +@xgboost_model_doc( + "Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.", + ["model", "objective"], + extra_parameters=""" + n_estimators : int + Number of trees in random forest to fit. +""", +) +class DaskXGBRFRegressor(DaskXGBRegressor): + @_deprecate_positional_args + def __init__( + self, + *, + learning_rate: Optional[float] = 1, + subsample: Optional[float] = 0.8, + colsample_bynode: Optional[float] = 0.8, + reg_lambda: Optional[float] = 1e-5, + **kwargs: Any + ) -> None: + super().__init__( + learning_rate=learning_rate, + subsample=subsample, + colsample_bynode=colsample_bynode, + reg_lambda=reg_lambda, + **kwargs + ) + + def get_xgb_params(self) -> Dict[str, Any]: + params = super().get_xgb_params() + params["num_parallel_tree"] = self.n_estimators + return params + + def get_num_boosting_rounds(self) -> int: + return 1 + + +@xgboost_model_doc( + "Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.", + ["model", "objective"], + extra_parameters=""" + n_estimators : int + Number of trees in random forest to fit. +""", +) +class DaskXGBRFClassifier(DaskXGBClassifier): + @_deprecate_positional_args + def __init__( + self, + *, + learning_rate: Optional[float] = 1, + subsample: Optional[float] = 0.8, + colsample_bynode: Optional[float] = 0.8, + reg_lambda: Optional[float] = 1e-5, + **kwargs: Any + ) -> None: + super().__init__( + learning_rate=learning_rate, + subsample=subsample, + colsample_bynode=colsample_bynode, + reg_lambda=reg_lambda, + **kwargs + ) + + def get_xgb_params(self) -> Dict[str, Any]: + params = super().get_xgb_params() + params["num_parallel_tree"] = self.n_estimators + return params + + def get_num_boosting_rounds(self) -> int: + return 1 diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 2ea726436..3fcbcc0ed 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -91,7 +91,7 @@ __model_doc = ''' node of the tree. min_child_weight : float Minimum sum of instance weight(hessian) needed in a child. - max_delta_step : int + max_delta_step : float Maximum delta step we allow each tree's weight estimation to be. subsample : float Subsample ratio of the training instance. @@ -1465,7 +1465,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn): xgb_model = xgb_model._Booster # pylint: disable=protected-access self._Booster = train(params, train_dmatrix, - self.n_estimators, + self.get_num_boosting_rounds(), early_stopping_rounds=early_stopping_rounds, evals=evals, evals_result=evals_result, feval=feval, diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 6f696b8d2..126f9ce98 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -34,7 +34,7 @@ from xgboost.dask import DaskDMatrix if hasattr(HealthCheck, 'function_scoped_fixture'): suppress = [HealthCheck.function_scoped_fixture] else: - suppress = hypothesis.utils.conventions.not_set + suppress = hypothesis.utils.conventions.not_set # type:ignore kRows = 1000 @@ -264,100 +264,127 @@ def test_dask_missing_value_cls() -> None: assert hasattr(cls, 'missing') -def test_dask_regressor() -> None: - with LocalCluster(n_workers=kWorkers) as cluster: - with Client(cluster) as client: - X, y, w = generate_array(with_weights=True) - regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) - assert regressor._estimator_type == "regressor" - assert sklearn.base.is_regressor(regressor) +@pytest.mark.parametrize("model", ["boosting", "rf"]) +def test_dask_regressor(model: str, client: "Client") -> None: + X, y, w = generate_array(with_weights=True) + if model == "boosting": + regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) + else: + regressor = xgb.dask.DaskXGBRFRegressor(verbosity=1, n_estimators=2) - regressor.set_params(tree_method='hist') - regressor.client = client - regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)]) - prediction = regressor.predict(X) + assert regressor._estimator_type == "regressor" + assert sklearn.base.is_regressor(regressor) - assert prediction.ndim == 1 - assert prediction.shape[0] == kRows + regressor.set_params(tree_method='hist') + regressor.client = client + regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)]) + prediction = regressor.predict(X) - history = regressor.evals_result() + assert prediction.ndim == 1 + assert prediction.shape[0] == kRows - assert isinstance(prediction, da.Array) - assert isinstance(history, dict) + history = regressor.evals_result() - assert list(history['validation_0'].keys())[0] == 'rmse' - assert len(history['validation_0']['rmse']) == 2 + assert isinstance(prediction, da.Array) + assert isinstance(history, dict) + + assert list(history['validation_0'].keys())[0] == 'rmse' + forest = int( + json.loads(regressor.get_booster().save_config())["learner"][ + "gradient_booster" + ]["gbtree_train_param"]["num_parallel_tree"] + ) + + if model == "boosting": + assert len(history['validation_0']['rmse']) == 2 + assert forest == 1 + else: + assert len(history['validation_0']['rmse']) == 1 + assert forest == 2 -def test_dask_classifier() -> None: - with LocalCluster(n_workers=kWorkers) as cluster: - with Client(cluster) as client: - X, y, w = generate_array(with_weights=True) - y = (y * 10).astype(np.int32) - classifier = xgb.dask.DaskXGBClassifier( - verbosity=1, n_estimators=2, eval_metric='merror') - assert classifier._estimator_type == "classifier" - assert sklearn.base.is_classifier(classifier) +@pytest.mark.parametrize("model", ["boosting", "rf"]) +def test_dask_classifier(model: str, client: "Client") -> None: + X, y, w = generate_array(with_weights=True) + y = (y * 10).astype(np.int32) + if model == "boosting": + classifier = xgb.dask.DaskXGBClassifier( + verbosity=1, n_estimators=2, eval_metric="merror" + ) + else: + classifier = xgb.dask.DaskXGBRFClassifier( + verbosity=1, n_estimators=2, eval_metric="merror" + ) - classifier.client = client - classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)]) - prediction = classifier.predict(X) + assert classifier._estimator_type == "classifier" + assert sklearn.base.is_classifier(classifier) - assert prediction.ndim == 1 - assert prediction.shape[0] == kRows + classifier.client = client + classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)]) + prediction = classifier.predict(X) - history = classifier.evals_result() + assert prediction.ndim == 1 + assert prediction.shape[0] == kRows - assert isinstance(prediction, da.Array) - assert isinstance(history, dict) + history = classifier.evals_result() - assert list(history.keys())[0] == 'validation_0' - assert list(history['validation_0'].keys())[0] == 'merror' - assert len(list(history['validation_0'])) == 1 - assert len(history['validation_0']['merror']) == 2 + assert isinstance(prediction, da.Array) + assert isinstance(history, dict) - # Test .predict_proba() - probas = classifier.predict_proba(X) - assert classifier.n_classes_ == 10 - assert probas.ndim == 2 - assert probas.shape[0] == kRows - assert probas.shape[1] == 10 + assert list(history.keys())[0] == "validation_0" + assert list(history["validation_0"].keys())[0] == "merror" + assert len(list(history["validation_0"])) == 1 + forest = int( + json.loads(classifier.get_booster().save_config())["learner"][ + "gradient_booster" + ]["gbtree_train_param"]["num_parallel_tree"] + ) + if model == "boosting": + assert len(history["validation_0"]["merror"]) == 2 + assert forest == 1 + else: + assert len(history["validation_0"]["merror"]) == 1 + assert forest == 2 - cls_booster = classifier.get_booster() - single_node_proba = cls_booster.inplace_predict(X.compute()) + # Test .predict_proba() + probas = classifier.predict_proba(X) + assert classifier.n_classes_ == 10 + assert probas.ndim == 2 + assert probas.shape[0] == kRows + assert probas.shape[1] == 10 - np.testing.assert_allclose(single_node_proba, - probas.compute()) + cls_booster = classifier.get_booster() + single_node_proba = cls_booster.inplace_predict(X.compute()) - # Test with dataframe. - X_d = dd.from_dask_array(X) - y_d = dd.from_dask_array(y) - classifier.fit(X_d, y_d) + np.testing.assert_allclose(single_node_proba, probas.compute()) - assert classifier.n_classes_ == 10 - prediction = classifier.predict(X_d) + # Test with dataframe. + X_d = dd.from_dask_array(X) + y_d = dd.from_dask_array(y) + classifier.fit(X_d, y_d) - assert prediction.ndim == 1 - assert prediction.shape[0] == kRows + assert classifier.n_classes_ == 10 + prediction = classifier.predict(X_d) + + assert prediction.ndim == 1 + assert prediction.shape[0] == kRows @pytest.mark.skipif(**tm.no_sklearn()) -def test_sklearn_grid_search() -> None: +def test_sklearn_grid_search(client: "Client") -> None: from sklearn.model_selection import GridSearchCV - with LocalCluster(n_workers=kWorkers) as cluster: - with Client(cluster) as client: - X, y, _ = generate_array() - reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1, - tree_method='hist') - reg.client = client - model = GridSearchCV(reg, {'max_depth': [2, 4], - 'n_estimators': [5, 10]}, - cv=2, verbose=1) - model.fit(X, y) - # Expect unique results for each parameter value This confirms - # sklearn is able to successfully update the parameter - means = model.cv_results_['mean_test_score'] - assert len(means) == len(set(means)) + X, y, _ = generate_array() + reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1, + tree_method='hist') + reg.client = client + model = GridSearchCV(reg, {'max_depth': [2, 4], + 'n_estimators': [5, 10]}, + cv=2, verbose=1) + model.fit(X, y) + # Expect unique results for each parameter value This confirms + # sklearn is able to successfully update the parameter + means = model.cv_results_['mean_test_score'] + assert len(means) == len(set(means)) def test_empty_dmatrix_training_continuation(client: "Client") -> None: