ENH: Make XGBModel pickleable.
This commit is contained in:
parent
3b4697786e
commit
11fa419720
@ -15,6 +15,7 @@ import re
|
||||
import ctypes
|
||||
import platform
|
||||
import collections
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
@ -839,6 +840,31 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
self._Booster = Booster()
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle ctypes pointers so save _Booster directly
|
||||
this = self.__dict__.copy() # don't modify in place
|
||||
|
||||
# delete = False for x-platform compatibility
|
||||
# https://bugs.python.org/issue14243
|
||||
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
|
||||
this["_Booster"].save_model(tmp.name)
|
||||
tmp.close()
|
||||
booster = open(tmp.name, "rb").read()
|
||||
os.remove(tmp.name)
|
||||
this.update({"_Booster": booster})
|
||||
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
|
||||
tmp.write(state["_Booster"])
|
||||
tmp.close()
|
||||
booster = Booster(model_file=tmp.name)
|
||||
os.remove(tmp.name)
|
||||
|
||||
state["_Booster"] = booster
|
||||
self.__dict__.update(state)
|
||||
|
||||
def get_xgb_params(self):
|
||||
xgb_params = self.get_params()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user