From 15ea00540a7c31d208d6a11c096b7a172ddadad2 Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Mon, 11 May 2015 09:30:51 -0500 Subject: [PATCH] EX: Make separate example for fork issue. --- demo/guide-python/sklearn_examples.py | 130 +++++++++++--------------- demo/guide-python/sklearn_parallel.py | 35 +++++++ 2 files changed, 89 insertions(+), 76 deletions(-) create mode 100644 demo/guide-python/sklearn_parallel.py diff --git a/demo/guide-python/sklearn_examples.py b/demo/guide-python/sklearn_examples.py index b378d28cc..ce8c8d01e 100755 --- a/demo/guide-python/sklearn_examples.py +++ b/demo/guide-python/sklearn_examples.py @@ -4,86 +4,64 @@ Created on 1 Apr 2015 @author: Jamie Hall ''' -if __name__ == "__main__": - # NOTE: This *has* to be here and in the `__name__ == "__main__"` clause - # to run XGBoost in parallel, if XGBoost was built with OpenMP support. - # Otherwise, you can use fork, which is the default backend for joblib, - # and omit this. - from multiprocessing import set_start_method - set_start_method("forkserver") +import pickle +import xgboost as xgb - import pickle - import os - import xgboost as xgb +import numpy as np +from sklearn.cross_validation import KFold +from sklearn.metrics import confusion_matrix, mean_squared_error +from sklearn.grid_search import GridSearchCV +from sklearn.datasets import load_iris, load_digits, load_boston - import numpy as np - from sklearn.cross_validation import KFold - from sklearn.grid_search import GridSearchCV - from sklearn.metrics import confusion_matrix, mean_squared_error - from sklearn.datasets import load_iris, load_digits, load_boston +rng = np.random.RandomState(31337) - rng = np.random.RandomState(31337) +print("Zeros and Ones from the Digits dataset: binary classification") +digits = load_digits(2) +y = digits['target'] +X = digits['data'] +kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) +for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) + predictions = xgb_model.predict(X[test_index]) + actuals = y[test_index] + print(confusion_matrix(actuals, predictions)) - print("Zeros and Ones from the Digits dataset: binary classification") - digits = load_digits(2) - y = digits['target'] - X = digits['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) - for train_index, test_index in kf: - xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) - predictions = xgb_model.predict(X[test_index]) - actuals = y[test_index] - print(confusion_matrix(actuals, predictions)) +print("Iris: multiclass classification") +iris = load_iris() +y = iris['target'] +X = iris['data'] +kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) +for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) + predictions = xgb_model.predict(X[test_index]) + actuals = y[test_index] + print(confusion_matrix(actuals, predictions)) - print("Iris: multiclass classification") - iris = load_iris() - y = iris['target'] - X = iris['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) - for train_index, test_index in kf: - xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) - predictions = xgb_model.predict(X[test_index]) - actuals = y[test_index] - print(confusion_matrix(actuals, predictions)) +print("Boston Housing: regression") +boston = load_boston() +y = boston['target'] +X = boston['data'] +kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) +for train_index, test_index in kf: + xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index]) + predictions = xgb_model.predict(X[test_index]) + actuals = y[test_index] + print(mean_squared_error(actuals, predictions)) - print("Boston Housing: regression") - boston = load_boston() - y = boston['target'] - X = boston['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) - for train_index, test_index in kf: - xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index]) - predictions = xgb_model.predict(X[test_index]) - actuals = y[test_index] - print(mean_squared_error(actuals, predictions)) +print("Parameter optimization") +y = boston['target'] +X = boston['data'] +xgb_model = xgb.XGBRegressor() +clf = GridSearchCV(xgb_model, + {'max_depth': [2,4,6], + 'n_estimators': [50,100,200]}, verbose=1) +clf.fit(X,y) +print(clf.best_score_) +print(clf.best_params_) - print("Parameter optimization") - y = boston['target'] - X = boston['data'] - xgb_model = xgb.XGBRegressor() - clf = GridSearchCV(xgb_model, - {'max_depth': [2,4,6], - 'n_estimators': [50,100,200]}, verbose=1) - clf.fit(X,y) - print(clf.best_score_) - print(clf.best_params_) - - # The sklearn API models are picklable - print("Pickling sklearn API models") - # must open in binary format to pickle - pickle.dump(clf, open("best_boston.pkl", "wb")) - clf2 = pickle.load(open("best_boston.pkl", "rb")) - print(np.allclose(clf.predict(X), clf2.predict(X))) - - print("Parallel Parameter optimization") - os.environ["OMP_NUM_THREADS"] = "1" - y = boston['target'] - X = boston['data'] - xgb_model = xgb.XGBRegressor() - clf = GridSearchCV(xgb_model, - {'max_depth': [2,4,6], - 'n_estimators': [50,100,200]}, verbose=1, - n_jobs=2) - clf.fit(X, y) - print(clf.best_score_) - print(clf.best_params_) +# The sklearn API models are picklable +print("Pickling sklearn API models") +# must open in binary format to pickle +pickle.dump(clf, open("best_boston.pkl", "wb")) +clf2 = pickle.load(open("best_boston.pkl", "rb")) +print(np.allclose(clf.predict(X), clf2.predict(X))) diff --git a/demo/guide-python/sklearn_parallel.py b/demo/guide-python/sklearn_parallel.py new file mode 100644 index 000000000..803f3fac8 --- /dev/null +++ b/demo/guide-python/sklearn_parallel.py @@ -0,0 +1,35 @@ +import os + +if __name__ == "__main__": + # NOTE: on posix systems, this *has* to be here and in the + # `__name__ == "__main__"` clause to run XGBoost in parallel processes + # using fork, if XGBoost was built with OpenMP support. Otherwise, if you + # build XGBoost without OpenMP support, you can use fork, which is the + # default backend for joblib, and omit this. + try: + from multiprocessing import set_start_method + except ImportError: + raise ImportError("Unable to import multiprocessing.set_start_method." + " This example only runs on Python 3.4") + set_start_method("forkserver") + + import numpy as np + from sklearn.grid_search import GridSearchCV + from sklearn.datasets import load_boston + import xgboost as xgb + + rng = np.random.RandomState(31337) + + print("Parallel Parameter optimization") + boston = load_boston() + + os.environ["OMP_NUM_THREADS"] = "2" # or to whatever you want + y = boston['target'] + X = boston['data'] + xgb_model = xgb.XGBRegressor() + clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6], + 'n_estimators': [50, 100, 200]}, verbose=1, + n_jobs=2) + clf.fit(X, y) + print(clf.best_score_) + print(clf.best_params_)