[doc] Add introduction and notes for the sklearn interface. (#8948)
This commit is contained in:
parent
bf88dadb61
commit
21a52c7f98
@ -2,6 +2,9 @@
|
|||||||
Collection of examples for using sklearn interface
|
Collection of examples for using sklearn interface
|
||||||
==================================================
|
==================================================
|
||||||
|
|
||||||
|
For an introduction to XGBoost's scikit-learn estimator interface, see
|
||||||
|
:doc:`/python/sklearn_estimator`.
|
||||||
|
|
||||||
Created on 1 Apr 2015
|
Created on 1 Apr 2015
|
||||||
|
|
||||||
@author: Jamie Hall
|
@author: Jamie Hall
|
||||||
|
|||||||
@ -10,6 +10,7 @@ Contents
|
|||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
python_intro
|
python_intro
|
||||||
|
sklearn_estimator
|
||||||
python_api
|
python_api
|
||||||
callbacks
|
callbacks
|
||||||
model
|
model
|
||||||
|
|||||||
@ -41,6 +41,7 @@ Learning API
|
|||||||
|
|
||||||
Scikit-Learn API
|
Scikit-Learn API
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
.. automodule:: xgboost.sklearn
|
.. automodule:: xgboost.sklearn
|
||||||
.. autoclass:: xgboost.XGBRegressor
|
.. autoclass:: xgboost.XGBRegressor
|
||||||
:members:
|
:members:
|
||||||
|
|||||||
@ -305,7 +305,8 @@ Scikit-Learn interface
|
|||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
XGBoost provides an easy to use scikit-learn interface for some pre-defined models
|
XGBoost provides an easy to use scikit-learn interface for some pre-defined models
|
||||||
including regression, classification and ranking.
|
including regression, classification and ranking. See :doc:`/python/sklearn_estimator`
|
||||||
|
for more info.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|||||||
162
doc/python/sklearn_estimator.rst
Normal file
162
doc/python/sklearn_estimator.rst
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
##########################################
|
||||||
|
Using the Scikit-Learn Estimator Interface
|
||||||
|
##########################################
|
||||||
|
|
||||||
|
**Contents**
|
||||||
|
|
||||||
|
.. contents::
|
||||||
|
:backlinks: none
|
||||||
|
:local:
|
||||||
|
|
||||||
|
********
|
||||||
|
Overview
|
||||||
|
********
|
||||||
|
|
||||||
|
In addition to the native interface, XGBoost features a sklearn estimator interface that
|
||||||
|
conforms to `sklearn estimator guideline
|
||||||
|
<https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator>`__. It
|
||||||
|
supports regression, classification, and learning to rank. Survival training for the
|
||||||
|
sklearn estimator interface is still working in progress.
|
||||||
|
|
||||||
|
You can find some some quick start examples at
|
||||||
|
:ref:`sphx_glr_python_examples_sklearn_examples.py`. The main advantage of using sklearn
|
||||||
|
interface is that it works with most of the utilites provided by sklearn like
|
||||||
|
:py:func:`sklearn.model_selection.cross_validate`. Also, many other libraries recognize
|
||||||
|
the sklearn estimator interface thanks to its popularity.
|
||||||
|
|
||||||
|
With the sklearn estimator interface, we can train a classification model with only a
|
||||||
|
couple lines of Python code. Here's an example for training a classification model:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=94)
|
||||||
|
|
||||||
|
# Use "hist" for constructing the trees, with early stopping enabled.
|
||||||
|
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2)
|
||||||
|
# Fit the model, test sets are used for early stopping.
|
||||||
|
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||||
|
# Save model into JSON format.
|
||||||
|
clf.save_model("clf.json")
|
||||||
|
|
||||||
|
|
||||||
|
The ``tree_method`` parameter specifies the method to use for constructing the trees, and
|
||||||
|
the early_stopping_rounds parameter enables early stopping. Early stopping can help
|
||||||
|
prevent overfitting and save time during training.
|
||||||
|
|
||||||
|
**************
|
||||||
|
Early Stopping
|
||||||
|
**************
|
||||||
|
|
||||||
|
As demonstrated in the previous example, early stopping can be enabled by the parameter
|
||||||
|
``early_stopping_rounds``. Alternatively, there's a callback function that can be used
|
||||||
|
:py:class:`xgboost.callback.EarlyStopping` to specify more details about the behavior of
|
||||||
|
early stopping, including whether XGBoost should return the best model instead of the full
|
||||||
|
stack of trees:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
early_stop = xgb.callback.EarlyStopping(
|
||||||
|
rounds=2, metric_name='logloss', data_name='Validation_0', save_best=True
|
||||||
|
)
|
||||||
|
clf = xgb.XGBClassifier(tree_method="hist", callbacks=[early_stop])
|
||||||
|
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||||
|
|
||||||
|
At present, XGBoost doesn't implement data spliting logic within the estimator and relies
|
||||||
|
on the ``eval_set`` parameter of the :py:meth:`xgboost.XGBModel.fit` method. If you want
|
||||||
|
to use early stopping to prevent overfitting, you'll need to manually split your data into
|
||||||
|
training and testing sets using the :py:func:`sklearn.model_selection.train_test_split`
|
||||||
|
function from the `sklearn` library. Some other machine learning algorithms, like those in
|
||||||
|
`sklearn`, include early stopping as part of the estimator and may work with cross
|
||||||
|
validation. However, using early stopping during cross validation may not be a perfect
|
||||||
|
approach because it changes the model's number of trees for each validation fold, leading
|
||||||
|
to different model. A better approach is to retrain the model after cross validation using
|
||||||
|
the best hyperparameters along with early stopping. If you want to experiment with idea of
|
||||||
|
using cross validation with early stopping, here is a snippet to begin with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sklearn.base import clone
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
from sklearn.model_selection import StratifiedKFold, cross_validate
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
|
||||||
|
|
||||||
|
def fit_and_score(estimator, X_train, X_test, y_train, y_test):
|
||||||
|
"""Fit the estimator on the train set and score it on both sets"""
|
||||||
|
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||||
|
|
||||||
|
train_score = estimator.score(X_train, y_train)
|
||||||
|
test_score = estimator.score(X_test, y_test)
|
||||||
|
|
||||||
|
return estimator, train_score, test_score
|
||||||
|
|
||||||
|
|
||||||
|
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=94)
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=3)
|
||||||
|
|
||||||
|
resutls = {}
|
||||||
|
|
||||||
|
for train, test in cv.split(X, y):
|
||||||
|
X_train = X[train]
|
||||||
|
X_test = X[test]
|
||||||
|
y_train = y[train]
|
||||||
|
y_test = y[test]
|
||||||
|
est, train_score, test_score = fit_and_score(
|
||||||
|
clone(clf), X_train, X_test, y_train, y_test
|
||||||
|
)
|
||||||
|
resutls[est] = (train_score, test_score)
|
||||||
|
|
||||||
|
|
||||||
|
***********************************
|
||||||
|
Obtaining the native booster object
|
||||||
|
***********************************
|
||||||
|
|
||||||
|
The sklearn estimator interface primarily facilitates training and doesn't implement all
|
||||||
|
features available in XGBoost. For instance, in order to have cached predictions,
|
||||||
|
:py:class:`xgboost.DMatrix` needs to be used with :py:meth:`xgboost.Booster.predict`. One
|
||||||
|
can obtain the booster object from the sklearn interface using
|
||||||
|
:py:meth:`xgboost.XGBModel.get_booster`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
booster = clf.get_booster()
|
||||||
|
print(booster.num_boosted_rounds())
|
||||||
|
|
||||||
|
|
||||||
|
**********
|
||||||
|
Prediction
|
||||||
|
**********
|
||||||
|
|
||||||
|
When early stopping is enabled, prediction functions including the
|
||||||
|
:py:meth:`xgboost.XGBModel.predict`, :py:meth:`xgboost.XGBModel.score`, and
|
||||||
|
:py:meth:`xgboost.XGBModel.apply` methods will use the best model automatically. Meaning
|
||||||
|
the :py:attr:`xgboost.XGBModel.best_iteration` is used to specify the range of trees used
|
||||||
|
in prediction.
|
||||||
|
|
||||||
|
To have cached results for incremental prediction, please use the
|
||||||
|
:py:meth:`xgboost.Booster.predict` method instead.
|
||||||
|
|
||||||
|
|
||||||
|
**************************
|
||||||
|
Number of parallel threads
|
||||||
|
**************************
|
||||||
|
|
||||||
|
When working with XGBoost and other sklearn tools, you can specify how many threads you
|
||||||
|
want to use by using the ``n_jobs`` parameter. By default, XGBoost uses all the available
|
||||||
|
threads on your computer, which can lead to some interesting consequences when combined
|
||||||
|
with other sklearn functions like :py:func:`sklearn.model_selection.cross_validate`. If
|
||||||
|
both XGBoost and sklearn are set to use all threads, your computer may start to slow down
|
||||||
|
significantly due to something called "thread thrashing". To avoid this, you can simply
|
||||||
|
set the ``n_jobs`` parameter for XGBoost to `None` (which uses all threads) and the
|
||||||
|
``n_jobs`` parameter for sklearn to `1`. This way, both programs will be able to work
|
||||||
|
together smoothly without causing any unnecessary computer strain.
|
||||||
@ -368,18 +368,21 @@ __model_doc = f"""
|
|||||||
|
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
|
|
||||||
Activates early stopping. Validation metric needs to improve at least once in
|
- Activates early stopping. Validation metric needs to improve at least once in
|
||||||
every **early_stopping_rounds** round(s) to continue training. Requires at least
|
every **early_stopping_rounds** round(s) to continue training. Requires at
|
||||||
one item in **eval_set** in :py:meth:`fit`.
|
least one item in **eval_set** in :py:meth:`fit`.
|
||||||
|
|
||||||
The method returns the model from the last iteration (not the best one). If
|
- The method returns the model from the last iteration, not the best one, use a
|
||||||
there's more than one item in **eval_set**, the last entry will be used for early
|
callback :py:class:`xgboost.callback.EarlyStopping` if returning the best
|
||||||
stopping. If there's more than one metric in **eval_metric**, the last metric
|
model is preferred.
|
||||||
will be used for early stopping.
|
|
||||||
|
|
||||||
If early stopping occurs, the model will have three additional fields:
|
- If there's more than one item in **eval_set**, the last entry will be used for
|
||||||
:py:attr:`best_score`, :py:attr:`best_iteration` and
|
early stopping. If there's more than one metric in **eval_metric**, the last
|
||||||
:py:attr:`best_ntree_limit`.
|
metric will be used for early stopping.
|
||||||
|
|
||||||
|
- If early stopping occurs, the model will have three additional fields:
|
||||||
|
:py:attr:`best_score`, :py:attr:`best_iteration` and
|
||||||
|
:py:attr:`best_ntree_limit`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -479,7 +482,9 @@ Parameters
|
|||||||
doc.extend([get_doc(i) for i in items])
|
doc.extend([get_doc(i) for i in items])
|
||||||
if end_note:
|
if end_note:
|
||||||
doc.append(end_note)
|
doc.append(end_note)
|
||||||
full_doc = [header + "\n\n"]
|
full_doc = [
|
||||||
|
header + "\nSee :doc:`/python/sklearn_estimator` for more information.\n"
|
||||||
|
]
|
||||||
full_doc.extend(doc)
|
full_doc.extend(doc)
|
||||||
cls.__doc__ = "".join(full_doc)
|
cls.__doc__ = "".join(full_doc)
|
||||||
return cls
|
return cls
|
||||||
@ -1146,10 +1151,10 @@ class XGBModel(XGBModelBase):
|
|||||||
base_margin: Optional[ArrayLike] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> ArrayLike:
|
) -> ArrayLike:
|
||||||
"""Predict with `X`. If the model is trained with early stopping, then `best_iteration`
|
"""Predict with `X`. If the model is trained with early stopping, then
|
||||||
is used automatically. For tree models, when data is on GPU, like cupy array or
|
:py:attr:`best_iteration` is used automatically. For tree models, when data is
|
||||||
cuDF dataframe and `predictor` is not specified, the prediction is run on GPU
|
on GPU, like cupy array or cuDF dataframe and `predictor` is not specified, the
|
||||||
automatically, otherwise it will run on CPU.
|
prediction is run on GPU automatically, otherwise it will run on CPU.
|
||||||
|
|
||||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||||
|
|
||||||
@ -1224,8 +1229,8 @@ class XGBModel(XGBModelBase):
|
|||||||
ntree_limit: int = 0,
|
ntree_limit: int = 0,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Return the predicted leaf every tree for each sample. If the model is trained with
|
"""Return the predicted leaf every tree for each sample. If the model is trained
|
||||||
early stopping, then `best_iteration` is used automatically.
|
with early stopping, then :py:attr:`best_iteration` is used automatically.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1635,7 +1640,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
base_margin: Optional[ArrayLike] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Predict the probability of each `X` example being of a given class.
|
"""Predict the probability of each `X` example being of a given class. If the
|
||||||
|
model is trained with early stopping, then :py:attr:`best_iteration` is used
|
||||||
|
automatically.
|
||||||
|
|
||||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||||
|
|
||||||
@ -1661,6 +1668,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
prediction :
|
prediction :
|
||||||
a numpy array of shape array-like of shape (n_samples, n_classes) with the
|
a numpy array of shape array-like of shape (n_samples, n_classes) with the
|
||||||
probability of each data example being of a given class.
|
probability of each data example being of a given class.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# custom obj: Do nothing as we don't know what to do.
|
# custom obj: Do nothing as we don't know what to do.
|
||||||
# softprob: Do nothing, output is proba.
|
# softprob: Do nothing, output is proba.
|
||||||
@ -2122,11 +2130,13 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
return super().apply(X, ntree_limit, iteration_range)
|
return super().apply(X, ntree_limit, iteration_range)
|
||||||
|
|
||||||
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
||||||
"""Evaluate score for data using the last evaluation metric.
|
"""Evaluate score for data using the last evaluation metric. If the model is
|
||||||
|
trained with early stopping, then :py:attr:`best_iteration` is used
|
||||||
|
automatically.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : pd.DataFrame|cudf.DataFrame
|
X : Union[pd.DataFrame, cudf.DataFrame]
|
||||||
Feature matrix. A DataFrame with a special `qid` column.
|
Feature matrix. A DataFrame with a special `qid` column.
|
||||||
|
|
||||||
y :
|
y :
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user