From 790dc877c37469de013204e842eea3c2512eb996 Mon Sep 17 00:00:00 2001 From: catena Date: Thu, 25 Feb 2016 13:42:19 +0530 Subject: [PATCH] return best_ntree_limit if early stopped --- python-package/xgboost/sklearn.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 6843afbac..5fd0f7495 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -206,7 +206,10 @@ class XGBModel(XGBModelBase): Requires at least one item in evals. If there's more than one, will use the last. Returns the model from the last iteration (not the best one). If early stopping occurs, the model will - have two additional fields: bst.best_score and bst.best_iteration. + have three additional fields: bst.best_score, bst.best_iteration + and bst.best_ntree_limit. + (Use bst.best_ntree_limit to get the correct value if num_parallel_tree + and/or num_class appears in the parameters) verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. @@ -251,6 +254,7 @@ class XGBModel(XGBModelBase): if early_stopping_rounds is not None: self.best_score = self._Booster.best_score self.best_iteration = self._Booster.best_iteration + self.best_ntree_limit = self._Booster.best_ntree_limit return self def predict(self, data, output_margin=False, ntree_limit=0): @@ -349,7 +353,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase): Requires at least one item in evals. If there's more than one, will use the last. Returns the model from the last iteration (not the best one). If early stopping occurs, the model will - have two additional fields: bst.best_score and bst.best_iteration. + have three additional fields: bst.best_score, bst.best_iteration + and bst.best_ntree_limit. + (Use bst.best_ntree_limit to get the correct value if num_parallel_tree + and/or num_class appears in the parameters) verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. @@ -416,6 +423,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if early_stopping_rounds is not None: self.best_score = self._Booster.best_score self.best_iteration = self._Booster.best_iteration + self.best_ntree_limit = self._Booster.best_ntree_limit return self