Merge pull request #833 from AlexisMignon/master
Added the possibility to use custom objective function in the sklearn…
This commit is contained in:
commit
75d23c8bb2
@ -11,6 +11,39 @@ from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
|||||||
XGBClassifierBase, XGBRegressorBase, LabelEncoder)
|
XGBClassifierBase, XGBRegressorBase, LabelEncoder)
|
||||||
|
|
||||||
|
|
||||||
|
def _objective_decorator(func):
|
||||||
|
"""Decorate an objective function
|
||||||
|
|
||||||
|
Converts an objective function using the typical sklearn metrics
|
||||||
|
signature so that it is usable with ``xgboost.training.train``
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func: callable
|
||||||
|
Expects a callable with signature ``func(y_true, y_pred)``:
|
||||||
|
|
||||||
|
y_true: array_like of shape [n_samples]
|
||||||
|
The target values
|
||||||
|
y_pred: array_like of shape [n_samples]
|
||||||
|
The predicted values
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
new_func: callable
|
||||||
|
The new objective function as expected by ``xgboost.training.train``.
|
||||||
|
The signature is ``new_func(preds, dmatrix)``:
|
||||||
|
|
||||||
|
preds: array_like, shape [n_samples]
|
||||||
|
The predicted values
|
||||||
|
dmatrix: ``DMatrix``
|
||||||
|
The training set from which the labels will be extracted using
|
||||||
|
``dmatrix.get_label()``
|
||||||
|
"""
|
||||||
|
def inner(preds, dmatrix):
|
||||||
|
labels = dmatrix.get_label()
|
||||||
|
return func(labels, preds)
|
||||||
|
return inner
|
||||||
|
|
||||||
class XGBModel(XGBModelBase):
|
class XGBModel(XGBModelBase):
|
||||||
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
|
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
|
||||||
"""Implementation of the Scikit-Learn API for XGBoost.
|
"""Implementation of the Scikit-Learn API for XGBoost.
|
||||||
@ -25,9 +58,9 @@ class XGBModel(XGBModelBase):
|
|||||||
Number of boosted trees to fit.
|
Number of boosted trees to fit.
|
||||||
silent : boolean
|
silent : boolean
|
||||||
Whether to print messages while running boosting.
|
Whether to print messages while running boosting.
|
||||||
objective : string
|
objective : string or callable
|
||||||
Specify the learning task and the corresponding learning objective.
|
Specify the learning task and the corresponding learning objective or
|
||||||
|
a custom objective function to be used (see note below).
|
||||||
nthread : int
|
nthread : int
|
||||||
Number of parallel threads used to run xgboost.
|
Number of parallel threads used to run xgboost.
|
||||||
gamma : float
|
gamma : float
|
||||||
@ -56,6 +89,22 @@ class XGBModel(XGBModelBase):
|
|||||||
missing : float, optional
|
missing : float, optional
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value. If
|
||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
A custom objective function can be provided for the ``objective``
|
||||||
|
parameter. In this case, it should have the signature
|
||||||
|
``objective(y_true, y_pred) -> grad, hess``:
|
||||||
|
|
||||||
|
y_true: array_like of shape [n_samples]
|
||||||
|
The target values
|
||||||
|
y_pred: array_like of shape [n_samples]
|
||||||
|
The predicted values
|
||||||
|
|
||||||
|
grad: array_like of shape [n_samples]
|
||||||
|
The value of the gradient for each sample point.
|
||||||
|
hess: array_like of shape [n_samples]
|
||||||
|
The value of the second derivative for each sample point
|
||||||
"""
|
"""
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
silent=True, objective="reg:linear",
|
silent=True, objective="reg:linear",
|
||||||
@ -174,6 +223,12 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
|
if callable(self.objective):
|
||||||
|
obj = _objective_decorator(self.objective)
|
||||||
|
params["objective"] = "reg:linear"
|
||||||
|
else:
|
||||||
|
obj = None
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
if callable(eval_metric):
|
if callable(eval_metric):
|
||||||
@ -184,7 +239,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self._Booster = train(params, trainDmatrix,
|
self._Booster = train(params, trainDmatrix,
|
||||||
self.n_estimators, evals=evals,
|
self.n_estimators, evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
@ -302,13 +357,21 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
evals_result = {}
|
evals_result = {}
|
||||||
self.classes_ = list(np.unique(y))
|
self.classes_ = list(np.unique(y))
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
|
|
||||||
|
|
||||||
|
xgb_options = self.get_xgb_params()
|
||||||
|
|
||||||
|
if callable(self.objective):
|
||||||
|
obj = _objective_decorator(self.objective)
|
||||||
|
# Use default value. Is it really not used ?
|
||||||
|
xgb_options["objective"] = "binary:logistic"
|
||||||
|
else:
|
||||||
|
obj = None
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
# Switch to using a multiclass objective in the underlying XGB instance
|
# Switch to using a multiclass objective in the underlying XGB instance
|
||||||
self.objective = "multi:softprob"
|
xgb_options["objective"] = "multi:softprob"
|
||||||
xgb_options = self.get_xgb_params()
|
|
||||||
xgb_options['num_class'] = self.n_classes_
|
xgb_options['num_class'] = self.n_classes_
|
||||||
else:
|
|
||||||
xgb_options = self.get_xgb_params()
|
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
@ -341,7 +404,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||||
evals=evals,
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.cross_validation import KFold, train_test_split
|
from sklearn.cross_validation import KFold
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
from sklearn.grid_search import GridSearchCV
|
from sklearn.grid_search import GridSearchCV
|
||||||
from sklearn.datasets import load_iris, load_digits, load_boston
|
from sklearn.datasets import load_iris, load_digits, load_boston
|
||||||
@ -8,57 +8,125 @@ from sklearn.datasets import load_iris, load_digits, load_boston
|
|||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
def test_binary_classification():
|
def test_binary_classification():
|
||||||
digits = load_digits(2)
|
digits = load_digits(2)
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf:
|
for train_index, test_index in kf:
|
||||||
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
|
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
||||||
assert err < 0.1
|
assert err < 0.1
|
||||||
|
|
||||||
def test_multiclass_classification():
|
def test_multiclass_classification():
|
||||||
iris = load_iris()
|
iris = load_iris()
|
||||||
y = iris['target']
|
y = iris['target']
|
||||||
X = iris['data']
|
X = iris['data']
|
||||||
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf:
|
for train_index, test_index in kf:
|
||||||
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
|
xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index])
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
# test other params in XGBClassifier().fit
|
# test other params in XGBClassifier().fit
|
||||||
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
|
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
|
||||||
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
|
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
|
||||||
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
|
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
|
||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
||||||
assert err < 0.4
|
assert err < 0.4
|
||||||
|
|
||||||
def test_boston_housing_regression():
|
def test_boston_housing_regression():
|
||||||
boston = load_boston()
|
boston = load_boston()
|
||||||
y = boston['target']
|
y = boston['target']
|
||||||
X = boston['data']
|
X = boston['data']
|
||||||
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf:
|
for train_index, test_index in kf:
|
||||||
xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index])
|
xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index])
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
# test other params in XGBRegressor().fit
|
# test other params in XGBRegressor().fit
|
||||||
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
|
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
|
||||||
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
|
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
|
||||||
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
|
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
|
||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
assert mean_squared_error(preds, labels) < 25
|
assert mean_squared_error(preds, labels) < 25
|
||||||
|
|
||||||
def test_parameter_tuning():
|
def test_parameter_tuning():
|
||||||
boston = load_boston()
|
boston = load_boston()
|
||||||
y = boston['target']
|
y = boston['target']
|
||||||
X = boston['data']
|
X = boston['data']
|
||||||
xgb_model = xgb.XGBRegressor()
|
xgb_model = xgb.XGBRegressor()
|
||||||
clf = GridSearchCV(xgb_model,
|
clf = GridSearchCV(xgb_model,
|
||||||
{'max_depth': [2,4,6],
|
{'max_depth': [2,4,6],
|
||||||
'n_estimators': [50,100,200]}, verbose=1)
|
'n_estimators': [50,100,200]}, verbose=1)
|
||||||
clf.fit(X,y)
|
clf.fit(X,y)
|
||||||
assert clf.best_score_ < 0.7
|
assert clf.best_score_ < 0.7
|
||||||
assert clf.best_params_ == {'n_estimators': 100, 'max_depth': 4}
|
assert clf.best_params_ == {'n_estimators': 100, 'max_depth': 4}
|
||||||
|
|
||||||
|
def test_regression_with_custom_objective():
|
||||||
|
def objective_ls(y_true, y_pred):
|
||||||
|
grad = (y_pred - y_true)
|
||||||
|
hess = np.ones(len(y_true))
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
boston = load_boston()
|
||||||
|
y = boston['target']
|
||||||
|
X = boston['data']
|
||||||
|
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
||||||
|
for train_index, test_index in kf:
|
||||||
|
xgb_model = xgb.XGBRegressor(objective=objective_ls).fit(
|
||||||
|
X[train_index], y[train_index]
|
||||||
|
)
|
||||||
|
preds = xgb_model.predict(X[test_index])
|
||||||
|
labels = y[test_index]
|
||||||
|
assert mean_squared_error(preds, labels) < 25
|
||||||
|
|
||||||
|
# Test that the custom objective function is actually used
|
||||||
|
class XGBCustomObjectiveException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dummy_objective(y_true, y_pred):
|
||||||
|
raise XGBCustomObjectiveException()
|
||||||
|
|
||||||
|
xgb_model = xgb.XGBRegressor(objective=dummy_objective)
|
||||||
|
np.testing.assert_raises(
|
||||||
|
XGBCustomObjectiveException,
|
||||||
|
xgb_model.fit,
|
||||||
|
X, y
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_classification_with_custom_objective():
|
||||||
|
def logregobj(y_true, y_pred):
|
||||||
|
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
|
||||||
|
grad = y_pred - y_true
|
||||||
|
hess = y_pred * (1.0-y_pred)
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
digits = load_digits(2)
|
||||||
|
y = digits['target']
|
||||||
|
X = digits['data']
|
||||||
|
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
|
||||||
|
for train_index, test_index in kf:
|
||||||
|
xgb_model = xgb.XGBClassifier(objective=logregobj).fit(
|
||||||
|
X[train_index],y[train_index]
|
||||||
|
)
|
||||||
|
preds = xgb_model.predict(X[test_index])
|
||||||
|
labels = y[test_index]
|
||||||
|
err = sum(1 for i in range(len(preds))
|
||||||
|
if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
||||||
|
assert err < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# Test that the custom objective function is actually used
|
||||||
|
class XGBCustomObjectiveException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dummy_objective(y_true, y_preds):
|
||||||
|
raise XGBCustomObjectiveException()
|
||||||
|
|
||||||
|
xgb_model = xgb.XGBClassifier(objective=dummy_objective)
|
||||||
|
np.testing.assert_raises(
|
||||||
|
XGBCustomObjectiveException,
|
||||||
|
xgb_model.fit,
|
||||||
|
X, y
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user