diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 4a5248818..48fa02b76 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -15,7 +15,7 @@ import re import ctypes import platform import collections -from tempfile import NamedTemporaryFile +from io import BytesIO import numpy as np import scipy.sparse @@ -71,6 +71,8 @@ def load_xglib(): lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float) lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) + lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char) + lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p return lib @@ -468,10 +470,19 @@ class Booster(object): Parameters ---------- - fname : string - Output file name. + fname : string or file handle + Output file name or handle. If a handle is given must be a BytesIO + object or a file opened for writing in binary format. """ - xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + if isinstance(fname, string_types): # assume file name + xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + else: + length = ctypes.c_ulong() + cptr = xglib.XGBoosterGetModelRaw(self.handle, + ctypes.byref(length)) + address = ctypes.addressof(cptr.contents) + buf = (ctypes.c_char * length.value).from_address(address) + fname.write(buf) def load_model(self, fname): """ @@ -479,10 +490,16 @@ class Booster(object): Parameters ---------- - fname : string - Input file name. + fname : string of file handle + Input file name or file handle object. """ - xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + if isinstance(fname, string_types): # assume file name + xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + else: + buf = fname.getbuffer() + length = ctypes.c_ulong(buf.nbytes) + ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf)) + xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length) def dump_model(self, fo, fmap='', with_stats=False): """ @@ -841,28 +858,20 @@ class XGBModel(XGBModelBase): self._Booster = Booster() def __getstate__(self): - # can't pickle ctypes pointers so save _Booster directly + # can't pickle ctypes pointers so put _Booster in a BytesIO obj + this = self.__dict__.copy() # don't modify in place - # delete = False for x-platform compatibility - # https://bugs.python.org/issue14243 - with NamedTemporaryFile(mode="wb", delete=False) as tmp: - this["_Booster"].save_model(tmp.name) - tmp.close() - booster = open(tmp.name, "rb").read() - os.remove(tmp.name) - this.update({"_Booster": booster}) + tmp = BytesIO() + this["_Booster"].save_model(tmp) + tmp.seek(0) + this["_Booster"] = tmp return this def __setstate__(self, state): - with NamedTemporaryFile(mode="wb", delete=False) as tmp: - tmp.write(state["_Booster"]) - tmp.close() - booster = Booster(model_file=tmp.name) - os.remove(tmp.name) - - state["_Booster"] = booster + booster = state["_Booster"] + state["_Booster"] = Booster(model_file=booster) self.__dict__.update(state) def get_xgb_params(self):