From c8714f587a7f2627082b4898e6d079d5a4cbac94 Mon Sep 17 00:00:00 2001 From: Alexis Mignon Date: Mon, 15 Feb 2016 17:13:13 +0100 Subject: [PATCH] Added the possibility to use custom objective function in the sklearn API --- python-package/xgboost/sklearn.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 763551abf..53c787f5a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -25,7 +25,7 @@ class XGBModel(XGBModelBase): Number of boosted trees to fit. silent : boolean Whether to print messages while running boosting. - objective : string + objective : string or callable Specify the learning task and the corresponding learning objective. nthread : int @@ -174,6 +174,12 @@ class XGBModel(XGBModelBase): params = self.get_xgb_params() + if callable(self.objective): + obj = self.objective + params["objective"] = "reg:linear" + else: + obj = None + feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: if callable(eval_metric): @@ -184,7 +190,7 @@ class XGBModel(XGBModelBase): self._Booster = train(params, trainDmatrix, self.n_estimators, evals=evals, early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, feval=feval, + evals_result=evals_result, obj=obj, feval=feval, verbose_eval=verbose) if evals_result: @@ -302,13 +308,20 @@ class XGBClassifier(XGBModel, XGBClassifierBase): evals_result = {} self.classes_ = list(np.unique(y)) self.n_classes_ = len(self.classes_) + + + xgb_options = self.get_xgb_params() + + if callable(self.objective): + obj = self.objective + xgb_options["objective"] = "binary:logistic" + else: + obj = None + if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying XGB instance - self.objective = "multi:softprob" - xgb_options = self.get_xgb_params() + xgb_options["objective"] = "multi:softprob" xgb_options['num_class'] = self.n_classes_ - else: - xgb_options = self.get_xgb_params() feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: @@ -339,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): self._Booster = train(xgb_options, train_dmatrix, self.n_estimators, evals=evals, early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, feval=feval, + evals_result=evals_result, obj=obj, feval=feval, verbose_eval=verbose) if evals_result: