[doc] Brief introduction to base_score. (#9882)

This commit is contained in:
Jiaming Yuan 2023-12-17 13:34:34 +08:00 committed by GitHub
parent db7f952ed6
commit 0edd600f3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 114 additions and 6 deletions

View File

@ -391,6 +391,8 @@ Specify the learning task and the corresponding learning objective. The objectiv
- If ``base_margin`` is supplied, ``base_score`` will not be added. - If ``base_margin`` is supplied, ``base_score`` will not be added.
- For sufficient number of iterations, changing this value will not have too much effect. - For sufficient number of iterations, changing this value will not have too much effect.
See :doc:`/tutorials/intercept` for more info.
* ``eval_metric`` [default according to objective] * ``eval_metric`` [default according to objective]
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.) - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)

View File

@ -271,7 +271,8 @@ available in XGBoost:
We use ``multi:softmax`` to illustrate the differences of transformed prediction. With We use ``multi:softmax`` to illustrate the differences of transformed prediction. With
``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for ``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for
``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also ``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also
available at :ref:`sphx_glr_python_examples_custom_softmax.py`. available at :ref:`sphx_glr_python_examples_custom_softmax.py`. Also, see
:doc:`/tutorials/intercept` for some more explanation.
********************** **********************

View File

@ -30,4 +30,5 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
input_format input_format
param_tuning param_tuning
custom_metric_obj custom_metric_obj
intercept
privacy_preserving privacy_preserving

104
doc/tutorials/intercept.rst Normal file
View File

@ -0,0 +1,104 @@
#########
Intercept
#########
.. versionadded:: 2.0.0
Since 2.0.0, XGBoost supports estimating the model intercept (named ``base_score``)
automatically based on targets upon training. The behavior can be controlled by setting
``base_score`` to a constant value. The following snippet disables the automatic
estimation:
.. code-block:: python
import xgboost as xgb
reg = xgb.XGBRegressor()
reg.set_params(base_score=0.5)
In addition, here 0.5 represents the value after applying the inverse link function. See
the end of the document for a description.
Other than the ``base_score``, users can also provide global bias via the data field
``base_margin``, which is a vector or a matrix depending on the task. With multi-output
and multi-class, the ``base_margin`` is a matrix with size ``(n_samples, n_targets)`` or
``(n_samples, n_classes)``.
.. code-block:: python
import xgboost as xgb
from sklearn.datasets import make_regression
X, y = make_regression()
reg = xgb.XGBRegressor()
reg.fit(X, y)
# Request for raw prediction
m = reg.predict(X, output_margin=True)
reg_1 = xgb.XGBRegressor()
# Feed the prediction into the next model
reg.fit(X, y, base_margin=m)
reg.predict(X, base_margin=m)
It specifies the bias for each sample and can be used for stacking an XGBoost model on top
of other models, see :ref:`sphx_glr_python_examples_boost_from_prediction.py` for a worked
example. When ``base_margin`` is specified, it automatically overrides the ``base_score``
parameter. If you are stacking XGBoost models, then the usage should be relatively
straightforward, with the previous model providing raw prediction and a new model using
the prediction as bias. For more customized inputs, users need to take extra care of the
link function. Let :math:`F` be the model and :math:`g` be the link function, since
``base_score`` is overridden when sample-specific ``base_margin`` is available, we will
omit it here:
.. math::
g(E[y_i]) = F(x_i)
When base margin :math:`b` is provided, it's added to the raw model output :math:`F`:
.. math::
g(E[y_i]) = F(x_i) + b_i
and the output of the final model is:
.. math::
g^{-1}(F(x_i) + b_i)
Using the gamma deviance objective ``reg:gamma`` as an example, which has a log link
function, hence:
.. math::
\ln{(E[y_i])} = F(x_i) + b_i \\
E[y_i] = \exp{(F(x_i) + b_i)}
As a result, if you are feeding outputs from models like GLM with a corresponding
objective function, make sure the outputs are not yet transformed by the inverse link.
In the case of ``base_score`` (intercept), it can be accessed through
:py:meth:`~xgboost.Booster.save_config` after estimation. Unlike the ``base_margin``, the
returned value represents a value after applying inverse link. With logistic regression
and the logit link function as an example, given the ``base_score`` as 0.5,
:math:`g(intercept) = logit(0.5) = 0` is added to the raw model output:
.. math::
E[y_i] = g^{-1}{(F(x_i) + g(intercept))}
and 0.5 is the same as :math:`base_score = g^{-1}(0) = 0.5`. This is more intuitive if you
remove the model and consider only the intercept, which is estimated before the model is
fitted:
.. math::
E[y] = g^{-1}{g(intercept))} \\
E[y] = intercept
For some objectives like MAE, there are close solutions, while for others it's estimated
with one step Newton method.

View File

@ -785,7 +785,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
so it doesn't make sense to assign weights to individual data points. so it doesn't make sense to assign weights to individual data points.
base_margin : base_margin :
Base margin used for boosting from existing model. Global bias for each instance. See :doc:`/tutorials/intercept` for details.
missing : missing :
Value in the input data which needs to be present as a missing value. If Value in the input data which needs to be present as a missing value. If
None, defaults to np.nan. None, defaults to np.nan.

View File

@ -1006,7 +1006,7 @@ class XGBModel(XGBModelBase):
sample_weight : sample_weight :
instance weights instance weights
base_margin : base_margin :
global bias for each instance. Global bias for each instance. See :doc:`/tutorials/intercept` for details.
eval_set : eval_set :
A list of (X, y) tuple pairs to use as validation sets, for which A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed. metrics will be computed.
@ -1146,7 +1146,7 @@ class XGBModel(XGBModelBase):
When this is True, validate that the Booster's and data's feature_names are When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same. identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : base_margin :
Margin added to prediction. Global bias for each instance. See :doc:`/tutorials/intercept` for details.
iteration_range : iteration_range :
Specifies which layer of trees are used in prediction. For example, if a Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying ``iteration_range=(10, random forest is trained with 100 rounds. Specifying ``iteration_range=(10,
@ -1599,7 +1599,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
When this is True, validate that the Booster's and data's feature_names are When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same. identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : base_margin :
Margin added to prediction. Global bias for each instance. See :doc:`/tutorials/intercept` for details.
iteration_range : iteration_range :
Specifies which layer of trees are used in prediction. For example, if a Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10, random forest is trained with 100 rounds. Specifying `iteration_range=(10,
@ -1942,7 +1942,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
weights to individual data points. weights to individual data points.
base_margin : base_margin :
Global bias for each instance. Global bias for each instance. See :doc:`/tutorials/intercept` for details.
eval_set : eval_set :
A list of (X, y) tuple pairs to use as validation sets, for which A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed. metrics will be computed.