Merge pull request #299 from jseabold/pickle-xgbooster

ENH: Pickle xgbooster enhancments. Thanks!
This commit is contained in:
Tianqi Chen 2015-05-11 08:44:36 -07:00
commit 8b9e87790a
4 changed files with 51 additions and 11 deletions

View File

@ -4,18 +4,17 @@ Created on 1 Apr 2015
@author: Jamie Hall @author: Jamie Hall
''' '''
import pickle
import xgboost as xgb import xgboost as xgb
import numpy as np import numpy as np
from sklearn.cross_validation import KFold from sklearn.cross_validation import KFold
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import confusion_matrix, mean_squared_error from sklearn.metrics import confusion_matrix, mean_squared_error
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_iris, load_digits, load_boston from sklearn.datasets import load_iris, load_digits, load_boston
rng = np.random.RandomState(31337) rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification") print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(2) digits = load_digits(2)
y = digits['target'] y = digits['target']
@ -60,4 +59,9 @@ clf.fit(X,y)
print(clf.best_score_) print(clf.best_score_)
print(clf.best_params_) print(clf.best_params_)
# The sklearn API models are picklable
print("Pickling sklearn API models")
# must open in binary format to pickle
pickle.dump(clf, open("best_boston.pkl", "wb"))
clf2 = pickle.load(open("best_boston.pkl", "rb"))
print(np.allclose(clf.predict(X), clf2.predict(X)))

View File

@ -0,0 +1,35 @@
import os
if __name__ == "__main__":
# NOTE: on posix systems, this *has* to be here and in the
# `__name__ == "__main__"` clause to run XGBoost in parallel processes
# using fork, if XGBoost was built with OpenMP support. Otherwise, if you
# build XGBoost without OpenMP support, you can use fork, which is the
# default backend for joblib, and omit this.
try:
from multiprocessing import set_start_method
except ImportError:
raise ImportError("Unable to import multiprocessing.set_start_method."
" This example only runs on Python 3.4")
set_start_method("forkserver")
import numpy as np
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_boston
import xgboost as xgb
rng = np.random.RandomState(31337)
print("Parallel Parameter optimization")
boston = load_boston()
os.environ["OMP_NUM_THREADS"] = "2" # or to whatever you want
y = boston['target']
X = boston['data']
xgb_model = xgb.XGBRegressor()
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
'n_estimators': [50, 100, 200]}, verbose=1,
n_jobs=2)
clf.fit(X, y)
print(clf.best_score_)
print(clf.best_params_)

View File

@ -7,6 +7,8 @@ Python
* To make the python module, type ```./build.sh``` in the root directory of project * To make the python module, type ```./build.sh``` in the root directory of project
* Install with `python setup.py install` from this directory. * Install with `python setup.py install` from this directory.
* Refer also to the walk through example in [demo folder](../demo/guide-python) * Refer also to the walk through example in [demo folder](../demo/guide-python)
* **NOTE**: if you want to run XGBoost in parallel using the fork backend for joblib/multiprocessing, you must build XGBoost without support for OpenMP by `make no_omp=1`. Otherwise, use the forkserver (in Python 3.4) or spawn backend. See the sklearn_examples.py demo.
R R
===== =====

View File

@ -15,7 +15,6 @@ import re
import ctypes import ctypes
import platform import platform
import collections import collections
from io import BytesIO
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
@ -492,7 +491,7 @@ class Booster(object):
def save_raw(self): def save_raw(self):
""" """
Save the model to a in memory buffer represetation Save the model to a in memory buffer represetation
Returns Returns
------- -------
a in memory buffer represetation of the model a in memory buffer represetation of the model
@ -876,12 +875,12 @@ class XGBModel(XGBModelBase):
self._Booster = None self._Booster = None
def __getstate__(self): def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a BytesIO obj # can't pickle ctypes pointers so put _Booster in a bytearray object
this = self.__dict__.copy() # don't modify in place this = self.__dict__.copy() # don't modify in place
bst = this["_Booster"] bst = this["_Booster"]
if bst is not None: if bst is not None:
raw = this["_Booster"].save_raw() raw = this["_Booster"].save_raw()
this["_Booster"] = raw this["_Booster"] = raw
return this return this
def __setstate__(self, state): def __setstate__(self, state):
@ -894,7 +893,7 @@ class XGBModel(XGBModelBase):
""" """
get the underlying xgboost Booster of this model get the underlying xgboost Booster of this model
will raise an exception when fit was not called will raise an exception when fit was not called
Returns Returns
------- -------
booster : a xgboost booster of underlying model booster : a xgboost booster of underlying model
@ -902,7 +901,7 @@ class XGBModel(XGBModelBase):
if self._Booster is None: if self._Booster is None:
raise XGBError('need to call fit beforehand') raise XGBError('need to call fit beforehand')
return self._Booster return self._Booster
def get_xgb_params(self): def get_xgb_params(self):
xgb_params = self.get_params() xgb_params = self.get_params()