Doc and demo for customized metric and obj. (#4598)
Co-Authored-By: Theodore Vasiloudis <theodoros.vasiloudis@gmail.com>
This commit is contained in:
parent
8bdf15120a
commit
5b2f805e74
@ -2,6 +2,7 @@ XGBoost Python Feature Walkthrough
|
|||||||
==================================
|
==================================
|
||||||
* [Basic walkthrough of wrappers](basic_walkthrough.py)
|
* [Basic walkthrough of wrappers](basic_walkthrough.py)
|
||||||
* [Customize loss function, and evaluation metric](custom_objective.py)
|
* [Customize loss function, and evaluation metric](custom_objective.py)
|
||||||
|
* [Re-implement RMSLE as customized metric and objective](custom_rmsle.py)
|
||||||
* [Boosting from existing prediction](boost_from_prediction.py)
|
* [Boosting from existing prediction](boost_from_prediction.py)
|
||||||
* [Predicting using first n trees](predict_first_ntree.py)
|
* [Predicting using first n trees](predict_first_ntree.py)
|
||||||
* [Generalized Linear Model](generalized_linear_model.py)
|
* [Generalized Linear Model](generalized_linear_model.py)
|
||||||
|
|||||||
@ -16,8 +16,9 @@ param = {'max_depth': 2, 'eta': 1, 'silent': 1}
|
|||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
num_round = 2
|
num_round = 2
|
||||||
|
|
||||||
# user define objective function, given prediction, return gradient and second order gradient
|
|
||||||
# this is log likelihood loss
|
# user define objective function, given prediction, return gradient and second
|
||||||
|
# order gradient this is log likelihood loss
|
||||||
def logregobj(preds, dtrain):
|
def logregobj(preds, dtrain):
|
||||||
labels = dtrain.get_label()
|
labels = dtrain.get_label()
|
||||||
preds = 1.0 / (1.0 + np.exp(-preds))
|
preds = 1.0 / (1.0 + np.exp(-preds))
|
||||||
@ -25,18 +26,24 @@ def logregobj(preds, dtrain):
|
|||||||
hess = preds * (1.0 - preds)
|
hess = preds * (1.0 - preds)
|
||||||
return grad, hess
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
# user defined evaluation function, return a pair metric_name, result
|
# user defined evaluation function, return a pair metric_name, result
|
||||||
# NOTE: when you do customized loss function, the default prediction value is margin
|
|
||||||
# this may make builtin evaluation metric not function properly
|
# NOTE: when you do customized loss function, the default prediction value is
|
||||||
# for example, we are doing logistic loss, the prediction is score before logistic transformation
|
# margin. this may make builtin evaluation metric not function properly for
|
||||||
# the builtin evaluation error assumes input is after logistic transformation
|
# example, we are doing logistic loss, the prediction is score before logistic
|
||||||
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
# transformation the builtin evaluation error assumes input is after logistic
|
||||||
|
# transformation Take this in mind when you use the customization, and maybe
|
||||||
|
# you need write customized evaluation function
|
||||||
def evalerror(preds, dtrain):
|
def evalerror(preds, dtrain):
|
||||||
labels = dtrain.get_label()
|
labels = dtrain.get_label()
|
||||||
# return a pair metric_name, result. The metric name must not contain a colon (:) or a space
|
# return a pair metric_name, result. The metric name must not contain a
|
||||||
# since preds are margin(before logistic transformation, cutoff at 0)
|
# colon (:) or a space since preds are margin(before logistic
|
||||||
|
# transformation, cutoff at 0)
|
||||||
return 'my-error', float(sum(labels != (preds > 0.0))) / len(labels)
|
return 'my-error', float(sum(labels != (preds > 0.0))) / len(labels)
|
||||||
|
|
||||||
|
|
||||||
# training with customized objective, we can also do step by step training
|
# training with customized objective, we can also do step by step training
|
||||||
# simply look at xgboost.py's implementation of train
|
# simply look at xgboost.py's implementation of train
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj, feval=evalerror)
|
bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj,
|
||||||
|
feval=evalerror)
|
||||||
|
|||||||
179
demo/guide-python/custom_rmsle.py
Normal file
179
demo/guide-python/custom_rmsle.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
'''Demo for defining customized metric and objective. Notice that for
|
||||||
|
simplicity reason weight is not used in following example. In this
|
||||||
|
script, we implement the Squared Log Error (SLE) objective and RMSLE metric as customized
|
||||||
|
functions, then compare it with native implementation in XGBoost.
|
||||||
|
|
||||||
|
See doc/tutorials/custom_metric_obj.rst for a step by step
|
||||||
|
walkthrough, with other details.
|
||||||
|
|
||||||
|
The `SLE` objective reduces impact of outliers in training dataset,
|
||||||
|
hence here we also compare its performance with standard squared
|
||||||
|
error.
|
||||||
|
|
||||||
|
'''
|
||||||
|
import numpy as np
|
||||||
|
import xgboost as xgb
|
||||||
|
from typing import Tuple, Dict, List
|
||||||
|
from time import time
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
# shape of generated data.
|
||||||
|
kRows = 4096
|
||||||
|
kCols = 16
|
||||||
|
|
||||||
|
kOutlier = 10000 # mean of generated outliers
|
||||||
|
kNumberOfOutliers = 64
|
||||||
|
|
||||||
|
kRatio = 0.7
|
||||||
|
kSeed = 1994
|
||||||
|
|
||||||
|
kBoostRound = 20
|
||||||
|
|
||||||
|
np.random.seed(seed=kSeed)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_data() -> Tuple[xgb.DMatrix, xgb.DMatrix]:
|
||||||
|
'''Generate data containing outliers.'''
|
||||||
|
x = np.random.randn(kRows, kCols)
|
||||||
|
y = np.random.randn(kRows)
|
||||||
|
y += np.abs(np.min(y))
|
||||||
|
|
||||||
|
# Create outliers
|
||||||
|
for i in range(0, kNumberOfOutliers):
|
||||||
|
ind = np.random.randint(0, len(y)-1)
|
||||||
|
y[ind] += np.random.randint(0, kOutlier)
|
||||||
|
|
||||||
|
train_portion = int(kRows * kRatio)
|
||||||
|
|
||||||
|
# rmsle requires all label be greater than -1.
|
||||||
|
assert np.all(y > -1.0)
|
||||||
|
|
||||||
|
train_x: np.ndarray = x[: train_portion]
|
||||||
|
train_y: np.ndarray = y[: train_portion]
|
||||||
|
dtrain = xgb.DMatrix(train_x, label=train_y)
|
||||||
|
|
||||||
|
test_x = x[train_portion:]
|
||||||
|
test_y = y[train_portion:]
|
||||||
|
dtest = xgb.DMatrix(test_x, label=test_y)
|
||||||
|
return dtrain, dtest
|
||||||
|
|
||||||
|
|
||||||
|
def native_rmse(dtrain: xgb.DMatrix,
|
||||||
|
dtest: xgb.DMatrix) -> Dict[str, Dict[str, List[float]]]:
|
||||||
|
'''Train using native implementation of Root Mean Squared Loss.'''
|
||||||
|
print('Squared Error')
|
||||||
|
squared_error = {
|
||||||
|
'objective': 'reg:squarederror',
|
||||||
|
'eval_metric': 'rmse',
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'seed': kSeed
|
||||||
|
}
|
||||||
|
start = time()
|
||||||
|
results: Dict[str, Dict[str, List[float]]] = {}
|
||||||
|
xgb.train(squared_error,
|
||||||
|
dtrain=dtrain,
|
||||||
|
num_boost_round=kBoostRound,
|
||||||
|
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
|
||||||
|
evals_result=results)
|
||||||
|
print('Finished Squared Error in:', time() - start, '\n')
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def native_rmsle(dtrain: xgb.DMatrix,
|
||||||
|
dtest: xgb.DMatrix) -> Dict[str, Dict[str, List[float]]]:
|
||||||
|
'''Train using native implementation of Squared Log Error.'''
|
||||||
|
print('Squared Log Error')
|
||||||
|
results: Dict[str, Dict[str, List[float]]] = {}
|
||||||
|
squared_log_error = {
|
||||||
|
'objective': 'reg:squaredlogerror',
|
||||||
|
'eval_metric': 'rmsle',
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'seed': kSeed
|
||||||
|
}
|
||||||
|
start = time()
|
||||||
|
xgb.train(squared_log_error,
|
||||||
|
dtrain=dtrain,
|
||||||
|
num_boost_round=kBoostRound,
|
||||||
|
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
|
||||||
|
evals_result=results)
|
||||||
|
print('Finished Squared Log Error in:', time() - start)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict:
|
||||||
|
'''Train using Python implementation of Squared Log Error.'''
|
||||||
|
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
|
||||||
|
'''Compute the gradient squared log error.'''
|
||||||
|
y = dtrain.get_label()
|
||||||
|
return (np.log1p(predt) - np.log1p(y)) / (predt + 1)
|
||||||
|
|
||||||
|
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
|
||||||
|
'''Compute the hessian for squared log error.'''
|
||||||
|
y = dtrain.get_label()
|
||||||
|
return ((-np.log1p(predt) + np.log1p(y) + 1) /
|
||||||
|
np.power(predt + 1, 2))
|
||||||
|
|
||||||
|
def squared_log(predt: np.ndarray,
|
||||||
|
dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
'''Squared Log Error objective. A simplified version for RMSLE used as
|
||||||
|
objective function.
|
||||||
|
|
||||||
|
:math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`
|
||||||
|
|
||||||
|
'''
|
||||||
|
predt[predt < -1] = -1 + 1e-6
|
||||||
|
grad = gradient(predt, dtrain)
|
||||||
|
hess = hessian(predt, dtrain)
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
|
||||||
|
''' Root mean squared log error metric.
|
||||||
|
|
||||||
|
:math:`\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}`
|
||||||
|
'''
|
||||||
|
y = dtrain.get_label()
|
||||||
|
predt[predt < -1] = -1 + 1e-6
|
||||||
|
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
|
||||||
|
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
|
||||||
|
|
||||||
|
results: Dict[str, Dict[str, List[float]]] = {}
|
||||||
|
xgb.train({'tree_method': 'hist', 'seed': kSeed,
|
||||||
|
'disable_default_eval_metric': 1},
|
||||||
|
dtrain=dtrain,
|
||||||
|
num_boost_round=kBoostRound,
|
||||||
|
obj=squared_log,
|
||||||
|
feval=rmsle,
|
||||||
|
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
|
||||||
|
evals_result=results)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dtrain, dtest = generate_data()
|
||||||
|
rmse_evals = native_rmse(dtrain, dtest)
|
||||||
|
rmsle_evals = native_rmsle(dtrain, dtest)
|
||||||
|
py_rmsle_evals = py_rmsle(dtrain, dtest)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(3, 1)
|
||||||
|
ax0: matplotlib.axes.Axes = axs[0]
|
||||||
|
ax1: matplotlib.axes.Axes = axs[1]
|
||||||
|
ax2: matplotlib.axes.Axes = axs[2]
|
||||||
|
|
||||||
|
x = np.arange(0, kBoostRound, 1)
|
||||||
|
|
||||||
|
ax0.plot(x, rmse_evals['dtrain']['rmse'], label='train-RMSE')
|
||||||
|
ax0.plot(x, rmse_evals['dtest']['rmse'], label='test-RMSE')
|
||||||
|
ax0.legend()
|
||||||
|
|
||||||
|
ax1.plot(x, rmsle_evals['dtrain']['rmsle'], label='train-native-RMSLE')
|
||||||
|
ax1.plot(x, rmsle_evals['dtest']['rmsle'], label='test-native-RMSLE')
|
||||||
|
ax1.legend()
|
||||||
|
|
||||||
|
ax2.plot(x, py_rmsle_evals['dtrain']['PyRMSLE'], label='train-PyRMSLE')
|
||||||
|
ax2.plot(x, py_rmsle_evals['dtest']['PyRMSLE'], label='test-PyRMSLE')
|
||||||
|
ax2.legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
@ -225,4 +225,4 @@ Many thanks to the following contributors (alphabetical order):
|
|||||||
* Shankara Rao Thejaswi Nanditale
|
* Shankara Rao Thejaswi Nanditale
|
||||||
* Vinay Deshpande
|
* Vinay Deshpande
|
||||||
|
|
||||||
Please report bugs to the user forum https://discuss.xgboost.ai/.
|
Please report bugs to the XGBoost issues list: https://github.com/dmlc/xgboost/issues. For general questions please visit our user form: https://discuss.xgboost.ai/.
|
||||||
|
|||||||
@ -157,4 +157,3 @@ After training and loading a model, you can use it to make prediction for other
|
|||||||
float[][] predicts = booster.predict(dtest);
|
float[][] predicts = booster.predict(dtest);
|
||||||
// predict leaf
|
// predict leaf
|
||||||
float[][] leafPredicts = booster.predictLeaf(dtest, 0);
|
float[][] leafPredicts = booster.predictLeaf(dtest, 0);
|
||||||
|
|
||||||
|
|||||||
138
doc/tutorials/custom_metric_obj.rst
Normal file
138
doc/tutorials/custom_metric_obj.rst
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
######################################
|
||||||
|
Custom Objective and Evaluation Metric
|
||||||
|
######################################
|
||||||
|
|
||||||
|
XGBoost is designed to be an extensible library. One way to extend it is by providing our
|
||||||
|
own objective function for training and corresponding metric for performance monitoring.
|
||||||
|
This document introduces implementing a customized elementwise evaluation metric and
|
||||||
|
objective for XGBoost. Although the introduction uses Python for demonstration, the
|
||||||
|
concepts should be readily applicable to other language bindings.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
* The ranking task does not support customized functions.
|
||||||
|
* The customized functions defined here are only applicable to single node training.
|
||||||
|
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
|
||||||
|
subject to change hence beyond the scope of this tutorial.
|
||||||
|
* We also plan to re-design the interface for multi-classes objective in the future.
|
||||||
|
|
||||||
|
In the following sections, we will provide a step by step walk through of implementing
|
||||||
|
``Squared Log Error(SLE)`` objective function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\frac{1}{2}[log(pred + 1) - log(label + 1)]^2
|
||||||
|
|
||||||
|
and its default metric ``Root Mean Squared Log Error(RMSLE)``:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}
|
||||||
|
|
||||||
|
Although XGBoost has native support for said functions, using it for demonstration
|
||||||
|
provides us the opportunity of comparing the result from our own implementation and the
|
||||||
|
one from XGBoost internal for learning purposes. After finishing this tutorial, we should
|
||||||
|
be able to provide our own functions for rapid experiments.
|
||||||
|
|
||||||
|
*****************************
|
||||||
|
Customized Objective Function
|
||||||
|
*****************************
|
||||||
|
|
||||||
|
During model training, the objective function plays an important role: provide gradient
|
||||||
|
information, both first and second order gradient, based on model predictions and observed
|
||||||
|
data labels (or targets). Therefore, a valid objective function should accept two inputs,
|
||||||
|
namely prediction and labels. For implementing ``SLE``, we define:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
|
||||||
|
'''Compute the gradient squared log error.'''
|
||||||
|
y = dtrain.get_label()
|
||||||
|
return (np.log1p(predt) - np.log1p(y)) / (predt + 1)
|
||||||
|
|
||||||
|
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
|
||||||
|
'''Compute the hessian for squared log error.'''
|
||||||
|
y = dtrain.get_label()
|
||||||
|
return ((-np.log1p(predt) + np.log1p(y) + 1) /
|
||||||
|
np.power(predt + 1, 2))
|
||||||
|
|
||||||
|
def squared_log(predt: np.ndarray,
|
||||||
|
dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
'''Squared Log Error objective. A simplified version for RMSLE used as
|
||||||
|
objective function.
|
||||||
|
'''
|
||||||
|
predt[predt < -1] = -1 + 1e-6
|
||||||
|
grad = gradient(predt, dtrain)
|
||||||
|
hess = hessian(predt, dtrain)
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
|
In the above code snippet, ``squared_log`` is the objective function we want. It accepts a
|
||||||
|
numpy array ``predt`` as model prediction, and the training DMatrix for obtaining required
|
||||||
|
information, including labels and weights (not used here). This objective is then used as
|
||||||
|
a callback function for XGBoost during training by passing it as an argument to
|
||||||
|
``xgb.train``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine.
|
||||||
|
dtrain=dtrain,
|
||||||
|
num_boost_round=10,
|
||||||
|
obj=squared_log)
|
||||||
|
|
||||||
|
Notice that in our definition of the objective, whether we subtract the labels from the
|
||||||
|
prediction or the other way around is important. If you find the training error goes up
|
||||||
|
instead of down, this might be the reason.
|
||||||
|
|
||||||
|
|
||||||
|
**************************
|
||||||
|
Customized Metric Function
|
||||||
|
**************************
|
||||||
|
|
||||||
|
So after having a customized objective, we might also need a corresponding metric to
|
||||||
|
monitor our model's performance. As mentioned above, the default metric for ``SLE`` is
|
||||||
|
``RMSLE``. Similarly we define another callback like function as the new metric:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
|
||||||
|
''' Root mean squared log error metric.'''
|
||||||
|
y = dtrain.get_label()
|
||||||
|
predt[predt < -1] = -1 + 1e-6
|
||||||
|
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
|
||||||
|
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
|
||||||
|
|
||||||
|
Since we are demonstrating in Python, the metric or objective needs not be a function,
|
||||||
|
any callable object should suffice. Similarly to the objective function, our metric also
|
||||||
|
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of metric itself and a
|
||||||
|
floating point value as result. After passing it into XGBoost as argument of ``feval``
|
||||||
|
parameter:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
xgb.train({'tree_method': 'hist', 'seed': 1994,
|
||||||
|
'disable_default_eval_metric': 1},
|
||||||
|
dtrain=dtrain,
|
||||||
|
num_boost_round=10,
|
||||||
|
obj=squared_log,
|
||||||
|
feval=rmsle,
|
||||||
|
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
|
||||||
|
evals_result=results)
|
||||||
|
|
||||||
|
We will be able to see XGBoost printing something like:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
|
||||||
|
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
|
||||||
|
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
|
||||||
|
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
|
||||||
|
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
|
||||||
|
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
|
||||||
|
...
|
||||||
|
|
||||||
|
Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric
|
||||||
|
in XGBoost.
|
||||||
|
|
||||||
|
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.
|
||||||
@ -19,4 +19,4 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
|
|||||||
input_format
|
input_format
|
||||||
param_tuning
|
param_tuning
|
||||||
external_memory
|
external_memory
|
||||||
|
custom_metric_obj
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user