Added the possibility to use custom objective function in the sklearn API

This commit is contained in:
Alexis Mignon 2016-02-15 17:13:13 +01:00
parent 9b2b81e6a4
commit c8714f587a

View File

@ -25,7 +25,7 @@ class XGBModel(XGBModelBase):
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
objective : string
objective : string or callable
Specify the learning task and the corresponding learning objective.
nthread : int
@ -174,6 +174,12 @@ class XGBModel(XGBModelBase):
params = self.get_xgb_params()
if callable(self.objective):
obj = self.objective
params["objective"] = "reg:linear"
else:
obj = None
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
@ -184,7 +190,7 @@ class XGBModel(XGBModelBase):
self._Booster = train(params, trainDmatrix,
self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, feval=feval,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose)
if evals_result:
@ -302,13 +308,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
evals_result = {}
self.classes_ = list(np.unique(y))
self.n_classes_ = len(self.classes_)
xgb_options = self.get_xgb_params()
if callable(self.objective):
obj = self.objective
xgb_options["objective"] = "binary:logistic"
else:
obj = None
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance
self.objective = "multi:softprob"
xgb_options = self.get_xgb_params()
xgb_options["objective"] = "multi:softprob"
xgb_options['num_class'] = self.n_classes_
else:
xgb_options = self.get_xgb_params()
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
@ -339,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, feval=feval,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose)
if evals_result: