Merge branch 'jseabold-xgb-pickleable'
This commit is contained in:
commit
382dcf6c34
@ -15,6 +15,7 @@ import re
|
||||
import ctypes
|
||||
import platform
|
||||
import collections
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
@ -70,6 +71,8 @@ def load_xglib():
|
||||
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
|
||||
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
||||
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
||||
lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char)
|
||||
lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p
|
||||
|
||||
return lib
|
||||
|
||||
@ -467,10 +470,19 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
Output file name.
|
||||
fname : string or file handle
|
||||
Output file name or handle. If a handle is given must be a BytesIO
|
||||
object or a file opened for writing in binary format.
|
||||
"""
|
||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||
if isinstance(fname, string_types): # assume file name
|
||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||
else:
|
||||
length = ctypes.c_ulong()
|
||||
cptr = xglib.XGBoosterGetModelRaw(self.handle,
|
||||
ctypes.byref(length))
|
||||
address = ctypes.addressof(cptr.contents)
|
||||
buf = (ctypes.c_char * length.value).from_address(address)
|
||||
fname.write(buf)
|
||||
|
||||
def load_model(self, fname):
|
||||
"""
|
||||
@ -478,10 +490,16 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
Input file name.
|
||||
fname : string of file handle
|
||||
Input file name or file handle object.
|
||||
"""
|
||||
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||
if isinstance(fname, string_types): # assume file name
|
||||
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||
else:
|
||||
buf = fname.getbuffer()
|
||||
length = ctypes.c_ulong(buf.nbytes)
|
||||
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
|
||||
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
|
||||
|
||||
def dump_model(self, fo, fmap='', with_stats=False):
|
||||
"""
|
||||
@ -839,6 +857,23 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
self._Booster = Booster()
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
||||
|
||||
this = self.__dict__.copy() # don't modify in place
|
||||
|
||||
tmp = BytesIO()
|
||||
this["_Booster"].save_model(tmp)
|
||||
tmp.seek(0)
|
||||
this["_Booster"] = tmp
|
||||
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
booster = state["_Booster"]
|
||||
state["_Booster"] = Booster(model_file=booster)
|
||||
self.__dict__.update(state)
|
||||
|
||||
def get_xgb_params(self):
|
||||
xgb_params = self.get_params()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user