From 5f9f42292c82afea411a3939e58544ef4cc723d2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 29 Jul 2015 17:49:55 -0700 Subject: [PATCH] fix sklearn best score --- wrapper/xgboost.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 77f5bedb8..32f9a52b4 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -866,6 +866,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, bst.best_iteration = best_score_i return bst + class CVPack(object): """"Auxiliary datastruct to hold one fold of CV.""" def __init__(self, dtrain, dtest, param): @@ -1154,9 +1155,11 @@ class XGBModel(XGBModelBase): eval_results = {k: np.array(v, dtype=float) for k, v in eval_results.items()} eval_results = {k: np.array(v) for k, v in eval_results.items()} - self.eval_results_ = eval_results - self.best_score_ = self._Booster.best_score - self.best_iteration_ = self._Booster.best_iteration + self.eval_results = eval_results + + if early_stopping_rounds is not None: + self.best_score = self._Booster.best_score + self.best_iteration = self._Booster.best_iteration return self def predict(self, data): @@ -1266,9 +1269,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if eval_results: eval_results = {k: np.array(v, dtype=float) for k, v in eval_results.items()} - self.eval_results_ = eval_results - self.best_score_ = self._Booster.best_score - self.best_iteration_ = self._Booster.best_iteration + self.eval_results = eval_results + + if early_stopping_rounds is not None: + self.best_score = self._Booster.best_score + self.best_iteration = self._Booster.best_iteration return self