ENH: Don't use tempfiles for save/load

This commit is contained in:
Skipper Seabold 2015-05-06 14:59:14 -05:00
parent 11fa419720
commit 13837060f1

View File

@ -15,7 +15,7 @@ import re
import ctypes import ctypes
import platform import platform
import collections import collections
from tempfile import NamedTemporaryFile from io import BytesIO
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
@ -71,6 +71,8 @@ def load_xglib():
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float) lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char)
lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p
return lib return lib
@ -468,10 +470,19 @@ class Booster(object):
Parameters Parameters
---------- ----------
fname : string fname : string or file handle
Output file name. Output file name or handle. If a handle is given must be a BytesIO
object or a file opened for writing in binary format.
""" """
xglib.XGBoosterSaveModel(self.handle, c_str(fname)) if isinstance(fname, string_types): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
else:
length = ctypes.c_ulong()
cptr = xglib.XGBoosterGetModelRaw(self.handle,
ctypes.byref(length))
address = ctypes.addressof(cptr.contents)
buf = (ctypes.c_char * length.value).from_address(address)
fname.write(buf)
def load_model(self, fname): def load_model(self, fname):
""" """
@ -479,10 +490,16 @@ class Booster(object):
Parameters Parameters
---------- ----------
fname : string fname : string of file handle
Input file name. Input file name or file handle object.
""" """
xglib.XGBoosterLoadModel(self.handle, c_str(fname)) if isinstance(fname, string_types): # assume file name
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
else:
buf = fname.getbuffer()
length = ctypes.c_ulong(buf.nbytes)
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
def dump_model(self, fo, fmap='', with_stats=False): def dump_model(self, fo, fmap='', with_stats=False):
""" """
@ -841,28 +858,20 @@ class XGBModel(XGBModelBase):
self._Booster = Booster() self._Booster = Booster()
def __getstate__(self): def __getstate__(self):
# can't pickle ctypes pointers so save _Booster directly # can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place this = self.__dict__.copy() # don't modify in place
# delete = False for x-platform compatibility tmp = BytesIO()
# https://bugs.python.org/issue14243 this["_Booster"].save_model(tmp)
with NamedTemporaryFile(mode="wb", delete=False) as tmp: tmp.seek(0)
this["_Booster"].save_model(tmp.name) this["_Booster"] = tmp
tmp.close()
booster = open(tmp.name, "rb").read()
os.remove(tmp.name)
this.update({"_Booster": booster})
return this return this
def __setstate__(self, state): def __setstate__(self, state):
with NamedTemporaryFile(mode="wb", delete=False) as tmp: booster = state["_Booster"]
tmp.write(state["_Booster"]) state["_Booster"] = Booster(model_file=booster)
tmp.close()
booster = Booster(model_file=tmp.name)
os.remove(tmp.name)
state["_Booster"] = booster
self.__dict__.update(state) self.__dict__.update(state)
def get_xgb_params(self): def get_xgb_params(self):