diff --git a/demo/guide-python/sklearn_examples.py b/demo/guide-python/sklearn_examples.py index ce8c8d01e..56fed1dd2 100755 --- a/demo/guide-python/sklearn_examples.py +++ b/demo/guide-python/sklearn_examples.py @@ -8,7 +8,7 @@ import pickle import xgboost as xgb import numpy as np -from sklearn.cross_validation import KFold +from sklearn.cross_validation import KFold, train_test_split from sklearn.metrics import confusion_matrix, mean_squared_error from sklearn.grid_search import GridSearchCV from sklearn.datasets import load_iris, load_digits, load_boston @@ -65,3 +65,23 @@ print("Pickling sklearn API models") pickle.dump(clf, open("best_boston.pkl", "wb")) clf2 = pickle.load(open("best_boston.pkl", "rb")) print(np.allclose(clf.predict(X), clf2.predict(X))) + +# Early-stopping + +X = digits['data'] +y = digits['target'] +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) +clf = xgb.XGBClassifier() +clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc", + eval_set=[(X_test, y_test)]) + +# Custom evaluation function +from sklearn.metrics import log_loss + + +def log_loss_eval(y_pred, y_true): + return "log-loss", log_loss(y_true.get_label(), y_pred) + + +clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric=log_loss_eval, + eval_set=[(X_test, y_test)])