Added the possibility to use custom objective function in the sklearn API
This commit is contained in:
parent
9b2b81e6a4
commit
c8714f587a
@ -25,7 +25,7 @@ class XGBModel(XGBModelBase):
|
|||||||
Number of boosted trees to fit.
|
Number of boosted trees to fit.
|
||||||
silent : boolean
|
silent : boolean
|
||||||
Whether to print messages while running boosting.
|
Whether to print messages while running boosting.
|
||||||
objective : string
|
objective : string or callable
|
||||||
Specify the learning task and the corresponding learning objective.
|
Specify the learning task and the corresponding learning objective.
|
||||||
|
|
||||||
nthread : int
|
nthread : int
|
||||||
@ -174,6 +174,12 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
|
if callable(self.objective):
|
||||||
|
obj = self.objective
|
||||||
|
params["objective"] = "reg:linear"
|
||||||
|
else:
|
||||||
|
obj = None
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
if callable(eval_metric):
|
if callable(eval_metric):
|
||||||
@ -184,7 +190,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self._Booster = train(params, trainDmatrix,
|
self._Booster = train(params, trainDmatrix,
|
||||||
self.n_estimators, evals=evals,
|
self.n_estimators, evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
@ -302,13 +308,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
evals_result = {}
|
evals_result = {}
|
||||||
self.classes_ = list(np.unique(y))
|
self.classes_ = list(np.unique(y))
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
|
|
||||||
|
|
||||||
|
xgb_options = self.get_xgb_params()
|
||||||
|
|
||||||
|
if callable(self.objective):
|
||||||
|
obj = self.objective
|
||||||
|
xgb_options["objective"] = "binary:logistic"
|
||||||
|
else:
|
||||||
|
obj = None
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
# Switch to using a multiclass objective in the underlying XGB instance
|
# Switch to using a multiclass objective in the underlying XGB instance
|
||||||
self.objective = "multi:softprob"
|
xgb_options["objective"] = "multi:softprob"
|
||||||
xgb_options = self.get_xgb_params()
|
|
||||||
xgb_options['num_class'] = self.n_classes_
|
xgb_options['num_class'] = self.n_classes_
|
||||||
else:
|
|
||||||
xgb_options = self.get_xgb_params()
|
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
@ -339,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||||
evals=evals,
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user