fix saveraw
This commit is contained in:
parent
382dcf6c34
commit
594bed34e4
@ -91,7 +91,14 @@ def ctypes2numpy(cptr, length, dtype):
|
||||
raise RuntimeError('memmove failed')
|
||||
return res
|
||||
|
||||
|
||||
def ctypes2buffer(cptr, length):
|
||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||
raise RuntimeError('expected char pointer')
|
||||
res = np.zeros(length, dtype='uint8')
|
||||
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
||||
raise RuntimeError('memmove failed')
|
||||
return res
|
||||
|
||||
def c_str(string):
|
||||
return ctypes.c_char_p(string.encode('utf-8'))
|
||||
|
||||
@ -470,19 +477,26 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string or file handle
|
||||
Output file name or handle. If a handle is given must be a BytesIO
|
||||
object or a file opened for writing in binary format.
|
||||
fname : string
|
||||
Output file name or handle
|
||||
"""
|
||||
if isinstance(fname, string_types): # assume file name
|
||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||
else:
|
||||
length = ctypes.c_ulong()
|
||||
cptr = xglib.XGBoosterGetModelRaw(self.handle,
|
||||
ctypes.byref(length))
|
||||
address = ctypes.addressof(cptr.contents)
|
||||
buf = (ctypes.c_char * length.value).from_address(address)
|
||||
fname.write(buf)
|
||||
raise Exception("fname must be a string")
|
||||
|
||||
def save_raw(self):
|
||||
"""
|
||||
Save the model to a in memory buffer represetation
|
||||
|
||||
Returns
|
||||
-------
|
||||
a in memory buffer represetation of the model
|
||||
"""
|
||||
length = ctypes.c_ulong()
|
||||
cptr = xglib.XGBoosterGetModelRaw(self.handle,
|
||||
ctypes.byref(length))
|
||||
return ctypes2buffer(cptr, length.value)
|
||||
|
||||
def load_model(self, fname):
|
||||
"""
|
||||
@ -491,14 +505,14 @@ class Booster(object):
|
||||
Parameters
|
||||
----------
|
||||
fname : string of file handle
|
||||
Input file name or file handle object.
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
"""
|
||||
if isinstance(fname, string_types): # assume file name
|
||||
if isinstance(fname, str): # assume file name
|
||||
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||
else:
|
||||
buf = fname.getbuffer()
|
||||
length = ctypes.c_ulong(buf.nbytes)
|
||||
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
|
||||
buf = fname
|
||||
length = ctypes.c_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
|
||||
|
||||
def dump_model(self, fo, fmap='', with_stats=False):
|
||||
@ -861,12 +875,9 @@ class XGBModel(XGBModelBase):
|
||||
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
||||
|
||||
this = self.__dict__.copy() # don't modify in place
|
||||
|
||||
tmp = BytesIO()
|
||||
this["_Booster"].save_model(tmp)
|
||||
tmp.seek(0)
|
||||
this["_Booster"] = tmp
|
||||
|
||||
raw = this["_Booster"].save_raw()
|
||||
this["_Booster"] = raw
|
||||
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user