Merge pull request #833 from AlexisMignon/master

Added the possibility to use custom objective function in the sklearn…
This commit is contained in:
Yuan (Terry) Tang 2016-02-18 09:36:38 -06:00
commit 75d23c8bb2
2 changed files with 188 additions and 57 deletions

View File

@ -11,6 +11,39 @@ from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, LabelEncoder)
def _objective_decorator(func):
"""Decorate an objective function
Converts an objective function using the typical sklearn metrics
signature so that it is usable with ``xgboost.training.train``
Parameters
----------
func: callable
Expects a callable with signature ``func(y_true, y_pred)``:
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples]
The predicted values
Returns
-------
new_func: callable
The new objective function as expected by ``xgboost.training.train``.
The signature is ``new_func(preds, dmatrix)``:
preds: array_like, shape [n_samples]
The predicted values
dmatrix: ``DMatrix``
The training set from which the labels will be extracted using
``dmatrix.get_label()``
"""
def inner(preds, dmatrix):
labels = dmatrix.get_label()
return func(labels, preds)
return inner
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
"""Implementation of the Scikit-Learn API for XGBoost.
@ -25,9 +58,9 @@ class XGBModel(XGBModelBase):
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
objective : string
Specify the learning task and the corresponding learning objective.
objective : string or callable
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
nthread : int
Number of parallel threads used to run xgboost.
gamma : float
@ -56,6 +89,22 @@ class XGBModel(XGBModelBase):
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
Note
----
A custom objective function can be provided for the ``objective``
parameter. In this case, it should have the signature
``objective(y_true, y_pred) -> grad, hess``:
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples]
The predicted values
grad: array_like of shape [n_samples]
The value of the gradient for each sample point.
hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear",
@ -174,6 +223,12 @@ class XGBModel(XGBModelBase):
params = self.get_xgb_params()
if callable(self.objective):
obj = _objective_decorator(self.objective)
params["objective"] = "reg:linear"
else:
obj = None
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
@ -184,7 +239,7 @@ class XGBModel(XGBModelBase):
self._Booster = train(params, trainDmatrix,
self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, feval=feval,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose)
if evals_result:
@ -302,13 +357,21 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
evals_result = {}
self.classes_ = list(np.unique(y))
self.n_classes_ = len(self.classes_)
xgb_options = self.get_xgb_params()
if callable(self.objective):
obj = _objective_decorator(self.objective)
# Use default value. Is it really not used ?
xgb_options["objective"] = "binary:logistic"
else:
obj = None
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance
self.objective = "multi:softprob"
xgb_options = self.get_xgb_params()
xgb_options["objective"] = "multi:softprob"
xgb_options['num_class'] = self.n_classes_
else:
xgb_options = self.get_xgb_params()
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
@ -341,7 +404,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, feval=feval,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose)
if evals_result:

View File

@ -1,6 +1,6 @@
import xgboost as xgb
import numpy as np
from sklearn.cross_validation import KFold, train_test_split
from sklearn.cross_validation import KFold
from sklearn.metrics import mean_squared_error
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_iris, load_digits, load_boston
@ -8,57 +8,125 @@ from sklearn.datasets import load_iris, load_digits, load_boston
rng = np.random.RandomState(1994)
def test_binary_classification():
digits = load_digits(2)
y = digits['target']
X = digits['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.1
digits = load_digits(2)
y = digits['target']
X = digits['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.1
def test_multiclass_classification():
iris = load_iris()
y = iris['target']
X = iris['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
labels = y[test_index]
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.4
iris = load_iris()
y = iris['target']
X = iris['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
labels = y[test_index]
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.4
def test_boston_housing_regression():
boston = load_boston()
y = boston['target']
X = boston['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index])
preds = xgb_model.predict(X[test_index])
# test other params in XGBRegressor().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
labels = y[test_index]
assert mean_squared_error(preds, labels) < 25
boston = load_boston()
y = boston['target']
X = boston['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index])
preds = xgb_model.predict(X[test_index])
# test other params in XGBRegressor().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
labels = y[test_index]
assert mean_squared_error(preds, labels) < 25
def test_parameter_tuning():
boston = load_boston()
y = boston['target']
X = boston['data']
xgb_model = xgb.XGBRegressor()
clf = GridSearchCV(xgb_model,
{'max_depth': [2,4,6],
'n_estimators': [50,100,200]}, verbose=1)
clf.fit(X,y)
assert clf.best_score_ < 0.7
assert clf.best_params_ == {'n_estimators': 100, 'max_depth': 4}
boston = load_boston()
y = boston['target']
X = boston['data']
xgb_model = xgb.XGBRegressor()
clf = GridSearchCV(xgb_model,
{'max_depth': [2,4,6],
'n_estimators': [50,100,200]}, verbose=1)
clf.fit(X,y)
assert clf.best_score_ < 0.7
assert clf.best_params_ == {'n_estimators': 100, 'max_depth': 4}
def test_regression_with_custom_objective():
def objective_ls(y_true, y_pred):
grad = (y_pred - y_true)
hess = np.ones(len(y_true))
return grad, hess
boston = load_boston()
y = boston['target']
X = boston['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBRegressor(objective=objective_ls).fit(
X[train_index], y[train_index]
)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
assert mean_squared_error(preds, labels) < 25
# Test that the custom objective function is actually used
class XGBCustomObjectiveException(Exception):
pass
def dummy_objective(y_true, y_pred):
raise XGBCustomObjectiveException()
xgb_model = xgb.XGBRegressor(objective=dummy_objective)
np.testing.assert_raises(
XGBCustomObjectiveException,
xgb_model.fit,
X, y
)
def test_classification_with_custom_objective():
def logregobj(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0-y_pred)
return grad, hess
digits = load_digits(2)
y = digits['target']
X = digits['data']
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model = xgb.XGBClassifier(objective=logregobj).fit(
X[train_index],y[train_index]
)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.1
# Test that the custom objective function is actually used
class XGBCustomObjectiveException(Exception):
pass
def dummy_objective(y_true, y_preds):
raise XGBCustomObjectiveException()
xgb_model = xgb.XGBClassifier(objective=dummy_objective)
np.testing.assert_raises(
XGBCustomObjectiveException,
xgb_model.fit,
X, y
)