From b0f7ddaa2ee3411b33f95b42be46a3325b0ac23b Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Tue, 30 Jun 2015 11:42:14 -0500 Subject: [PATCH] REF: Combine eval_metric and feval to one parameter --- wrapper/xgboost.py | 48 ++++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index adb21a00b..95e0bf6ff 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -1093,7 +1093,7 @@ class XGBModel(XGBModelBase): return xgb_params def fit(self, X, y, eval_set=None, eval_metric=None, - early_stopping_rounds=None, feval=None, verbose=True): + early_stopping_rounds=None, verbose=True): # pylint: disable=missing-docstring,invalid-name """ Fit the gradient boosting model @@ -1107,8 +1107,14 @@ class XGBModel(XGBModelBase): eval_set : list, optional A list of (X, y) tuple pairs to use as a validation set for early-stopping - eval_metric : str, optional - Built-in evaluation metric to use. See doc/parameter.md. + eval_metric : str, callable, optional + If a str, should be a built-in evaluation metric to use. See + doc/parameter.md. If callable, a custom evaluation metric. The call + signature is func(y_predicted, y_true) where y_true will be a + DMatrix object such that you may need to call the get_label + method. It must return a str, value pair where the str is a name + for the evaluation and value is the value of the evaluation + function. This objective is always minimized. early_stopping_rounds : int Activates early stopping. Validation error needs to decrease at least every round(s) to continue training. @@ -1116,11 +1122,6 @@ class XGBModel(XGBModelBase): will use the last. Returns the model from the last iteration (not the best one). If early stopping occurs, the model will have two additional fields: bst.best_score and bst.best_iteration. - feval : function, optional - Custom evaluation metric to use. The call signature is - feval(y_predicted, y_true) where y_true will be a DMatrix object - such that you may need to call the get_label method. This objective - if always assumed to be minimized, so use -feval when appropriate. verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. @@ -1137,13 +1138,17 @@ class XGBModel(XGBModelBase): params = self.get_xgb_params() + feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: - params.update({'eval_metric': eval_metric}) + if callable(eval_metric): + eval_metric = None + else: + params.update({'eval_metric': eval_metric}) self._Booster = train(params, trainDmatrix, self.n_estimators, evals=evals, early_stopping_rounds=early_stopping_rounds, - evals_result=eval_results, feval=None, + evals_result=eval_results, feval=feval, verbose_eval=verbose) if eval_results: eval_results = {k: np.array(v, dtype=float) @@ -1180,7 +1185,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): base_score, seed, missing) def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, - early_stopping_rounds=None, feval=None, versbose=True): + early_stopping_rounds=None, verbose=True): # pylint: disable = attribute-defined-outside-init,arguments-differ """ Fit gradient boosting classifier @@ -1196,8 +1201,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase): eval_set : list, optional A list of (X, y) pairs to use as a validation set for early-stopping - eval_metric : str - Built-in evaluation metric to use. See doc/parameter.md. + eval_metric : str, callable, optional + If a str, should be a built-in evaluation metric to use. See + doc/parameter.md. If callable, a custom evaluation metric. The call + signature is func(y_predicted, y_true) where y_true will be a + DMatrix object such that you may need to call the get_label + method. It must return a str, value pair where the str is a name + for the evaluation and value is the value of the evaluation + function. This objective is always minimized. early_stopping_rounds : int, optional Activates early stopping. Validation error needs to decrease at least every round(s) to continue training. @@ -1205,11 +1216,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): will use the last. Returns the model from the last iteration (not the best one). If early stopping occurs, the model will have two additional fields: bst.best_score and bst.best_iteration. - feval : function, optional - Custom evaluation metric to use. The call signature is - feval(y_predicted, y_true) where y_true will be a DMatrix object - such that you may need to call the get_label method. This objective - if always assumed to be minimized, so use -feval when appropriate. verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. @@ -1225,8 +1231,12 @@ class XGBClassifier(XGBModel, XGBClassifierBase): else: xgb_options = self.get_xgb_params() + feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: - xgb_options.update({"eval_metric": eval_metric}) + if callable(eval_metric): + eval_metric = None + else: + xgb_options.update({"eval_metric": eval_metric}) if eval_set is not None: # TODO: use sample_weight if given?