Doc and demo for customized metric and obj. (#4598)

Co-Authored-By: Theodore Vasiloudis <theodoros.vasiloudis@gmail.com>
This commit is contained in:
Jiaming Yuan 2019-06-26 16:13:12 +08:00 committed by GitHub
parent 8bdf15120a
commit 5b2f805e74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 345 additions and 21 deletions

View File

@ -2,6 +2,7 @@ XGBoost Python Feature Walkthrough
==================================
* [Basic walkthrough of wrappers](basic_walkthrough.py)
* [Customize loss function, and evaluation metric](custom_objective.py)
* [Re-implement RMSLE as customized metric and objective](custom_rmsle.py)
* [Boosting from existing prediction](boost_from_prediction.py)
* [Predicting using first n trees](predict_first_ntree.py)
* [Generalized Linear Model](generalized_linear_model.py)

View File

@ -16,8 +16,9 @@ param = {'max_depth': 2, 'eta': 1, 'silent': 1}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
# user define objective function, given prediction, return gradient and second order gradient
# this is log likelihood loss
# user define objective function, given prediction, return gradient and second
# order gradient this is log likelihood loss
def logregobj(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
@ -25,18 +26,24 @@ def logregobj(preds, dtrain):
hess = preds * (1.0 - preds)
return grad, hess
# user defined evaluation function, return a pair metric_name, result
# NOTE: when you do customized loss function, the default prediction value is margin
# this may make builtin evaluation metric not function properly
# for example, we are doing logistic loss, the prediction is score before logistic transformation
# the builtin evaluation error assumes input is after logistic transformation
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
# NOTE: when you do customized loss function, the default prediction value is
# margin. this may make builtin evaluation metric not function properly for
# example, we are doing logistic loss, the prediction is score before logistic
# transformation the builtin evaluation error assumes input is after logistic
# transformation Take this in mind when you use the customization, and maybe
# you need write customized evaluation function
def evalerror(preds, dtrain):
labels = dtrain.get_label()
# return a pair metric_name, result. The metric name must not contain a colon (:) or a space
# since preds are margin(before logistic transformation, cutoff at 0)
# return a pair metric_name, result. The metric name must not contain a
# colon (:) or a space since preds are margin(before logistic
# transformation, cutoff at 0)
return 'my-error', float(sum(labels != (preds > 0.0))) / len(labels)
# training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train
bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj, feval=evalerror)
bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj,
feval=evalerror)

View File

@ -0,0 +1,179 @@
'''Demo for defining customized metric and objective. Notice that for
simplicity reason weight is not used in following example. In this
script, we implement the Squared Log Error (SLE) objective and RMSLE metric as customized
functions, then compare it with native implementation in XGBoost.
See doc/tutorials/custom_metric_obj.rst for a step by step
walkthrough, with other details.
The `SLE` objective reduces impact of outliers in training dataset,
hence here we also compare its performance with standard squared
error.
'''
import numpy as np
import xgboost as xgb
from typing import Tuple, Dict, List
from time import time
import matplotlib
from matplotlib import pyplot as plt
# shape of generated data.
kRows = 4096
kCols = 16
kOutlier = 10000 # mean of generated outliers
kNumberOfOutliers = 64
kRatio = 0.7
kSeed = 1994
kBoostRound = 20
np.random.seed(seed=kSeed)
def generate_data() -> Tuple[xgb.DMatrix, xgb.DMatrix]:
'''Generate data containing outliers.'''
x = np.random.randn(kRows, kCols)
y = np.random.randn(kRows)
y += np.abs(np.min(y))
# Create outliers
for i in range(0, kNumberOfOutliers):
ind = np.random.randint(0, len(y)-1)
y[ind] += np.random.randint(0, kOutlier)
train_portion = int(kRows * kRatio)
# rmsle requires all label be greater than -1.
assert np.all(y > -1.0)
train_x: np.ndarray = x[: train_portion]
train_y: np.ndarray = y[: train_portion]
dtrain = xgb.DMatrix(train_x, label=train_y)
test_x = x[train_portion:]
test_y = y[train_portion:]
dtest = xgb.DMatrix(test_x, label=test_y)
return dtrain, dtest
def native_rmse(dtrain: xgb.DMatrix,
dtest: xgb.DMatrix) -> Dict[str, Dict[str, List[float]]]:
'''Train using native implementation of Root Mean Squared Loss.'''
print('Squared Error')
squared_error = {
'objective': 'reg:squarederror',
'eval_metric': 'rmse',
'tree_method': 'hist',
'seed': kSeed
}
start = time()
results: Dict[str, Dict[str, List[float]]] = {}
xgb.train(squared_error,
dtrain=dtrain,
num_boost_round=kBoostRound,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
print('Finished Squared Error in:', time() - start, '\n')
return results
def native_rmsle(dtrain: xgb.DMatrix,
dtest: xgb.DMatrix) -> Dict[str, Dict[str, List[float]]]:
'''Train using native implementation of Squared Log Error.'''
print('Squared Log Error')
results: Dict[str, Dict[str, List[float]]] = {}
squared_log_error = {
'objective': 'reg:squaredlogerror',
'eval_metric': 'rmsle',
'tree_method': 'hist',
'seed': kSeed
}
start = time()
xgb.train(squared_log_error,
dtrain=dtrain,
num_boost_round=kBoostRound,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
print('Finished Squared Log Error in:', time() - start)
return results
def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict:
'''Train using Python implementation of Squared Log Error.'''
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the gradient squared log error.'''
y = dtrain.get_label()
return (np.log1p(predt) - np.log1p(y)) / (predt + 1)
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the hessian for squared log error.'''
y = dtrain.get_label()
return ((-np.log1p(predt) + np.log1p(y) + 1) /
np.power(predt + 1, 2))
def squared_log(predt: np.ndarray,
dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
'''Squared Log Error objective. A simplified version for RMSLE used as
objective function.
:math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`
'''
predt[predt < -1] = -1 + 1e-6
grad = gradient(predt, dtrain)
hess = hessian(predt, dtrain)
return grad, hess
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
''' Root mean squared log error metric.
:math:`\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}`
'''
y = dtrain.get_label()
predt[predt < -1] = -1 + 1e-6
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
results: Dict[str, Dict[str, List[float]]] = {}
xgb.train({'tree_method': 'hist', 'seed': kSeed,
'disable_default_eval_metric': 1},
dtrain=dtrain,
num_boost_round=kBoostRound,
obj=squared_log,
feval=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
return results
if __name__ == '__main__':
dtrain, dtest = generate_data()
rmse_evals = native_rmse(dtrain, dtest)
rmsle_evals = native_rmsle(dtrain, dtest)
py_rmsle_evals = py_rmsle(dtrain, dtest)
fig, axs = plt.subplots(3, 1)
ax0: matplotlib.axes.Axes = axs[0]
ax1: matplotlib.axes.Axes = axs[1]
ax2: matplotlib.axes.Axes = axs[2]
x = np.arange(0, kBoostRound, 1)
ax0.plot(x, rmse_evals['dtrain']['rmse'], label='train-RMSE')
ax0.plot(x, rmse_evals['dtest']['rmse'], label='test-RMSE')
ax0.legend()
ax1.plot(x, rmsle_evals['dtrain']['rmsle'], label='train-native-RMSLE')
ax1.plot(x, rmsle_evals['dtest']['rmsle'], label='test-native-RMSLE')
ax1.legend()
ax2.plot(x, py_rmsle_evals['dtrain']['PyRMSLE'], label='train-PyRMSLE')
ax2.plot(x, py_rmsle_evals['dtest']['PyRMSLE'], label='test-PyRMSLE')
ax2.legend()
plt.show()
plt.close()

View File

@ -225,4 +225,4 @@ Many thanks to the following contributors (alphabetical order):
* Shankara Rao Thejaswi Nanditale
* Vinay Deshpande
Please report bugs to the user forum https://discuss.xgboost.ai/.
Please report bugs to the XGBoost issues list: https://github.com/dmlc/xgboost/issues. For general questions please visit our user form: https://discuss.xgboost.ai/.

View File

@ -157,4 +157,3 @@ After training and loading a model, you can use it to make prediction for other
float[][] predicts = booster.predict(dtest);
// predict leaf
float[][] leafPredicts = booster.predictLeaf(dtest, 0);

View File

@ -0,0 +1,138 @@
######################################
Custom Objective and Evaluation Metric
######################################
XGBoost is designed to be an extensible library. One way to extend it is by providing our
own objective function for training and corresponding metric for performance monitoring.
This document introduces implementing a customized elementwise evaluation metric and
objective for XGBoost. Although the introduction uses Python for demonstration, the
concepts should be readily applicable to other language bindings.
.. note::
* The ranking task does not support customized functions.
* The customized functions defined here are only applicable to single node training.
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
subject to change hence beyond the scope of this tutorial.
* We also plan to re-design the interface for multi-classes objective in the future.
In the following sections, we will provide a step by step walk through of implementing
``Squared Log Error(SLE)`` objective function:
.. math::
\frac{1}{2}[log(pred + 1) - log(label + 1)]^2
and its default metric ``Root Mean Squared Log Error(RMSLE)``:
.. math::
\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}
Although XGBoost has native support for said functions, using it for demonstration
provides us the opportunity of comparing the result from our own implementation and the
one from XGBoost internal for learning purposes. After finishing this tutorial, we should
be able to provide our own functions for rapid experiments.
*****************************
Customized Objective Function
*****************************
During model training, the objective function plays an important role: provide gradient
information, both first and second order gradient, based on model predictions and observed
data labels (or targets). Therefore, a valid objective function should accept two inputs,
namely prediction and labels. For implementing ``SLE``, we define:
.. code-block:: python
import numpy as np
import xgboost as xgb
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the gradient squared log error.'''
y = dtrain.get_label()
return (np.log1p(predt) - np.log1p(y)) / (predt + 1)
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the hessian for squared log error.'''
y = dtrain.get_label()
return ((-np.log1p(predt) + np.log1p(y) + 1) /
np.power(predt + 1, 2))
def squared_log(predt: np.ndarray,
dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
'''Squared Log Error objective. A simplified version for RMSLE used as
objective function.
'''
predt[predt < -1] = -1 + 1e-6
grad = gradient(predt, dtrain)
hess = hessian(predt, dtrain)
return grad, hess
In the above code snippet, ``squared_log`` is the objective function we want. It accepts a
numpy array ``predt`` as model prediction, and the training DMatrix for obtaining required
information, including labels and weights (not used here). This objective is then used as
a callback function for XGBoost during training by passing it as an argument to
``xgb.train``:
.. code-block:: python
xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine.
dtrain=dtrain,
num_boost_round=10,
obj=squared_log)
Notice that in our definition of the objective, whether we subtract the labels from the
prediction or the other way around is important. If you find the training error goes up
instead of down, this might be the reason.
**************************
Customized Metric Function
**************************
So after having a customized objective, we might also need a corresponding metric to
monitor our model's performance. As mentioned above, the default metric for ``SLE`` is
``RMSLE``. Similarly we define another callback like function as the new metric:
.. code-block:: python
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
''' Root mean squared log error metric.'''
y = dtrain.get_label()
predt[predt < -1] = -1 + 1e-6
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
Since we are demonstrating in Python, the metric or objective needs not be a function,
any callable object should suffice. Similarly to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of metric itself and a
floating point value as result. After passing it into XGBoost as argument of ``feval``
parameter:
.. code-block:: python
xgb.train({'tree_method': 'hist', 'seed': 1994,
'disable_default_eval_metric': 1},
dtrain=dtrain,
num_boost_round=10,
obj=squared_log,
feval=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
We will be able to see XGBoost printing something like:
.. code-block::
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
...
Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric
in XGBoost.
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.

View File

@ -19,4 +19,4 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
input_format
param_tuning
external_memory
custom_metric_obj