ENH: Make XGBModel pickleable.
This commit is contained in:
parent
3b4697786e
commit
11fa419720
@ -15,6 +15,7 @@ import re
|
|||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
import collections
|
import collections
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
@ -839,6 +840,31 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
self._Booster = Booster()
|
self._Booster = Booster()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# can't pickle ctypes pointers so save _Booster directly
|
||||||
|
this = self.__dict__.copy() # don't modify in place
|
||||||
|
|
||||||
|
# delete = False for x-platform compatibility
|
||||||
|
# https://bugs.python.org/issue14243
|
||||||
|
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
|
||||||
|
this["_Booster"].save_model(tmp.name)
|
||||||
|
tmp.close()
|
||||||
|
booster = open(tmp.name, "rb").read()
|
||||||
|
os.remove(tmp.name)
|
||||||
|
this.update({"_Booster": booster})
|
||||||
|
|
||||||
|
return this
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
|
||||||
|
tmp.write(state["_Booster"])
|
||||||
|
tmp.close()
|
||||||
|
booster = Booster(model_file=tmp.name)
|
||||||
|
os.remove(tmp.name)
|
||||||
|
|
||||||
|
state["_Booster"] = booster
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
xgb_params = self.get_params()
|
xgb_params = self.get_params()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user