diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 8ef82b2c7..48fa02b76 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -15,6 +15,7 @@ import re import ctypes import platform import collections +from io import BytesIO import numpy as np import scipy.sparse @@ -70,6 +71,8 @@ def load_xglib(): lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float) lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) + lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char) + lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p return lib @@ -467,10 +470,19 @@ class Booster(object): Parameters ---------- - fname : string - Output file name. + fname : string or file handle + Output file name or handle. If a handle is given must be a BytesIO + object or a file opened for writing in binary format. """ - xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + if isinstance(fname, string_types): # assume file name + xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + else: + length = ctypes.c_ulong() + cptr = xglib.XGBoosterGetModelRaw(self.handle, + ctypes.byref(length)) + address = ctypes.addressof(cptr.contents) + buf = (ctypes.c_char * length.value).from_address(address) + fname.write(buf) def load_model(self, fname): """ @@ -478,10 +490,16 @@ class Booster(object): Parameters ---------- - fname : string - Input file name. + fname : string of file handle + Input file name or file handle object. """ - xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + if isinstance(fname, string_types): # assume file name + xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + else: + buf = fname.getbuffer() + length = ctypes.c_ulong(buf.nbytes) + ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf)) + xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length) def dump_model(self, fo, fmap='', with_stats=False): """ @@ -839,6 +857,23 @@ class XGBModel(XGBModelBase): self._Booster = Booster() + def __getstate__(self): + # can't pickle ctypes pointers so put _Booster in a BytesIO obj + + this = self.__dict__.copy() # don't modify in place + + tmp = BytesIO() + this["_Booster"].save_model(tmp) + tmp.seek(0) + this["_Booster"] = tmp + + return this + + def __setstate__(self, state): + booster = state["_Booster"] + state["_Booster"] = Booster(model_file=booster) + self.__dict__.update(state) + def get_xgb_params(self): xgb_params = self.get_params()