Merge branch 'jseabold-xgb-pickleable'
This commit is contained in:
commit
382dcf6c34
@ -15,6 +15,7 @@ import re
|
|||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
import collections
|
import collections
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
@ -70,6 +71,8 @@ def load_xglib():
|
|||||||
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
|
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
||||||
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
||||||
|
lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char)
|
||||||
|
lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p
|
||||||
|
|
||||||
return lib
|
return lib
|
||||||
|
|
||||||
@ -467,10 +470,19 @@ class Booster(object):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string
|
fname : string or file handle
|
||||||
Output file name.
|
Output file name or handle. If a handle is given must be a BytesIO
|
||||||
|
object or a file opened for writing in binary format.
|
||||||
"""
|
"""
|
||||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
if isinstance(fname, string_types): # assume file name
|
||||||
|
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||||
|
else:
|
||||||
|
length = ctypes.c_ulong()
|
||||||
|
cptr = xglib.XGBoosterGetModelRaw(self.handle,
|
||||||
|
ctypes.byref(length))
|
||||||
|
address = ctypes.addressof(cptr.contents)
|
||||||
|
buf = (ctypes.c_char * length.value).from_address(address)
|
||||||
|
fname.write(buf)
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname):
|
||||||
"""
|
"""
|
||||||
@ -478,10 +490,16 @@ class Booster(object):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string
|
fname : string of file handle
|
||||||
Input file name.
|
Input file name or file handle object.
|
||||||
"""
|
"""
|
||||||
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
if isinstance(fname, string_types): # assume file name
|
||||||
|
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||||
|
else:
|
||||||
|
buf = fname.getbuffer()
|
||||||
|
length = ctypes.c_ulong(buf.nbytes)
|
||||||
|
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
|
||||||
|
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
|
||||||
|
|
||||||
def dump_model(self, fo, fmap='', with_stats=False):
|
def dump_model(self, fo, fmap='', with_stats=False):
|
||||||
"""
|
"""
|
||||||
@ -839,6 +857,23 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
self._Booster = Booster()
|
self._Booster = Booster()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
||||||
|
|
||||||
|
this = self.__dict__.copy() # don't modify in place
|
||||||
|
|
||||||
|
tmp = BytesIO()
|
||||||
|
this["_Booster"].save_model(tmp)
|
||||||
|
tmp.seek(0)
|
||||||
|
this["_Booster"] = tmp
|
||||||
|
|
||||||
|
return this
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
booster = state["_Booster"]
|
||||||
|
state["_Booster"] = Booster(model_file=booster)
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
xgb_params = self.get_params()
|
xgb_params = self.get_params()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user