fix saveraw
This commit is contained in:
parent
382dcf6c34
commit
594bed34e4
@ -91,6 +91,13 @@ def ctypes2numpy(cptr, length, dtype):
|
|||||||
raise RuntimeError('memmove failed')
|
raise RuntimeError('memmove failed')
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def ctypes2buffer(cptr, length):
|
||||||
|
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||||
|
raise RuntimeError('expected char pointer')
|
||||||
|
res = np.zeros(length, dtype='uint8')
|
||||||
|
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
||||||
|
raise RuntimeError('memmove failed')
|
||||||
|
return res
|
||||||
|
|
||||||
def c_str(string):
|
def c_str(string):
|
||||||
return ctypes.c_char_p(string.encode('utf-8'))
|
return ctypes.c_char_p(string.encode('utf-8'))
|
||||||
@ -470,19 +477,26 @@ class Booster(object):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string or file handle
|
fname : string
|
||||||
Output file name or handle. If a handle is given must be a BytesIO
|
Output file name or handle
|
||||||
object or a file opened for writing in binary format.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(fname, string_types): # assume file name
|
if isinstance(fname, string_types): # assume file name
|
||||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||||
else:
|
else:
|
||||||
|
raise Exception("fname must be a string")
|
||||||
|
|
||||||
|
def save_raw(self):
|
||||||
|
"""
|
||||||
|
Save the model to a in memory buffer represetation
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
a in memory buffer represetation of the model
|
||||||
|
"""
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
cptr = xglib.XGBoosterGetModelRaw(self.handle,
|
cptr = xglib.XGBoosterGetModelRaw(self.handle,
|
||||||
ctypes.byref(length))
|
ctypes.byref(length))
|
||||||
address = ctypes.addressof(cptr.contents)
|
return ctypes2buffer(cptr, length.value)
|
||||||
buf = (ctypes.c_char * length.value).from_address(address)
|
|
||||||
fname.write(buf)
|
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname):
|
||||||
"""
|
"""
|
||||||
@ -491,14 +505,14 @@ class Booster(object):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string of file handle
|
fname : string of file handle
|
||||||
Input file name or file handle object.
|
Input file name or memory buffer(see also save_raw)
|
||||||
"""
|
"""
|
||||||
if isinstance(fname, string_types): # assume file name
|
if isinstance(fname, str): # assume file name
|
||||||
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||||
else:
|
else:
|
||||||
buf = fname.getbuffer()
|
buf = fname
|
||||||
length = ctypes.c_ulong(buf.nbytes)
|
length = ctypes.c_ulong(len(buf))
|
||||||
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
|
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||||
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
|
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
|
||||||
|
|
||||||
def dump_model(self, fo, fmap='', with_stats=False):
|
def dump_model(self, fo, fmap='', with_stats=False):
|
||||||
@ -861,11 +875,8 @@ class XGBModel(XGBModelBase):
|
|||||||
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
||||||
|
|
||||||
this = self.__dict__.copy() # don't modify in place
|
this = self.__dict__.copy() # don't modify in place
|
||||||
|
raw = this["_Booster"].save_raw()
|
||||||
tmp = BytesIO()
|
this["_Booster"] = raw
|
||||||
this["_Booster"].save_model(tmp)
|
|
||||||
tmp.seek(0)
|
|
||||||
this["_Booster"] = tmp
|
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user