[dask] Random forest estimators (#6602)
This commit is contained in:
parent
0027220aa0
commit
89a00a5866
@ -106,3 +106,9 @@ Dask API
|
||||
.. autofunction:: xgboost.dask.DaskXGBClassifier
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRegressor
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRanker
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRFRegressor
|
||||
|
||||
.. autofunction:: xgboost.dask.DaskXGBRFClassifier
|
||||
|
||||
@ -38,8 +38,8 @@ from .core import Objective, Metric
|
||||
from .core import _deprecate_positional_args
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker, get_host_ip
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
||||
from .sklearn import xgboost_model_doc
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import xgboost_model_doc, _objective_decorator
|
||||
from .sklearn import _cls_predict_proba
|
||||
from .sklearn import XGBRanker
|
||||
|
||||
@ -1262,7 +1262,6 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
_client = None
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@_deprecate_positional_args
|
||||
async def _predict_async(
|
||||
self, data: _DaskCollection,
|
||||
@ -1282,7 +1281,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
def predict(
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
X: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
@ -1291,10 +1290,13 @@ class DaskScikitLearnBase(XGBModel):
|
||||
_assert_dask_support()
|
||||
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(self._predict_async, data,
|
||||
return self.client.sync(
|
||||
self._predict_async,
|
||||
X,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin)
|
||||
base_margin=base_margin
|
||||
)
|
||||
|
||||
def __await__(self) -> Awaitable[Any]:
|
||||
# Generate a coroutine wrapper to make this class awaitable.
|
||||
@ -1586,7 +1588,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
""",
|
||||
)
|
||||
class DaskXGBRanker(DaskScikitLearnBase):
|
||||
def __init__(self, objective: str = "rank:pairwise", **kwargs: Any):
|
||||
@_deprecate_positional_args
|
||||
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
|
||||
if callable(objective):
|
||||
raise ValueError("Custom objective function not supported by XGBRanker.")
|
||||
super().__init__(objective=objective, kwargs=kwargs)
|
||||
@ -1698,3 +1701,75 @@ class DaskXGBRanker(DaskScikitLearnBase):
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
fit.__doc__ = XGBRanker.fit.__doc__
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
Number of trees in random forest to fit.
|
||||
""",
|
||||
)
|
||||
class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
learning_rate: Optional[float] = 1,
|
||||
subsample: Optional[float] = 0.8,
|
||||
colsample_bynode: Optional[float] = 0.8,
|
||||
reg_lambda: Optional[float] = 1e-5,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
subsample=subsample,
|
||||
colsample_bynode=colsample_bynode,
|
||||
reg_lambda=reg_lambda,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_xgb_params(self) -> Dict[str, Any]:
|
||||
params = super().get_xgb_params()
|
||||
params["num_parallel_tree"] = self.n_estimators
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
Number of trees in random forest to fit.
|
||||
""",
|
||||
)
|
||||
class DaskXGBRFClassifier(DaskXGBClassifier):
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
learning_rate: Optional[float] = 1,
|
||||
subsample: Optional[float] = 0.8,
|
||||
colsample_bynode: Optional[float] = 0.8,
|
||||
reg_lambda: Optional[float] = 1e-5,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
subsample=subsample,
|
||||
colsample_bynode=colsample_bynode,
|
||||
reg_lambda=reg_lambda,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_xgb_params(self) -> Dict[str, Any]:
|
||||
params = super().get_xgb_params()
|
||||
params["num_parallel_tree"] = self.n_estimators
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
return 1
|
||||
|
||||
@ -91,7 +91,7 @@ __model_doc = '''
|
||||
node of the tree.
|
||||
min_child_weight : float
|
||||
Minimum sum of instance weight(hessian) needed in a child.
|
||||
max_delta_step : int
|
||||
max_delta_step : float
|
||||
Maximum delta step we allow each tree's weight estimation to be.
|
||||
subsample : float
|
||||
Subsample ratio of the training instance.
|
||||
@ -1465,7 +1465,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
xgb_model = xgb_model._Booster # pylint: disable=protected-access
|
||||
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
self.n_estimators,
|
||||
self.get_num_boosting_rounds(),
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals=evals,
|
||||
evals_result=evals_result, feval=feval,
|
||||
|
||||
@ -34,7 +34,7 @@ from xgboost.dask import DaskDMatrix
|
||||
if hasattr(HealthCheck, 'function_scoped_fixture'):
|
||||
suppress = [HealthCheck.function_scoped_fixture]
|
||||
else:
|
||||
suppress = hypothesis.utils.conventions.not_set
|
||||
suppress = hypothesis.utils.conventions.not_set # type:ignore
|
||||
|
||||
|
||||
kRows = 1000
|
||||
@ -264,11 +264,14 @@ def test_dask_missing_value_cls() -> None:
|
||||
assert hasattr(cls, 'missing')
|
||||
|
||||
|
||||
def test_dask_regressor() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
@pytest.mark.parametrize("model", ["boosting", "rf"])
|
||||
def test_dask_regressor(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
if model == "boosting":
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
else:
|
||||
regressor = xgb.dask.DaskXGBRFRegressor(verbosity=1, n_estimators=2)
|
||||
|
||||
assert regressor._estimator_type == "regressor"
|
||||
assert sklearn.base.is_regressor(regressor)
|
||||
|
||||
@ -286,16 +289,33 @@ def test_dask_regressor() -> None:
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||
forest = int(
|
||||
json.loads(regressor.get_booster().save_config())["learner"][
|
||||
"gradient_booster"
|
||||
]["gbtree_train_param"]["num_parallel_tree"]
|
||||
)
|
||||
|
||||
if model == "boosting":
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
assert forest == 1
|
||||
else:
|
||||
assert len(history['validation_0']['rmse']) == 1
|
||||
assert forest == 2
|
||||
|
||||
|
||||
def test_dask_classifier() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
@pytest.mark.parametrize("model", ["boosting", "rf"])
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
if model == "boosting":
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric='merror')
|
||||
verbosity=1, n_estimators=2, eval_metric="merror"
|
||||
)
|
||||
else:
|
||||
classifier = xgb.dask.DaskXGBRFClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric="merror"
|
||||
)
|
||||
|
||||
assert classifier._estimator_type == "classifier"
|
||||
assert sklearn.base.is_classifier(classifier)
|
||||
|
||||
@ -311,10 +331,20 @@ def test_dask_classifier() -> None:
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == 'validation_0'
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
assert list(history.keys())[0] == "validation_0"
|
||||
assert list(history["validation_0"].keys())[0] == "merror"
|
||||
assert len(list(history["validation_0"])) == 1
|
||||
forest = int(
|
||||
json.loads(classifier.get_booster().save_config())["learner"][
|
||||
"gradient_booster"
|
||||
]["gbtree_train_param"]["num_parallel_tree"]
|
||||
)
|
||||
if model == "boosting":
|
||||
assert len(history["validation_0"]["merror"]) == 2
|
||||
assert forest == 1
|
||||
else:
|
||||
assert len(history["validation_0"]["merror"]) == 1
|
||||
assert forest == 2
|
||||
|
||||
# Test .predict_proba()
|
||||
probas = classifier.predict_proba(X)
|
||||
@ -326,8 +356,7 @@ def test_dask_classifier() -> None:
|
||||
cls_booster = classifier.get_booster()
|
||||
single_node_proba = cls_booster.inplace_predict(X.compute())
|
||||
|
||||
np.testing.assert_allclose(single_node_proba,
|
||||
probas.compute())
|
||||
np.testing.assert_allclose(single_node_proba, probas.compute())
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
@ -342,10 +371,8 @@ def test_dask_classifier() -> None:
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn_grid_search() -> None:
|
||||
def test_sklearn_grid_search(client: "Client") -> None:
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y, _ = generate_array()
|
||||
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||
tree_method='hist')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user