From 594bed34e474028274cc4925fe08764b10538859 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 6 May 2015 16:42:27 -0700 Subject: [PATCH] fix saveraw --- wrapper/xgboost.py | 53 ++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 48fa02b76..9051061d0 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -91,7 +91,14 @@ def ctypes2numpy(cptr, length, dtype): raise RuntimeError('memmove failed') return res - +def ctypes2buffer(cptr, length): + if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): + raise RuntimeError('expected char pointer') + res = np.zeros(length, dtype='uint8') + if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): + raise RuntimeError('memmove failed') + return res + def c_str(string): return ctypes.c_char_p(string.encode('utf-8')) @@ -470,19 +477,26 @@ class Booster(object): Parameters ---------- - fname : string or file handle - Output file name or handle. If a handle is given must be a BytesIO - object or a file opened for writing in binary format. + fname : string + Output file name or handle """ if isinstance(fname, string_types): # assume file name xglib.XGBoosterSaveModel(self.handle, c_str(fname)) else: - length = ctypes.c_ulong() - cptr = xglib.XGBoosterGetModelRaw(self.handle, - ctypes.byref(length)) - address = ctypes.addressof(cptr.contents) - buf = (ctypes.c_char * length.value).from_address(address) - fname.write(buf) + raise Exception("fname must be a string") + + def save_raw(self): + """ + Save the model to a in memory buffer represetation + + Returns + ------- + a in memory buffer represetation of the model + """ + length = ctypes.c_ulong() + cptr = xglib.XGBoosterGetModelRaw(self.handle, + ctypes.byref(length)) + return ctypes2buffer(cptr, length.value) def load_model(self, fname): """ @@ -491,14 +505,14 @@ class Booster(object): Parameters ---------- fname : string of file handle - Input file name or file handle object. + Input file name or memory buffer(see also save_raw) """ - if isinstance(fname, string_types): # assume file name + if isinstance(fname, str): # assume file name xglib.XGBoosterLoadModel(self.handle, c_str(fname)) else: - buf = fname.getbuffer() - length = ctypes.c_ulong(buf.nbytes) - ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf)) + buf = fname + length = ctypes.c_ulong(len(buf)) + ptr = (ctypes.c_char * len(buf)).from_buffer(buf) xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length) def dump_model(self, fo, fmap='', with_stats=False): @@ -861,12 +875,9 @@ class XGBModel(XGBModelBase): # can't pickle ctypes pointers so put _Booster in a BytesIO obj this = self.__dict__.copy() # don't modify in place - - tmp = BytesIO() - this["_Booster"].save_model(tmp) - tmp.seek(0) - this["_Booster"] = tmp - + raw = this["_Booster"].save_raw() + this["_Booster"] = raw + return this def __setstate__(self, state):