ENH: Make XGBModel pickleable.

This commit is contained in:
Skipper Seabold 2015-05-06 12:33:43 -05:00
parent 3b4697786e
commit 11fa419720

View File

@ -15,6 +15,7 @@ import re
import ctypes
import platform
import collections
from tempfile import NamedTemporaryFile
import numpy as np
import scipy.sparse
@ -839,6 +840,31 @@ class XGBModel(XGBModelBase):
self._Booster = Booster()
def __getstate__(self):
# can't pickle ctypes pointers so save _Booster directly
this = self.__dict__.copy() # don't modify in place
# delete = False for x-platform compatibility
# https://bugs.python.org/issue14243
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
this["_Booster"].save_model(tmp.name)
tmp.close()
booster = open(tmp.name, "rb").read()
os.remove(tmp.name)
this.update({"_Booster": booster})
return this
def __setstate__(self, state):
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
tmp.write(state["_Booster"])
tmp.close()
booster = Booster(model_file=tmp.name)
os.remove(tmp.name)
state["_Booster"] = booster
self.__dict__.update(state)
def get_xgb_params(self):
xgb_params = self.get_params()