Merge pull request #299 from jseabold/pickle-xgbooster

ENH: Pickle xgbooster enhancments. Thanks!
This commit is contained in:
Tianqi Chen 2015-05-11 08:44:36 -07:00
commit 8b9e87790a
4 changed files with 51 additions and 11 deletions

View File

@ -4,18 +4,17 @@ Created on 1 Apr 2015
@author: Jamie Hall
'''
import pickle
import xgboost as xgb
import numpy as np
from sklearn.cross_validation import KFold
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import confusion_matrix, mean_squared_error
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_iris, load_digits, load_boston
rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(2)
y = digits['target']
@ -60,4 +59,9 @@ clf.fit(X,y)
print(clf.best_score_)
print(clf.best_params_)
# The sklearn API models are picklable
print("Pickling sklearn API models")
# must open in binary format to pickle
pickle.dump(clf, open("best_boston.pkl", "wb"))
clf2 = pickle.load(open("best_boston.pkl", "rb"))
print(np.allclose(clf.predict(X), clf2.predict(X)))

View File

@ -0,0 +1,35 @@
import os
if __name__ == "__main__":
# NOTE: on posix systems, this *has* to be here and in the
# `__name__ == "__main__"` clause to run XGBoost in parallel processes
# using fork, if XGBoost was built with OpenMP support. Otherwise, if you
# build XGBoost without OpenMP support, you can use fork, which is the
# default backend for joblib, and omit this.
try:
from multiprocessing import set_start_method
except ImportError:
raise ImportError("Unable to import multiprocessing.set_start_method."
" This example only runs on Python 3.4")
set_start_method("forkserver")
import numpy as np
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_boston
import xgboost as xgb
rng = np.random.RandomState(31337)
print("Parallel Parameter optimization")
boston = load_boston()
os.environ["OMP_NUM_THREADS"] = "2" # or to whatever you want
y = boston['target']
X = boston['data']
xgb_model = xgb.XGBRegressor()
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
'n_estimators': [50, 100, 200]}, verbose=1,
n_jobs=2)
clf.fit(X, y)
print(clf.best_score_)
print(clf.best_params_)

View File

@ -7,6 +7,8 @@ Python
* To make the python module, type ```./build.sh``` in the root directory of project
* Install with `python setup.py install` from this directory.
* Refer also to the walk through example in [demo folder](../demo/guide-python)
* **NOTE**: if you want to run XGBoost in parallel using the fork backend for joblib/multiprocessing, you must build XGBoost without support for OpenMP by `make no_omp=1`. Otherwise, use the forkserver (in Python 3.4) or spawn backend. See the sklearn_examples.py demo.
R
=====

View File

@ -15,7 +15,6 @@ import re
import ctypes
import platform
import collections
from io import BytesIO
import numpy as np
import scipy.sparse
@ -492,7 +491,7 @@ class Booster(object):
def save_raw(self):
"""
Save the model to a in memory buffer represetation
Returns
-------
a in memory buffer represetation of the model
@ -876,12 +875,12 @@ class XGBModel(XGBModelBase):
self._Booster = None
def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place
# can't pickle ctypes pointers so put _Booster in a bytearray object
this = self.__dict__.copy() # don't modify in place
bst = this["_Booster"]
if bst is not None:
raw = this["_Booster"].save_raw()
this["_Booster"] = raw
this["_Booster"] = raw
return this
def __setstate__(self, state):
@ -894,7 +893,7 @@ class XGBModel(XGBModelBase):
"""
get the underlying xgboost Booster of this model
will raise an exception when fit was not called
Returns
-------
booster : a xgboost booster of underlying model
@ -902,7 +901,7 @@ class XGBModel(XGBModelBase):
if self._Booster is None:
raise XGBError('need to call fit beforehand')
return self._Booster
def get_xgb_params(self):
xgb_params = self.get_params()