[Breaking] Set output margin to True for custom objective. (#5564)

* Set output margin to True for custom objective in Python and R.

* Add a demo for writing multi-class custom objective function.

* Run tests on selected demos.
This commit is contained in:
Jiaming Yuan 2020-04-20 20:44:12 +08:00 committed by GitHub
parent fcbedcedf8
commit 9c1103e06c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 262 additions and 22 deletions

View File

@ -145,7 +145,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
if (is.null(obj)) {
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
} else {
pred <- predict(booster_handle, dtrain, training = TRUE)
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE)
gpair <- obj(pred, dtrain)
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
}

View File

@ -58,3 +58,20 @@ test_that("custom objective using DMatrix attr works", {
bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster")
})
test_that("custom objective with multi-class works", {
data = as.matrix(iris[, -5])
label = as.numeric(iris$Species) - 1
dtrain <- xgb.DMatrix(data = data, label = label)
nclasses <- 3
fake_softprob <- function(preds, dtrain) {
expect_true(all(matrix(preds) == 0.5))
grad <- rnorm(dim(as.matrix(preds))[1])
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
hess <- rnorm(dim(as.matrix(preds))[1])
return (list(grad = grad, hess = hess))
}
param$objective = fake_softprob
bst <- xgb.train(param, dtrain, 1, num_class=nclasses)
})

View File

@ -3,6 +3,7 @@ XGBoost Python Feature Walkthrough
* [Basic walkthrough of wrappers](basic_walkthrough.py)
* [Customize loss function, and evaluation metric](custom_objective.py)
* [Re-implement RMSLE as customized metric and objective](custom_rmsle.py)
* [Re-Implement `multi:softmax` objective as customized objective](custom_softmax.py)
* [Boosting from existing prediction](boost_from_prediction.py)
* [Predicting using first n trees](predict_first_ntree.py)
* [Generalized Linear Model](generalized_linear_model.py)

View File

@ -1,16 +1,22 @@
#!/usr/bin/python
#!/usr/bin/env python
import numpy as np
import scipy.sparse
import pickle
import xgboost as xgb
import os
### simple example
# Make sure the demo knows where to load the data.
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
XGBOOST_ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
DEMO_DIR = os.path.join(XGBOOST_ROOT_DIR, 'demo')
# simple example
# load file from text file, also binary buffer generated by xgboost
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
dtest = xgb.DMatrix('../data/agaricus.txt.test')
dtrain = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.test'))
# specify parameters via map, definition are same as c++ version
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
@ -20,12 +26,14 @@ bst = xgb.train(param, dtrain, num_round, watchlist)
# this is prediction
preds = bst.predict(dtest)
labels = dtest.get_label()
print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
print('error=%f' %
(sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) /
float(len(preds))))
bst.save_model('0001.model')
# dump model
bst.dump_model('dump.raw.txt')
# dump model with feature map
bst.dump_model('dump.nice.txt', '../data/featmap.txt')
bst.dump_model('dump.nice.txt', os.path.join(DEMO_DIR, 'data/featmap.txt'))
# save dmatrix into binary buffer
dtest.save_binary('dtest.buffer')
@ -50,14 +58,18 @@ assert np.sum(np.abs(preds3 - preds)) == 0
# build dmatrix from scipy.sparse
print('start running example of build DMatrix from scipy.sparse CSR Matrix')
labels = []
row = []; col = []; dat = []
row = []
col = []
dat = []
i = 0
for l in open('../data/agaricus.txt.train'):
for l in open(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train')):
arr = l.split()
labels.append(int(arr[0]))
for it in arr[1:]:
k, v = it.split(':')
row.append(i); col.append(int(k)); dat.append(float(v))
row.append(i)
col.append(int(k))
dat.append(float(v))
i += 1
csr = scipy.sparse.csr_matrix((dat, (row, col)))
dtrain = xgb.DMatrix(csr, label=labels)
@ -72,8 +84,8 @@ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
bst = xgb.train(param, dtrain, num_round, watchlist)
print('start running example of build DMatrix from numpy array')
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix in internal implementation
# then convert to DMatrix
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix
# in internal implementation then convert to DMatrix
npymat = csr.todense()
dtrain = xgb.DMatrix(npymat, label=labels)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@ -15,6 +15,7 @@ import numpy as np
import xgboost as xgb
from typing import Tuple, Dict, List
from time import time
import argparse
import matplotlib
from matplotlib import pyplot as plt
@ -150,12 +151,7 @@ def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict:
return results
if __name__ == '__main__':
dtrain, dtest = generate_data()
rmse_evals = native_rmse(dtrain, dtest)
rmsle_evals = native_rmsle(dtrain, dtest)
py_rmsle_evals = py_rmsle(dtrain, dtest)
def plot_history(rmse_evals, rmsle_evals, py_rmsle_evals):
fig, axs = plt.subplots(3, 1)
ax0: matplotlib.axes.Axes = axs[0]
ax1: matplotlib.axes.Axes = axs[1]
@ -177,3 +173,25 @@ if __name__ == '__main__':
plt.show()
plt.close()
def main(args):
dtrain, dtest = generate_data()
rmse_evals = native_rmse(dtrain, dtest)
rmsle_evals = native_rmsle(dtrain, dtest)
py_rmsle_evals = py_rmsle(dtrain, dtest)
if args.plot != 0:
plot_history(rmse_evals, rmsle_evals, py_rmsle_evals)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Arguments for custom RMSLE objective function demo.')
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,148 @@
'''Demo for creating customized multi-class objective function. This demo is
only applicable after (excluding) XGBoost 1.0.0, as before this version XGBoost
returns transformed prediction for multi-class objective function. More
details in comments.
'''
import numpy as np
import xgboost as xgb
from matplotlib import pyplot as plt
import argparse
np.random.seed(1994)
kRows = 100
kCols = 10
kClasses = 4 # number of classes
kRounds = 10 # number of boosting rounds.
# Generate some random data for demo.
X = np.random.randn(kRows, kCols)
y = np.random.randint(0, 4, size=kRows)
m = xgb.DMatrix(X, y)
def softmax(x):
'''Softmax function with x as input vector.'''
e = np.exp(x)
return e / np.sum(e)
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
'''Loss function. Computing the gradient and approximated hessian (diagonal).
Reimplements the `multi:softprob` inside XGBoost.
'''
labels = data.get_label()
if data.get_weight().size == 0:
# Use 1 as weight if we don't have custom weight.
weights = np.ones((kRows, 1), dtype=float)
else:
weights = data.get_weight()
# The prediction is of shape (rows, classes), each element in a row
# represents a raw prediction (leaf weight, hasn't gone through softmax
# yet). In XGBoost 1.0.0, the prediction is transformed by a softmax
# function, fixed in later versions.
assert predt.shape == (kRows, kClasses)
grad = np.zeros((kRows, kClasses), dtype=float)
hess = np.zeros((kRows, kClasses), dtype=float)
eps = 1e-6
# compute the gradient and hessian, slow iterations in Python, only
# suitable for demo. Also the one in native XGBoost core is more robust to
# numeric overflow as we don't do anything to mitigate the `exp` in
# `softmax` here.
for r in range(predt.shape[0]):
target = labels[r]
p = softmax(predt[r, :])
for c in range(predt.shape[1]):
assert target >= 0 or target <= kClasses
g = p[c] - 1.0 if c == target else p[c]
g = g * weights[r]
h = max((2.0 * p[c] * (1.0 - p[c]) * weights[r]).item(), eps)
grad[r, c] = g
hess[r, c] = h
# Right now (XGBoost 1.0.0), reshaping is necessary
grad = grad.reshape((kRows * kClasses, 1))
hess = hess.reshape((kRows * kClasses, 1))
return grad, hess
def predict(booster, X):
'''A customized prediction function that converts raw prediction to
target class.
'''
# Output margin means we want to obtain the raw prediction obtained from
# tree leaf weight.
predt = booster.predict(X, output_margin=True)
out = np.zeros(kRows)
for r in range(predt.shape[0]):
# the class with maximum prob (not strictly prob as it haven't gone
# through softmax yet so it doesn't sum to 1, but result is the same
# for argmax).
i = np.argmax(predt[r])
out[r] = i
return out
def plot_history(custom_results, native_results):
fig, axs = plt.subplots(2, 1)
ax0 = axs[0]
ax1 = axs[1]
x = np.arange(0, kRounds, 1)
ax0.plot(x, custom_results['train']['merror'], label='Custom objective')
ax0.legend()
ax1.plot(x, native_results['train']['merror'], label='multi:softmax')
ax1.legend()
plt.show()
def main(args):
custom_results = {}
# Use our custom objective function
booster_custom = xgb.train({'num_class': kClasses},
m,
num_boost_round=kRounds,
obj=softprob_obj,
evals_result=custom_results,
evals=[(m, 'train')])
predt_custom = predict(booster_custom, m)
native_results = {}
# Use the same objective function defined in XGBoost.
booster_native = xgb.train({'num_class': kClasses},
m,
num_boost_round=kRounds,
evals_result=native_results,
evals=[(m, 'train')])
predt_native = booster_native.predict(m)
# We are reimplementing the loss function in XGBoost, so it should
# be the same for normal cases.
assert np.all(predt_custom == predt_native)
if args.plot != 0:
plot_history(custom_results, native_results)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Arguments for custom softmax objective function demo.')
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)

View File

@ -14,7 +14,7 @@ concepts should be readily applicable to other language bindings.
* The customized functions defined here are only applicable to single node training.
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
subject to change hence beyond the scope of this tutorial.
* We also plan to re-design the interface for multi-classes objective in the future.
* We also plan to improve the interface for multi-classes objective in the future.
In the following sections, we will provide a step by step walk through of implementing
``Squared Log Error(SLE)`` objective function:
@ -136,3 +136,12 @@ Notice that the parameter ``disable_default_eval_metric`` is used to suppress th
in XGBoost.
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.
******************************
Multi-class objective function
******************************
A similiar demo for multi-class objective funtion is also available, see
`demo/guide-python/custom_softmax.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_
for details.

View File

@ -1367,7 +1367,7 @@ class Booster(object):
ctypes.c_int(iteration),
dtrain.handle))
else:
pred = self.predict(dtrain, training=True)
pred = self.predict(dtrain, output_margin=True, training=True)
grad, hess = fobj(pred, dtrain)
self.boost(dtrain, grad, hess)

View File

@ -295,6 +295,9 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
}
}
// update the trees
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
<< "Mismatching size between number of rows from input data and size of "
"gradient vector.";
for (auto& up : updaters_) {
up->Update(gpair, p_fmat, new_trees);
}

View File

@ -0,0 +1,32 @@
import os
import subprocess
import sys
import pytest
CURRENT_DIR = os.path.dirname(__file__)
ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
DEMO_DIR = os.path.join(ROOT_DIR, 'demo', 'guide-python')
def test_basic_walkthrough():
script = os.path.join(DEMO_DIR, 'basic_walkthrough.py')
cmd = ['python', script]
subprocess.check_call(cmd)
os.remove('dump.nice.txt')
os.remove('dump.raw.txt')
def test_custom_multiclass_objective():
script = os.path.join(DEMO_DIR, 'custom_softmax.py')
cmd = ['python', script, '--plot=0']
subprocess.check_call(cmd)
def test_custom_rmsle_objective():
major, minor = sys.version_info[:2]
if minor < 6:
pytest.skip('Skipping RMLSE test due to Python version being too low.')
script = os.path.join(DEMO_DIR, 'custom_rmsle.py')
cmd = ['python', script, '--plot=0']
subprocess.check_call(cmd)