[Breaking] Set output margin to True for custom objective. (#5564)
* Set output margin to True for custom objective in Python and R. * Add a demo for writing multi-class custom objective function. * Run tests on selected demos.
This commit is contained in:
parent
fcbedcedf8
commit
9c1103e06c
@ -145,7 +145,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
|
|||||||
if (is.null(obj)) {
|
if (is.null(obj)) {
|
||||||
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
|
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
|
||||||
} else {
|
} else {
|
||||||
pred <- predict(booster_handle, dtrain, training = TRUE)
|
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE)
|
||||||
gpair <- obj(pred, dtrain)
|
gpair <- obj(pred, dtrain)
|
||||||
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
|
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -58,3 +58,20 @@ test_that("custom objective using DMatrix attr works", {
|
|||||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||||
expect_equal(class(bst), "xgb.Booster")
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("custom objective with multi-class works", {
|
||||||
|
data = as.matrix(iris[, -5])
|
||||||
|
label = as.numeric(iris$Species) - 1
|
||||||
|
dtrain <- xgb.DMatrix(data = data, label = label)
|
||||||
|
nclasses <- 3
|
||||||
|
|
||||||
|
fake_softprob <- function(preds, dtrain) {
|
||||||
|
expect_true(all(matrix(preds) == 0.5))
|
||||||
|
grad <- rnorm(dim(as.matrix(preds))[1])
|
||||||
|
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
|
||||||
|
hess <- rnorm(dim(as.matrix(preds))[1])
|
||||||
|
return (list(grad = grad, hess = hess))
|
||||||
|
}
|
||||||
|
param$objective = fake_softprob
|
||||||
|
bst <- xgb.train(param, dtrain, 1, num_class=nclasses)
|
||||||
|
})
|
||||||
|
|||||||
@ -3,6 +3,7 @@ XGBoost Python Feature Walkthrough
|
|||||||
* [Basic walkthrough of wrappers](basic_walkthrough.py)
|
* [Basic walkthrough of wrappers](basic_walkthrough.py)
|
||||||
* [Customize loss function, and evaluation metric](custom_objective.py)
|
* [Customize loss function, and evaluation metric](custom_objective.py)
|
||||||
* [Re-implement RMSLE as customized metric and objective](custom_rmsle.py)
|
* [Re-implement RMSLE as customized metric and objective](custom_rmsle.py)
|
||||||
|
* [Re-Implement `multi:softmax` objective as customized objective](custom_softmax.py)
|
||||||
* [Boosting from existing prediction](boost_from_prediction.py)
|
* [Boosting from existing prediction](boost_from_prediction.py)
|
||||||
* [Predicting using first n trees](predict_first_ntree.py)
|
* [Predicting using first n trees](predict_first_ntree.py)
|
||||||
* [Generalized Linear Model](generalized_linear_model.py)
|
* [Generalized Linear Model](generalized_linear_model.py)
|
||||||
|
|||||||
@ -1,16 +1,22 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/env python
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
import pickle
|
import pickle
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
import os
|
||||||
|
|
||||||
### simple example
|
# Make sure the demo knows where to load the data.
|
||||||
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
XGBOOST_ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
|
||||||
|
DEMO_DIR = os.path.join(XGBOOST_ROOT_DIR, 'demo')
|
||||||
|
|
||||||
|
# simple example
|
||||||
# load file from text file, also binary buffer generated by xgboost
|
# load file from text file, also binary buffer generated by xgboost
|
||||||
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
dtrain = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train'))
|
||||||
dtest = xgb.DMatrix('../data/agaricus.txt.test')
|
dtest = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.test'))
|
||||||
|
|
||||||
# specify parameters via map, definition are same as c++ version
|
# specify parameters via map, definition are same as c++ version
|
||||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
|
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||||
|
|
||||||
# specify validations set to watch performance
|
# specify validations set to watch performance
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
@ -20,12 +26,14 @@ bst = xgb.train(param, dtrain, num_round, watchlist)
|
|||||||
# this is prediction
|
# this is prediction
|
||||||
preds = bst.predict(dtest)
|
preds = bst.predict(dtest)
|
||||||
labels = dtest.get_label()
|
labels = dtest.get_label()
|
||||||
print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
|
print('error=%f' %
|
||||||
|
(sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) /
|
||||||
|
float(len(preds))))
|
||||||
bst.save_model('0001.model')
|
bst.save_model('0001.model')
|
||||||
# dump model
|
# dump model
|
||||||
bst.dump_model('dump.raw.txt')
|
bst.dump_model('dump.raw.txt')
|
||||||
# dump model with feature map
|
# dump model with feature map
|
||||||
bst.dump_model('dump.nice.txt', '../data/featmap.txt')
|
bst.dump_model('dump.nice.txt', os.path.join(DEMO_DIR, 'data/featmap.txt'))
|
||||||
|
|
||||||
# save dmatrix into binary buffer
|
# save dmatrix into binary buffer
|
||||||
dtest.save_binary('dtest.buffer')
|
dtest.save_binary('dtest.buffer')
|
||||||
@ -50,14 +58,18 @@ assert np.sum(np.abs(preds3 - preds)) == 0
|
|||||||
# build dmatrix from scipy.sparse
|
# build dmatrix from scipy.sparse
|
||||||
print('start running example of build DMatrix from scipy.sparse CSR Matrix')
|
print('start running example of build DMatrix from scipy.sparse CSR Matrix')
|
||||||
labels = []
|
labels = []
|
||||||
row = []; col = []; dat = []
|
row = []
|
||||||
|
col = []
|
||||||
|
dat = []
|
||||||
i = 0
|
i = 0
|
||||||
for l in open('../data/agaricus.txt.train'):
|
for l in open(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train')):
|
||||||
arr = l.split()
|
arr = l.split()
|
||||||
labels.append(int(arr[0]))
|
labels.append(int(arr[0]))
|
||||||
for it in arr[1:]:
|
for it in arr[1:]:
|
||||||
k,v = it.split(':')
|
k, v = it.split(':')
|
||||||
row.append(i); col.append(int(k)); dat.append(float(v))
|
row.append(i)
|
||||||
|
col.append(int(k))
|
||||||
|
dat.append(float(v))
|
||||||
i += 1
|
i += 1
|
||||||
csr = scipy.sparse.csr_matrix((dat, (row, col)))
|
csr = scipy.sparse.csr_matrix((dat, (row, col)))
|
||||||
dtrain = xgb.DMatrix(csr, label=labels)
|
dtrain = xgb.DMatrix(csr, label=labels)
|
||||||
@ -72,8 +84,8 @@ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
|||||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
|
||||||
print('start running example of build DMatrix from numpy array')
|
print('start running example of build DMatrix from numpy array')
|
||||||
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix in internal implementation
|
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix
|
||||||
# then convert to DMatrix
|
# in internal implementation then convert to DMatrix
|
||||||
npymat = csr.todense()
|
npymat = csr.todense()
|
||||||
dtrain = xgb.DMatrix(npymat, label=labels)
|
dtrain = xgb.DMatrix(npymat, label=labels)
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import numpy as np
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from typing import Tuple, Dict, List
|
from typing import Tuple, Dict, List
|
||||||
from time import time
|
from time import time
|
||||||
|
import argparse
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
@ -150,12 +151,7 @@ def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def plot_history(rmse_evals, rmsle_evals, py_rmsle_evals):
|
||||||
dtrain, dtest = generate_data()
|
|
||||||
rmse_evals = native_rmse(dtrain, dtest)
|
|
||||||
rmsle_evals = native_rmsle(dtrain, dtest)
|
|
||||||
py_rmsle_evals = py_rmsle(dtrain, dtest)
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(3, 1)
|
fig, axs = plt.subplots(3, 1)
|
||||||
ax0: matplotlib.axes.Axes = axs[0]
|
ax0: matplotlib.axes.Axes = axs[0]
|
||||||
ax1: matplotlib.axes.Axes = axs[1]
|
ax1: matplotlib.axes.Axes = axs[1]
|
||||||
@ -177,3 +173,25 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
dtrain, dtest = generate_data()
|
||||||
|
rmse_evals = native_rmse(dtrain, dtest)
|
||||||
|
rmsle_evals = native_rmsle(dtrain, dtest)
|
||||||
|
py_rmsle_evals = py_rmsle(dtrain, dtest)
|
||||||
|
|
||||||
|
if args.plot != 0:
|
||||||
|
plot_history(rmse_evals, rmsle_evals, py_rmsle_evals)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Arguments for custom RMSLE objective function demo.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--plot',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Set to 0 to disable plotting the evaluation history.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|||||||
148
demo/guide-python/custom_softmax.py
Normal file
148
demo/guide-python/custom_softmax.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
'''Demo for creating customized multi-class objective function. This demo is
|
||||||
|
only applicable after (excluding) XGBoost 1.0.0, as before this version XGBoost
|
||||||
|
returns transformed prediction for multi-class objective function. More
|
||||||
|
details in comments.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xgboost as xgb
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
np.random.seed(1994)
|
||||||
|
|
||||||
|
kRows = 100
|
||||||
|
kCols = 10
|
||||||
|
kClasses = 4 # number of classes
|
||||||
|
|
||||||
|
kRounds = 10 # number of boosting rounds.
|
||||||
|
|
||||||
|
# Generate some random data for demo.
|
||||||
|
X = np.random.randn(kRows, kCols)
|
||||||
|
y = np.random.randint(0, 4, size=kRows)
|
||||||
|
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x):
|
||||||
|
'''Softmax function with x as input vector.'''
|
||||||
|
e = np.exp(x)
|
||||||
|
return e / np.sum(e)
|
||||||
|
|
||||||
|
|
||||||
|
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
|
||||||
|
'''Loss function. Computing the gradient and approximated hessian (diagonal).
|
||||||
|
Reimplements the `multi:softprob` inside XGBoost.
|
||||||
|
|
||||||
|
'''
|
||||||
|
labels = data.get_label()
|
||||||
|
if data.get_weight().size == 0:
|
||||||
|
# Use 1 as weight if we don't have custom weight.
|
||||||
|
weights = np.ones((kRows, 1), dtype=float)
|
||||||
|
else:
|
||||||
|
weights = data.get_weight()
|
||||||
|
|
||||||
|
# The prediction is of shape (rows, classes), each element in a row
|
||||||
|
# represents a raw prediction (leaf weight, hasn't gone through softmax
|
||||||
|
# yet). In XGBoost 1.0.0, the prediction is transformed by a softmax
|
||||||
|
# function, fixed in later versions.
|
||||||
|
assert predt.shape == (kRows, kClasses)
|
||||||
|
|
||||||
|
grad = np.zeros((kRows, kClasses), dtype=float)
|
||||||
|
hess = np.zeros((kRows, kClasses), dtype=float)
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
# compute the gradient and hessian, slow iterations in Python, only
|
||||||
|
# suitable for demo. Also the one in native XGBoost core is more robust to
|
||||||
|
# numeric overflow as we don't do anything to mitigate the `exp` in
|
||||||
|
# `softmax` here.
|
||||||
|
for r in range(predt.shape[0]):
|
||||||
|
target = labels[r]
|
||||||
|
p = softmax(predt[r, :])
|
||||||
|
for c in range(predt.shape[1]):
|
||||||
|
assert target >= 0 or target <= kClasses
|
||||||
|
g = p[c] - 1.0 if c == target else p[c]
|
||||||
|
g = g * weights[r]
|
||||||
|
h = max((2.0 * p[c] * (1.0 - p[c]) * weights[r]).item(), eps)
|
||||||
|
grad[r, c] = g
|
||||||
|
hess[r, c] = h
|
||||||
|
|
||||||
|
# Right now (XGBoost 1.0.0), reshaping is necessary
|
||||||
|
grad = grad.reshape((kRows * kClasses, 1))
|
||||||
|
hess = hess.reshape((kRows * kClasses, 1))
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
|
def predict(booster, X):
|
||||||
|
'''A customized prediction function that converts raw prediction to
|
||||||
|
target class.
|
||||||
|
|
||||||
|
'''
|
||||||
|
# Output margin means we want to obtain the raw prediction obtained from
|
||||||
|
# tree leaf weight.
|
||||||
|
predt = booster.predict(X, output_margin=True)
|
||||||
|
out = np.zeros(kRows)
|
||||||
|
for r in range(predt.shape[0]):
|
||||||
|
# the class with maximum prob (not strictly prob as it haven't gone
|
||||||
|
# through softmax yet so it doesn't sum to 1, but result is the same
|
||||||
|
# for argmax).
|
||||||
|
i = np.argmax(predt[r])
|
||||||
|
out[r] = i
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def plot_history(custom_results, native_results):
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
ax0 = axs[0]
|
||||||
|
ax1 = axs[1]
|
||||||
|
|
||||||
|
x = np.arange(0, kRounds, 1)
|
||||||
|
ax0.plot(x, custom_results['train']['merror'], label='Custom objective')
|
||||||
|
ax0.legend()
|
||||||
|
ax1.plot(x, native_results['train']['merror'], label='multi:softmax')
|
||||||
|
ax1.legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
custom_results = {}
|
||||||
|
# Use our custom objective function
|
||||||
|
booster_custom = xgb.train({'num_class': kClasses},
|
||||||
|
m,
|
||||||
|
num_boost_round=kRounds,
|
||||||
|
obj=softprob_obj,
|
||||||
|
evals_result=custom_results,
|
||||||
|
evals=[(m, 'train')])
|
||||||
|
|
||||||
|
predt_custom = predict(booster_custom, m)
|
||||||
|
|
||||||
|
native_results = {}
|
||||||
|
# Use the same objective function defined in XGBoost.
|
||||||
|
booster_native = xgb.train({'num_class': kClasses},
|
||||||
|
m,
|
||||||
|
num_boost_round=kRounds,
|
||||||
|
evals_result=native_results,
|
||||||
|
evals=[(m, 'train')])
|
||||||
|
predt_native = booster_native.predict(m)
|
||||||
|
|
||||||
|
# We are reimplementing the loss function in XGBoost, so it should
|
||||||
|
# be the same for normal cases.
|
||||||
|
assert np.all(predt_custom == predt_native)
|
||||||
|
|
||||||
|
if args.plot != 0:
|
||||||
|
plot_history(custom_results, native_results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Arguments for custom softmax objective function demo.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--plot',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Set to 0 to disable plotting the evaluation history.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@ -14,7 +14,7 @@ concepts should be readily applicable to other language bindings.
|
|||||||
* The customized functions defined here are only applicable to single node training.
|
* The customized functions defined here are only applicable to single node training.
|
||||||
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
|
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
|
||||||
subject to change hence beyond the scope of this tutorial.
|
subject to change hence beyond the scope of this tutorial.
|
||||||
* We also plan to re-design the interface for multi-classes objective in the future.
|
* We also plan to improve the interface for multi-classes objective in the future.
|
||||||
|
|
||||||
In the following sections, we will provide a step by step walk through of implementing
|
In the following sections, we will provide a step by step walk through of implementing
|
||||||
``Squared Log Error(SLE)`` objective function:
|
``Squared Log Error(SLE)`` objective function:
|
||||||
@ -136,3 +136,12 @@ Notice that the parameter ``disable_default_eval_metric`` is used to suppress th
|
|||||||
in XGBoost.
|
in XGBoost.
|
||||||
|
|
||||||
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.
|
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.
|
||||||
|
|
||||||
|
|
||||||
|
******************************
|
||||||
|
Multi-class objective function
|
||||||
|
******************************
|
||||||
|
|
||||||
|
A similiar demo for multi-class objective funtion is also available, see
|
||||||
|
`demo/guide-python/custom_softmax.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_
|
||||||
|
for details.
|
||||||
|
|||||||
@ -1367,7 +1367,7 @@ class Booster(object):
|
|||||||
ctypes.c_int(iteration),
|
ctypes.c_int(iteration),
|
||||||
dtrain.handle))
|
dtrain.handle))
|
||||||
else:
|
else:
|
||||||
pred = self.predict(dtrain, training=True)
|
pred = self.predict(dtrain, output_margin=True, training=True)
|
||||||
grad, hess = fobj(pred, dtrain)
|
grad, hess = fobj(pred, dtrain)
|
||||||
self.boost(dtrain, grad, hess)
|
self.boost(dtrain, grad, hess)
|
||||||
|
|
||||||
|
|||||||
@ -295,6 +295,9 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// update the trees
|
// update the trees
|
||||||
|
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
|
||||||
|
<< "Mismatching size between number of rows from input data and size of "
|
||||||
|
"gradient vector.";
|
||||||
for (auto& up : updaters_) {
|
for (auto& up : updaters_) {
|
||||||
up->Update(gpair, p_fmat, new_trees);
|
up->Update(gpair, p_fmat, new_trees);
|
||||||
}
|
}
|
||||||
|
|||||||
32
tests/python/test_demos.py
Normal file
32
tests/python/test_demos.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
CURRENT_DIR = os.path.dirname(__file__)
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
|
||||||
|
DEMO_DIR = os.path.join(ROOT_DIR, 'demo', 'guide-python')
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_walkthrough():
|
||||||
|
script = os.path.join(DEMO_DIR, 'basic_walkthrough.py')
|
||||||
|
cmd = ['python', script]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
os.remove('dump.nice.txt')
|
||||||
|
os.remove('dump.raw.txt')
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_multiclass_objective():
|
||||||
|
script = os.path.join(DEMO_DIR, 'custom_softmax.py')
|
||||||
|
cmd = ['python', script, '--plot=0']
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_rmsle_objective():
|
||||||
|
major, minor = sys.version_info[:2]
|
||||||
|
if minor < 6:
|
||||||
|
pytest.skip('Skipping RMLSE test due to Python version being too low.')
|
||||||
|
script = os.path.join(DEMO_DIR, 'custom_rmsle.py')
|
||||||
|
cmd = ['python', script, '--plot=0']
|
||||||
|
subprocess.check_call(cmd)
|
||||||
Loading…
x
Reference in New Issue
Block a user