fix saveraw

This commit is contained in:
tqchen 2015-05-06 16:42:27 -07:00
parent 382dcf6c34
commit 594bed34e4

View File

@ -91,7 +91,14 @@ def ctypes2numpy(cptr, length, dtype):
raise RuntimeError('memmove failed')
return res
def ctypes2buffer(cptr, length):
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise RuntimeError('expected char pointer')
res = np.zeros(length, dtype='uint8')
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
raise RuntimeError('memmove failed')
return res
def c_str(string):
return ctypes.c_char_p(string.encode('utf-8'))
@ -470,19 +477,26 @@ class Booster(object):
Parameters
----------
fname : string or file handle
Output file name or handle. If a handle is given must be a BytesIO
object or a file opened for writing in binary format.
fname : string
Output file name or handle
"""
if isinstance(fname, string_types): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
else:
length = ctypes.c_ulong()
cptr = xglib.XGBoosterGetModelRaw(self.handle,
ctypes.byref(length))
address = ctypes.addressof(cptr.contents)
buf = (ctypes.c_char * length.value).from_address(address)
fname.write(buf)
raise Exception("fname must be a string")
def save_raw(self):
"""
Save the model to a in memory buffer represetation
Returns
-------
a in memory buffer represetation of the model
"""
length = ctypes.c_ulong()
cptr = xglib.XGBoosterGetModelRaw(self.handle,
ctypes.byref(length))
return ctypes2buffer(cptr, length.value)
def load_model(self, fname):
"""
@ -491,14 +505,14 @@ class Booster(object):
Parameters
----------
fname : string of file handle
Input file name or file handle object.
Input file name or memory buffer(see also save_raw)
"""
if isinstance(fname, string_types): # assume file name
if isinstance(fname, str): # assume file name
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
else:
buf = fname.getbuffer()
length = ctypes.c_ulong(buf.nbytes)
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
buf = fname
length = ctypes.c_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
def dump_model(self, fo, fmap='', with_stats=False):
@ -861,12 +875,9 @@ class XGBModel(XGBModelBase):
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place
tmp = BytesIO()
this["_Booster"].save_model(tmp)
tmp.seek(0)
this["_Booster"] = tmp
raw = this["_Booster"].save_raw()
this["_Booster"] = raw
return this
def __setstate__(self, state):