Added the possibility to use custom objective function in the sklearn API
This commit is contained in:
parent
9b2b81e6a4
commit
c8714f587a
@ -25,7 +25,7 @@ class XGBModel(XGBModelBase):
|
||||
Number of boosted trees to fit.
|
||||
silent : boolean
|
||||
Whether to print messages while running boosting.
|
||||
objective : string
|
||||
objective : string or callable
|
||||
Specify the learning task and the corresponding learning objective.
|
||||
|
||||
nthread : int
|
||||
@ -174,6 +174,12 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
params = self.get_xgb_params()
|
||||
|
||||
if callable(self.objective):
|
||||
obj = self.objective
|
||||
params["objective"] = "reg:linear"
|
||||
else:
|
||||
obj = None
|
||||
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
@ -184,7 +190,7 @@ class XGBModel(XGBModelBase):
|
||||
self._Booster = train(params, trainDmatrix,
|
||||
self.n_estimators, evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result, feval=feval,
|
||||
evals_result=evals_result, obj=obj, feval=feval,
|
||||
verbose_eval=verbose)
|
||||
|
||||
if evals_result:
|
||||
@ -302,13 +308,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
evals_result = {}
|
||||
self.classes_ = list(np.unique(y))
|
||||
self.n_classes_ = len(self.classes_)
|
||||
|
||||
|
||||
xgb_options = self.get_xgb_params()
|
||||
|
||||
if callable(self.objective):
|
||||
obj = self.objective
|
||||
xgb_options["objective"] = "binary:logistic"
|
||||
else:
|
||||
obj = None
|
||||
|
||||
if self.n_classes_ > 2:
|
||||
# Switch to using a multiclass objective in the underlying XGB instance
|
||||
self.objective = "multi:softprob"
|
||||
xgb_options = self.get_xgb_params()
|
||||
xgb_options["objective"] = "multi:softprob"
|
||||
xgb_options['num_class'] = self.n_classes_
|
||||
else:
|
||||
xgb_options = self.get_xgb_params()
|
||||
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
@ -339,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result, feval=feval,
|
||||
evals_result=evals_result, obj=obj, feval=feval,
|
||||
verbose_eval=verbose)
|
||||
|
||||
if evals_result:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user