From 4a37b852a03b1320d1c41f948a4ec212a981ad1d Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Tue, 30 Jun 2015 11:42:28 -0500 Subject: [PATCH] DOC: Add early stopping example --- demo/guide-python/sklearn_examples.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/demo/guide-python/sklearn_examples.py b/demo/guide-python/sklearn_examples.py index ce8c8d01e..56fed1dd2 100755 --- a/demo/guide-python/sklearn_examples.py +++ b/demo/guide-python/sklearn_examples.py @@ -8,7 +8,7 @@ import pickle import xgboost as xgb import numpy as np -from sklearn.cross_validation import KFold +from sklearn.cross_validation import KFold, train_test_split from sklearn.metrics import confusion_matrix, mean_squared_error from sklearn.grid_search import GridSearchCV from sklearn.datasets import load_iris, load_digits, load_boston @@ -65,3 +65,23 @@ print("Pickling sklearn API models") pickle.dump(clf, open("best_boston.pkl", "wb")) clf2 = pickle.load(open("best_boston.pkl", "rb")) print(np.allclose(clf.predict(X), clf2.predict(X))) + +# Early-stopping + +X = digits['data'] +y = digits['target'] +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) +clf = xgb.XGBClassifier() +clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc", + eval_set=[(X_test, y_test)]) + +# Custom evaluation function +from sklearn.metrics import log_loss + + +def log_loss_eval(y_pred, y_true): + return "log-loss", log_loss(y_true.get_label(), y_pred) + + +clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric=log_loss_eval, + eval_set=[(X_test, y_test)])