diff --git a/demo/guide-python/README.md b/demo/guide-python/README.md index a1fe84b27..42d2f77dc 100644 --- a/demo/guide-python/README.md +++ b/demo/guide-python/README.md @@ -2,6 +2,7 @@ XGBoost Python Feature Walkthrough ================================== * [Basic walkthrough of wrappers](basic_walkthrough.py) * [Customize loss function, and evaluation metric](custom_objective.py) +* [Re-implement RMSLE as customized metric and objective](custom_rmsle.py) * [Boosting from existing prediction](boost_from_prediction.py) * [Predicting using first n trees](predict_first_ntree.py) * [Generalized Linear Model](generalized_linear_model.py) diff --git a/demo/guide-python/custom_objective.py b/demo/guide-python/custom_objective.py index 3cc73c0e3..5bbceccfb 100755 --- a/demo/guide-python/custom_objective.py +++ b/demo/guide-python/custom_objective.py @@ -16,8 +16,9 @@ param = {'max_depth': 2, 'eta': 1, 'silent': 1} watchlist = [(dtest, 'eval'), (dtrain, 'train')] num_round = 2 -# user define objective function, given prediction, return gradient and second order gradient -# this is log likelihood loss + +# user define objective function, given prediction, return gradient and second +# order gradient this is log likelihood loss def logregobj(preds, dtrain): labels = dtrain.get_label() preds = 1.0 / (1.0 + np.exp(-preds)) @@ -25,18 +26,24 @@ def logregobj(preds, dtrain): hess = preds * (1.0 - preds) return grad, hess + # user defined evaluation function, return a pair metric_name, result -# NOTE: when you do customized loss function, the default prediction value is margin -# this may make builtin evaluation metric not function properly -# for example, we are doing logistic loss, the prediction is score before logistic transformation -# the builtin evaluation error assumes input is after logistic transformation -# Take this in mind when you use the customization, and maybe you need write customized evaluation function + +# NOTE: when you do customized loss function, the default prediction value is +# margin. this may make builtin evaluation metric not function properly for +# example, we are doing logistic loss, the prediction is score before logistic +# transformation the builtin evaluation error assumes input is after logistic +# transformation Take this in mind when you use the customization, and maybe +# you need write customized evaluation function def evalerror(preds, dtrain): labels = dtrain.get_label() - # return a pair metric_name, result. The metric name must not contain a colon (:) or a space - # since preds are margin(before logistic transformation, cutoff at 0) + # return a pair metric_name, result. The metric name must not contain a + # colon (:) or a space since preds are margin(before logistic + # transformation, cutoff at 0) return 'my-error', float(sum(labels != (preds > 0.0))) / len(labels) + # training with customized objective, we can also do step by step training # simply look at xgboost.py's implementation of train -bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj, feval=evalerror) +bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj, + feval=evalerror) diff --git a/demo/guide-python/custom_rmsle.py b/demo/guide-python/custom_rmsle.py new file mode 100644 index 000000000..23d6b2359 --- /dev/null +++ b/demo/guide-python/custom_rmsle.py @@ -0,0 +1,179 @@ +'''Demo for defining customized metric and objective. Notice that for +simplicity reason weight is not used in following example. In this +script, we implement the Squared Log Error (SLE) objective and RMSLE metric as customized +functions, then compare it with native implementation in XGBoost. + +See doc/tutorials/custom_metric_obj.rst for a step by step +walkthrough, with other details. + +The `SLE` objective reduces impact of outliers in training dataset, +hence here we also compare its performance with standard squared +error. + +''' +import numpy as np +import xgboost as xgb +from typing import Tuple, Dict, List +from time import time +import matplotlib +from matplotlib import pyplot as plt + +# shape of generated data. +kRows = 4096 +kCols = 16 + +kOutlier = 10000 # mean of generated outliers +kNumberOfOutliers = 64 + +kRatio = 0.7 +kSeed = 1994 + +kBoostRound = 20 + +np.random.seed(seed=kSeed) + + +def generate_data() -> Tuple[xgb.DMatrix, xgb.DMatrix]: + '''Generate data containing outliers.''' + x = np.random.randn(kRows, kCols) + y = np.random.randn(kRows) + y += np.abs(np.min(y)) + + # Create outliers + for i in range(0, kNumberOfOutliers): + ind = np.random.randint(0, len(y)-1) + y[ind] += np.random.randint(0, kOutlier) + + train_portion = int(kRows * kRatio) + + # rmsle requires all label be greater than -1. + assert np.all(y > -1.0) + + train_x: np.ndarray = x[: train_portion] + train_y: np.ndarray = y[: train_portion] + dtrain = xgb.DMatrix(train_x, label=train_y) + + test_x = x[train_portion:] + test_y = y[train_portion:] + dtest = xgb.DMatrix(test_x, label=test_y) + return dtrain, dtest + + +def native_rmse(dtrain: xgb.DMatrix, + dtest: xgb.DMatrix) -> Dict[str, Dict[str, List[float]]]: + '''Train using native implementation of Root Mean Squared Loss.''' + print('Squared Error') + squared_error = { + 'objective': 'reg:squarederror', + 'eval_metric': 'rmse', + 'tree_method': 'hist', + 'seed': kSeed + } + start = time() + results: Dict[str, Dict[str, List[float]]] = {} + xgb.train(squared_error, + dtrain=dtrain, + num_boost_round=kBoostRound, + evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], + evals_result=results) + print('Finished Squared Error in:', time() - start, '\n') + return results + + +def native_rmsle(dtrain: xgb.DMatrix, + dtest: xgb.DMatrix) -> Dict[str, Dict[str, List[float]]]: + '''Train using native implementation of Squared Log Error.''' + print('Squared Log Error') + results: Dict[str, Dict[str, List[float]]] = {} + squared_log_error = { + 'objective': 'reg:squaredlogerror', + 'eval_metric': 'rmsle', + 'tree_method': 'hist', + 'seed': kSeed + } + start = time() + xgb.train(squared_log_error, + dtrain=dtrain, + num_boost_round=kBoostRound, + evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], + evals_result=results) + print('Finished Squared Log Error in:', time() - start) + return results + + +def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict: + '''Train using Python implementation of Squared Log Error.''' + def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: + '''Compute the gradient squared log error.''' + y = dtrain.get_label() + return (np.log1p(predt) - np.log1p(y)) / (predt + 1) + + def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: + '''Compute the hessian for squared log error.''' + y = dtrain.get_label() + return ((-np.log1p(predt) + np.log1p(y) + 1) / + np.power(predt + 1, 2)) + + def squared_log(predt: np.ndarray, + dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]: + '''Squared Log Error objective. A simplified version for RMSLE used as + objective function. + + :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2` + + ''' + predt[predt < -1] = -1 + 1e-6 + grad = gradient(predt, dtrain) + hess = hessian(predt, dtrain) + return grad, hess + + def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: + ''' Root mean squared log error metric. + + :math:`\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}` + ''' + y = dtrain.get_label() + predt[predt < -1] = -1 + 1e-6 + elements = np.power(np.log1p(y) - np.log1p(predt), 2) + return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y))) + + results: Dict[str, Dict[str, List[float]]] = {} + xgb.train({'tree_method': 'hist', 'seed': kSeed, + 'disable_default_eval_metric': 1}, + dtrain=dtrain, + num_boost_round=kBoostRound, + obj=squared_log, + feval=rmsle, + evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], + evals_result=results) + + return results + + +if __name__ == '__main__': + dtrain, dtest = generate_data() + rmse_evals = native_rmse(dtrain, dtest) + rmsle_evals = native_rmsle(dtrain, dtest) + py_rmsle_evals = py_rmsle(dtrain, dtest) + + fig, axs = plt.subplots(3, 1) + ax0: matplotlib.axes.Axes = axs[0] + ax1: matplotlib.axes.Axes = axs[1] + ax2: matplotlib.axes.Axes = axs[2] + + x = np.arange(0, kBoostRound, 1) + + ax0.plot(x, rmse_evals['dtrain']['rmse'], label='train-RMSE') + ax0.plot(x, rmse_evals['dtest']['rmse'], label='test-RMSE') + ax0.legend() + + ax1.plot(x, rmsle_evals['dtrain']['rmsle'], label='train-native-RMSLE') + ax1.plot(x, rmsle_evals['dtest']['rmsle'], label='test-native-RMSLE') + ax1.legend() + + ax2.plot(x, py_rmsle_evals['dtrain']['PyRMSLE'], label='train-PyRMSLE') + ax2.plot(x, py_rmsle_evals['dtest']['PyRMSLE'], label='test-PyRMSLE') + ax2.legend() + + plt.show() + plt.close() diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 0f5e6317a..4aefc1d6f 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -225,4 +225,4 @@ Many thanks to the following contributors (alphabetical order): * Shankara Rao Thejaswi Nanditale * Vinay Deshpande -Please report bugs to the user forum https://discuss.xgboost.ai/. +Please report bugs to the XGBoost issues list: https://github.com/dmlc/xgboost/issues. For general questions please visit our user form: https://discuss.xgboost.ai/. diff --git a/doc/jvm/java_intro.rst b/doc/jvm/java_intro.rst index 57a5866f9..d20eab1e5 100644 --- a/doc/jvm/java_intro.rst +++ b/doc/jvm/java_intro.rst @@ -25,27 +25,27 @@ supported. * Pass arrays to DMatrix constructor to load from sparse matrix. Suppose we have a sparse matrix - + .. code-block:: none - + 1 0 2 0 4 0 0 3 3 1 2 0 - + We can express the sparse matrix in `Compressed Sparse Row (CSR) `_ format: - + .. code-block:: java - + long[] rowHeaders = new long[] {0,2,4,7}; float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f}; int[] colIndex = new int[] {0,2,0,3,0,1,2}; int numColumn = 4; DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, numColumn); - + ... or in `Compressed Sparse Column (CSC) `_ format: - + .. code-block:: java - + long[] colHeaders = new long[] {0,3,4,6,7}; float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f}; int[] rowIndex = new int[] {0,1,2,2,0,2,1}; @@ -157,4 +157,3 @@ After training and loading a model, you can use it to make prediction for other float[][] predicts = booster.predict(dtest); // predict leaf float[][] leafPredicts = booster.predictLeaf(dtest, 0); - diff --git a/doc/tutorials/custom_metric_obj.rst b/doc/tutorials/custom_metric_obj.rst new file mode 100644 index 000000000..9722d25ca --- /dev/null +++ b/doc/tutorials/custom_metric_obj.rst @@ -0,0 +1,138 @@ +###################################### +Custom Objective and Evaluation Metric +###################################### + +XGBoost is designed to be an extensible library. One way to extend it is by providing our +own objective function for training and corresponding metric for performance monitoring. +This document introduces implementing a customized elementwise evaluation metric and +objective for XGBoost. Although the introduction uses Python for demonstration, the +concepts should be readily applicable to other language bindings. + +.. note:: + + * The ranking task does not support customized functions. + * The customized functions defined here are only applicable to single node training. + Distributed environment requires syncing with ``xgboost.rabit``, the interface is + subject to change hence beyond the scope of this tutorial. + * We also plan to re-design the interface for multi-classes objective in the future. + +In the following sections, we will provide a step by step walk through of implementing +``Squared Log Error(SLE)`` objective function: + +.. math:: + \frac{1}{2}[log(pred + 1) - log(label + 1)]^2 + +and its default metric ``Root Mean Squared Log Error(RMSLE)``: + +.. math:: + \sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2} + +Although XGBoost has native support for said functions, using it for demonstration +provides us the opportunity of comparing the result from our own implementation and the +one from XGBoost internal for learning purposes. After finishing this tutorial, we should +be able to provide our own functions for rapid experiments. + +***************************** +Customized Objective Function +***************************** + +During model training, the objective function plays an important role: provide gradient +information, both first and second order gradient, based on model predictions and observed +data labels (or targets). Therefore, a valid objective function should accept two inputs, +namely prediction and labels. For implementing ``SLE``, we define: + +.. code-block:: python + + import numpy as np + import xgboost as xgb + + def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: + '''Compute the gradient squared log error.''' + y = dtrain.get_label() + return (np.log1p(predt) - np.log1p(y)) / (predt + 1) + + def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: + '''Compute the hessian for squared log error.''' + y = dtrain.get_label() + return ((-np.log1p(predt) + np.log1p(y) + 1) / + np.power(predt + 1, 2)) + + def squared_log(predt: np.ndarray, + dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]: + '''Squared Log Error objective. A simplified version for RMSLE used as + objective function. + ''' + predt[predt < -1] = -1 + 1e-6 + grad = gradient(predt, dtrain) + hess = hessian(predt, dtrain) + return grad, hess + + +In the above code snippet, ``squared_log`` is the objective function we want. It accepts a +numpy array ``predt`` as model prediction, and the training DMatrix for obtaining required +information, including labels and weights (not used here). This objective is then used as +a callback function for XGBoost during training by passing it as an argument to +``xgb.train``: + +.. code-block:: python + + xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine. + dtrain=dtrain, + num_boost_round=10, + obj=squared_log) + +Notice that in our definition of the objective, whether we subtract the labels from the +prediction or the other way around is important. If you find the training error goes up +instead of down, this might be the reason. + + +************************** +Customized Metric Function +************************** + +So after having a customized objective, we might also need a corresponding metric to +monitor our model's performance. As mentioned above, the default metric for ``SLE`` is +``RMSLE``. Similarly we define another callback like function as the new metric: + +.. code-block:: python + + def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: + ''' Root mean squared log error metric.''' + y = dtrain.get_label() + predt[predt < -1] = -1 + 1e-6 + elements = np.power(np.log1p(y) - np.log1p(predt), 2) + return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y))) + +Since we are demonstrating in Python, the metric or objective needs not be a function, +any callable object should suffice. Similarly to the objective function, our metric also +accepts ``predt`` and ``dtrain`` as inputs, but returns the name of metric itself and a +floating point value as result. After passing it into XGBoost as argument of ``feval`` +parameter: + +.. code-block:: python + + xgb.train({'tree_method': 'hist', 'seed': 1994, + 'disable_default_eval_metric': 1}, + dtrain=dtrain, + num_boost_round=10, + obj=squared_log, + feval=rmsle, + evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], + evals_result=results) + +We will be able to see XGBoost printing something like: + +.. code-block:: + + [0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487 + [1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899 + [2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629 + [3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871 + [4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186 + [5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057 + ... + +Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric +in XGBoost. + +For fully reproducible source code and comparison plots, see `custom_rmsle.py `_. diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 280ee1c01..a5c27f3f5 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -19,4 +19,4 @@ See `Awesome XGBoost `_ for mo input_format param_tuning external_memory - + custom_metric_obj