From 5cb51a191e80e7fe235a78bfd7aedd47d6c9909e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 19 May 2021 13:50:45 +0800 Subject: [PATCH] [dask][doc] Add small example for sklearn interface. (#6970) --- doc/tutorials/dask.rst | 36 ++++++++++++++++++++++++++++++++++-- doc/tutorials/rf.rst | 4 ++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 440e595c0..b449de926 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -115,8 +115,8 @@ See next section for details. Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost, -with dask collection as inputs and has an additional ``client`` attribute. See -``xgboost/demo/dask`` for more examples. +with dask collection as inputs and has an additional ``client`` attribute. See following +sections and ``xgboost/demo/dask`` for more examples. ****************** @@ -191,6 +191,38 @@ Scikit-Learn wrapper object: booster = cls.get_booster() +********************** +Scikit-Learn interface +********************** + +As mentioned previously, there's another interface that mimics the scikit-learn estimators +with higher level of of abstraction. The interface is easier to use compared to the +functional interface but with more constraints. It's worth mentioning that, although the +interface mimics scikit-learn estimators, it doesn't work with normal scikit-learn +utilities like ``GridSearchCV`` as scikit-learn doesn't understand distributed dask data +collection. + + +.. code-block:: python + + from distributed import LocalCluster, Client + import xgboost as xgb + + + def main(client: Client) -> None: + X, y = load_data() + clf = xgb.dask.DaskXGBClassifier(n_estimators=100, tree_method="hist") + clf.client = client # assign the client + clf.fit(X, y, eval_set=[(X, y)]) + proba = clf.predict_proba(X) + + + if __name__ == "__main__": + with LocalCluster() as cluster: + with Client(cluster) as client: + main(client) + + *************************** Working with other clusters *************************** diff --git a/doc/tutorials/rf.rst b/doc/tutorials/rf.rst index 808dd3850..b68204e63 100644 --- a/doc/tutorials/rf.rst +++ b/doc/tutorials/rf.rst @@ -1,6 +1,6 @@ -######################### +############################# Random Forests(TM) in XGBoost -######################### +############################# XGBoost is normally used to train gradient-boosted decision trees and other gradient boosted models. Random Forests use the same model representation and inference, as