From 11fa4197208bf3c3054c93a9609602eb7b8b9c4d Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Wed, 6 May 2015 12:33:43 -0500 Subject: [PATCH 1/2] ENH: Make XGBModel pickleable. --- wrapper/xgboost.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 8ef82b2c7..4a5248818 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -15,6 +15,7 @@ import re import ctypes import platform import collections +from tempfile import NamedTemporaryFile import numpy as np import scipy.sparse @@ -839,6 +840,31 @@ class XGBModel(XGBModelBase): self._Booster = Booster() + def __getstate__(self): + # can't pickle ctypes pointers so save _Booster directly + this = self.__dict__.copy() # don't modify in place + + # delete = False for x-platform compatibility + # https://bugs.python.org/issue14243 + with NamedTemporaryFile(mode="wb", delete=False) as tmp: + this["_Booster"].save_model(tmp.name) + tmp.close() + booster = open(tmp.name, "rb").read() + os.remove(tmp.name) + this.update({"_Booster": booster}) + + return this + + def __setstate__(self, state): + with NamedTemporaryFile(mode="wb", delete=False) as tmp: + tmp.write(state["_Booster"]) + tmp.close() + booster = Booster(model_file=tmp.name) + os.remove(tmp.name) + + state["_Booster"] = booster + self.__dict__.update(state) + def get_xgb_params(self): xgb_params = self.get_params() From 13837060f1a17aa691c1ac57969ce21f7969e0f2 Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Wed, 6 May 2015 14:59:14 -0500 Subject: [PATCH 2/2] ENH: Don't use tempfiles for save/load --- wrapper/xgboost.py | 55 +++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 4a5248818..48fa02b76 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -15,7 +15,7 @@ import re import ctypes import platform import collections -from tempfile import NamedTemporaryFile +from io import BytesIO import numpy as np import scipy.sparse @@ -71,6 +71,8 @@ def load_xglib(): lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float) lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) + lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char) + lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p return lib @@ -468,10 +470,19 @@ class Booster(object): Parameters ---------- - fname : string - Output file name. + fname : string or file handle + Output file name or handle. If a handle is given must be a BytesIO + object or a file opened for writing in binary format. """ - xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + if isinstance(fname, string_types): # assume file name + xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + else: + length = ctypes.c_ulong() + cptr = xglib.XGBoosterGetModelRaw(self.handle, + ctypes.byref(length)) + address = ctypes.addressof(cptr.contents) + buf = (ctypes.c_char * length.value).from_address(address) + fname.write(buf) def load_model(self, fname): """ @@ -479,10 +490,16 @@ class Booster(object): Parameters ---------- - fname : string - Input file name. + fname : string of file handle + Input file name or file handle object. """ - xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + if isinstance(fname, string_types): # assume file name + xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + else: + buf = fname.getbuffer() + length = ctypes.c_ulong(buf.nbytes) + ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf)) + xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length) def dump_model(self, fo, fmap='', with_stats=False): """ @@ -841,28 +858,20 @@ class XGBModel(XGBModelBase): self._Booster = Booster() def __getstate__(self): - # can't pickle ctypes pointers so save _Booster directly + # can't pickle ctypes pointers so put _Booster in a BytesIO obj + this = self.__dict__.copy() # don't modify in place - # delete = False for x-platform compatibility - # https://bugs.python.org/issue14243 - with NamedTemporaryFile(mode="wb", delete=False) as tmp: - this["_Booster"].save_model(tmp.name) - tmp.close() - booster = open(tmp.name, "rb").read() - os.remove(tmp.name) - this.update({"_Booster": booster}) + tmp = BytesIO() + this["_Booster"].save_model(tmp) + tmp.seek(0) + this["_Booster"] = tmp return this def __setstate__(self, state): - with NamedTemporaryFile(mode="wb", delete=False) as tmp: - tmp.write(state["_Booster"]) - tmp.close() - booster = Booster(model_file=tmp.name) - os.remove(tmp.name) - - state["_Booster"] = booster + booster = state["_Booster"] + state["_Booster"] = Booster(model_file=booster) self.__dict__.update(state) def get_xgb_params(self):