[doc] Add introduction and notes for the sklearn interface. (#8948)
This commit is contained in:
parent
bf88dadb61
commit
21a52c7f98
@ -2,6 +2,9 @@
|
||||
Collection of examples for using sklearn interface
|
||||
==================================================
|
||||
|
||||
For an introduction to XGBoost's scikit-learn estimator interface, see
|
||||
:doc:`/python/sklearn_estimator`.
|
||||
|
||||
Created on 1 Apr 2015
|
||||
|
||||
@author: Jamie Hall
|
||||
|
||||
@ -10,6 +10,7 @@ Contents
|
||||
|
||||
.. toctree::
|
||||
python_intro
|
||||
sklearn_estimator
|
||||
python_api
|
||||
callbacks
|
||||
model
|
||||
|
||||
@ -41,6 +41,7 @@ Learning API
|
||||
|
||||
Scikit-Learn API
|
||||
----------------
|
||||
|
||||
.. automodule:: xgboost.sklearn
|
||||
.. autoclass:: xgboost.XGBRegressor
|
||||
:members:
|
||||
|
||||
@ -305,7 +305,8 @@ Scikit-Learn interface
|
||||
----------------------
|
||||
|
||||
XGBoost provides an easy to use scikit-learn interface for some pre-defined models
|
||||
including regression, classification and ranking.
|
||||
including regression, classification and ranking. See :doc:`/python/sklearn_estimator`
|
||||
for more info.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
162
doc/python/sklearn_estimator.rst
Normal file
162
doc/python/sklearn_estimator.rst
Normal file
@ -0,0 +1,162 @@
|
||||
##########################################
|
||||
Using the Scikit-Learn Estimator Interface
|
||||
##########################################
|
||||
|
||||
**Contents**
|
||||
|
||||
.. contents::
|
||||
:backlinks: none
|
||||
:local:
|
||||
|
||||
********
|
||||
Overview
|
||||
********
|
||||
|
||||
In addition to the native interface, XGBoost features a sklearn estimator interface that
|
||||
conforms to `sklearn estimator guideline
|
||||
<https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator>`__. It
|
||||
supports regression, classification, and learning to rank. Survival training for the
|
||||
sklearn estimator interface is still working in progress.
|
||||
|
||||
You can find some some quick start examples at
|
||||
:ref:`sphx_glr_python_examples_sklearn_examples.py`. The main advantage of using sklearn
|
||||
interface is that it works with most of the utilites provided by sklearn like
|
||||
:py:func:`sklearn.model_selection.cross_validate`. Also, many other libraries recognize
|
||||
the sklearn estimator interface thanks to its popularity.
|
||||
|
||||
With the sklearn estimator interface, we can train a classification model with only a
|
||||
couple lines of Python code. Here's an example for training a classification model:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=94)
|
||||
|
||||
# Use "hist" for constructing the trees, with early stopping enabled.
|
||||
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2)
|
||||
# Fit the model, test sets are used for early stopping.
|
||||
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||
# Save model into JSON format.
|
||||
clf.save_model("clf.json")
|
||||
|
||||
|
||||
The ``tree_method`` parameter specifies the method to use for constructing the trees, and
|
||||
the early_stopping_rounds parameter enables early stopping. Early stopping can help
|
||||
prevent overfitting and save time during training.
|
||||
|
||||
**************
|
||||
Early Stopping
|
||||
**************
|
||||
|
||||
As demonstrated in the previous example, early stopping can be enabled by the parameter
|
||||
``early_stopping_rounds``. Alternatively, there's a callback function that can be used
|
||||
:py:class:`xgboost.callback.EarlyStopping` to specify more details about the behavior of
|
||||
early stopping, including whether XGBoost should return the best model instead of the full
|
||||
stack of trees:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
early_stop = xgb.callback.EarlyStopping(
|
||||
rounds=2, metric_name='logloss', data_name='Validation_0', save_best=True
|
||||
)
|
||||
clf = xgb.XGBClassifier(tree_method="hist", callbacks=[early_stop])
|
||||
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||
|
||||
At present, XGBoost doesn't implement data spliting logic within the estimator and relies
|
||||
on the ``eval_set`` parameter of the :py:meth:`xgboost.XGBModel.fit` method. If you want
|
||||
to use early stopping to prevent overfitting, you'll need to manually split your data into
|
||||
training and testing sets using the :py:func:`sklearn.model_selection.train_test_split`
|
||||
function from the `sklearn` library. Some other machine learning algorithms, like those in
|
||||
`sklearn`, include early stopping as part of the estimator and may work with cross
|
||||
validation. However, using early stopping during cross validation may not be a perfect
|
||||
approach because it changes the model's number of trees for each validation fold, leading
|
||||
to different model. A better approach is to retrain the model after cross validation using
|
||||
the best hyperparameters along with early stopping. If you want to experiment with idea of
|
||||
using cross validation with early stopping, here is a snippet to begin with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sklearn.base import clone
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.model_selection import StratifiedKFold, cross_validate
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
|
||||
def fit_and_score(estimator, X_train, X_test, y_train, y_test):
|
||||
"""Fit the estimator on the train set and score it on both sets"""
|
||||
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||
|
||||
train_score = estimator.score(X_train, y_train)
|
||||
test_score = estimator.score(X_test, y_test)
|
||||
|
||||
return estimator, train_score, test_score
|
||||
|
||||
|
||||
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=94)
|
||||
|
||||
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=3)
|
||||
|
||||
resutls = {}
|
||||
|
||||
for train, test in cv.split(X, y):
|
||||
X_train = X[train]
|
||||
X_test = X[test]
|
||||
y_train = y[train]
|
||||
y_test = y[test]
|
||||
est, train_score, test_score = fit_and_score(
|
||||
clone(clf), X_train, X_test, y_train, y_test
|
||||
)
|
||||
resutls[est] = (train_score, test_score)
|
||||
|
||||
|
||||
***********************************
|
||||
Obtaining the native booster object
|
||||
***********************************
|
||||
|
||||
The sklearn estimator interface primarily facilitates training and doesn't implement all
|
||||
features available in XGBoost. For instance, in order to have cached predictions,
|
||||
:py:class:`xgboost.DMatrix` needs to be used with :py:meth:`xgboost.Booster.predict`. One
|
||||
can obtain the booster object from the sklearn interface using
|
||||
:py:meth:`xgboost.XGBModel.get_booster`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
booster = clf.get_booster()
|
||||
print(booster.num_boosted_rounds())
|
||||
|
||||
|
||||
**********
|
||||
Prediction
|
||||
**********
|
||||
|
||||
When early stopping is enabled, prediction functions including the
|
||||
:py:meth:`xgboost.XGBModel.predict`, :py:meth:`xgboost.XGBModel.score`, and
|
||||
:py:meth:`xgboost.XGBModel.apply` methods will use the best model automatically. Meaning
|
||||
the :py:attr:`xgboost.XGBModel.best_iteration` is used to specify the range of trees used
|
||||
in prediction.
|
||||
|
||||
To have cached results for incremental prediction, please use the
|
||||
:py:meth:`xgboost.Booster.predict` method instead.
|
||||
|
||||
|
||||
**************************
|
||||
Number of parallel threads
|
||||
**************************
|
||||
|
||||
When working with XGBoost and other sklearn tools, you can specify how many threads you
|
||||
want to use by using the ``n_jobs`` parameter. By default, XGBoost uses all the available
|
||||
threads on your computer, which can lead to some interesting consequences when combined
|
||||
with other sklearn functions like :py:func:`sklearn.model_selection.cross_validate`. If
|
||||
both XGBoost and sklearn are set to use all threads, your computer may start to slow down
|
||||
significantly due to something called "thread thrashing". To avoid this, you can simply
|
||||
set the ``n_jobs`` parameter for XGBoost to `None` (which uses all threads) and the
|
||||
``n_jobs`` parameter for sklearn to `1`. This way, both programs will be able to work
|
||||
together smoothly without causing any unnecessary computer strain.
|
||||
@ -368,18 +368,21 @@ __model_doc = f"""
|
||||
|
||||
.. versionadded:: 1.6.0
|
||||
|
||||
Activates early stopping. Validation metric needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training. Requires at least
|
||||
one item in **eval_set** in :py:meth:`fit`.
|
||||
- Activates early stopping. Validation metric needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training. Requires at
|
||||
least one item in **eval_set** in :py:meth:`fit`.
|
||||
|
||||
The method returns the model from the last iteration (not the best one). If
|
||||
there's more than one item in **eval_set**, the last entry will be used for early
|
||||
stopping. If there's more than one metric in **eval_metric**, the last metric
|
||||
will be used for early stopping.
|
||||
- The method returns the model from the last iteration, not the best one, use a
|
||||
callback :py:class:`xgboost.callback.EarlyStopping` if returning the best
|
||||
model is preferred.
|
||||
|
||||
If early stopping occurs, the model will have three additional fields:
|
||||
:py:attr:`best_score`, :py:attr:`best_iteration` and
|
||||
:py:attr:`best_ntree_limit`.
|
||||
- If there's more than one item in **eval_set**, the last entry will be used for
|
||||
early stopping. If there's more than one metric in **eval_metric**, the last
|
||||
metric will be used for early stopping.
|
||||
|
||||
- If early stopping occurs, the model will have three additional fields:
|
||||
:py:attr:`best_score`, :py:attr:`best_iteration` and
|
||||
:py:attr:`best_ntree_limit`.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -479,7 +482,9 @@ Parameters
|
||||
doc.extend([get_doc(i) for i in items])
|
||||
if end_note:
|
||||
doc.append(end_note)
|
||||
full_doc = [header + "\n\n"]
|
||||
full_doc = [
|
||||
header + "\nSee :doc:`/python/sklearn_estimator` for more information.\n"
|
||||
]
|
||||
full_doc.extend(doc)
|
||||
cls.__doc__ = "".join(full_doc)
|
||||
return cls
|
||||
@ -1146,10 +1151,10 @@ class XGBModel(XGBModelBase):
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> ArrayLike:
|
||||
"""Predict with `X`. If the model is trained with early stopping, then `best_iteration`
|
||||
is used automatically. For tree models, when data is on GPU, like cupy array or
|
||||
cuDF dataframe and `predictor` is not specified, the prediction is run on GPU
|
||||
automatically, otherwise it will run on CPU.
|
||||
"""Predict with `X`. If the model is trained with early stopping, then
|
||||
:py:attr:`best_iteration` is used automatically. For tree models, when data is
|
||||
on GPU, like cupy array or cuDF dataframe and `predictor` is not specified, the
|
||||
prediction is run on GPU automatically, otherwise it will run on CPU.
|
||||
|
||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||
|
||||
@ -1224,8 +1229,8 @@ class XGBModel(XGBModelBase):
|
||||
ntree_limit: int = 0,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Return the predicted leaf every tree for each sample. If the model is trained with
|
||||
early stopping, then `best_iteration` is used automatically.
|
||||
"""Return the predicted leaf every tree for each sample. If the model is trained
|
||||
with early stopping, then :py:attr:`best_iteration` is used automatically.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1635,7 +1640,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Predict the probability of each `X` example being of a given class.
|
||||
"""Predict the probability of each `X` example being of a given class. If the
|
||||
model is trained with early stopping, then :py:attr:`best_iteration` is used
|
||||
automatically.
|
||||
|
||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||
|
||||
@ -1661,6 +1668,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
prediction :
|
||||
a numpy array of shape array-like of shape (n_samples, n_classes) with the
|
||||
probability of each data example being of a given class.
|
||||
|
||||
"""
|
||||
# custom obj: Do nothing as we don't know what to do.
|
||||
# softprob: Do nothing, output is proba.
|
||||
@ -2122,11 +2130,13 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
return super().apply(X, ntree_limit, iteration_range)
|
||||
|
||||
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
||||
"""Evaluate score for data using the last evaluation metric.
|
||||
"""Evaluate score for data using the last evaluation metric. If the model is
|
||||
trained with early stopping, then :py:attr:`best_iteration` is used
|
||||
automatically.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : pd.DataFrame|cudf.DataFrame
|
||||
X : Union[pd.DataFrame, cudf.DataFrame]
|
||||
Feature matrix. A DataFrame with a special `qid` column.
|
||||
|
||||
y :
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user