ENH: Don't use tempfiles for save/load

This commit is contained in:
Skipper Seabold 2015-05-06 14:59:14 -05:00
parent 11fa419720
commit 13837060f1

View File

@ -15,7 +15,7 @@ import re
import ctypes
import platform
import collections
from tempfile import NamedTemporaryFile
from io import BytesIO
import numpy as np
import scipy.sparse
@ -71,6 +71,8 @@ def load_xglib():
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char)
lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p
return lib
@ -468,10 +470,19 @@ class Booster(object):
Parameters
----------
fname : string
Output file name.
fname : string or file handle
Output file name or handle. If a handle is given must be a BytesIO
object or a file opened for writing in binary format.
"""
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
if isinstance(fname, string_types): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
else:
length = ctypes.c_ulong()
cptr = xglib.XGBoosterGetModelRaw(self.handle,
ctypes.byref(length))
address = ctypes.addressof(cptr.contents)
buf = (ctypes.c_char * length.value).from_address(address)
fname.write(buf)
def load_model(self, fname):
"""
@ -479,10 +490,16 @@ class Booster(object):
Parameters
----------
fname : string
Input file name.
fname : string of file handle
Input file name or file handle object.
"""
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
if isinstance(fname, string_types): # assume file name
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
else:
buf = fname.getbuffer()
length = ctypes.c_ulong(buf.nbytes)
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
def dump_model(self, fo, fmap='', with_stats=False):
"""
@ -841,28 +858,20 @@ class XGBModel(XGBModelBase):
self._Booster = Booster()
def __getstate__(self):
# can't pickle ctypes pointers so save _Booster directly
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place
# delete = False for x-platform compatibility
# https://bugs.python.org/issue14243
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
this["_Booster"].save_model(tmp.name)
tmp.close()
booster = open(tmp.name, "rb").read()
os.remove(tmp.name)
this.update({"_Booster": booster})
tmp = BytesIO()
this["_Booster"].save_model(tmp)
tmp.seek(0)
this["_Booster"] = tmp
return this
def __setstate__(self, state):
with NamedTemporaryFile(mode="wb", delete=False) as tmp:
tmp.write(state["_Booster"])
tmp.close()
booster = Booster(model_file=tmp.name)
os.remove(tmp.name)
state["_Booster"] = booster
booster = state["_Booster"]
state["_Booster"] = Booster(model_file=booster)
self.__dict__.update(state)
def get_xgb_params(self):