Merge branch 'jseabold-xgb-pickleable'

This commit is contained in:
tqchen 2015-05-06 16:03:36 -07:00
commit 3244f1e9ae

View File

@ -15,6 +15,7 @@ import re
import ctypes
import platform
import collections
from io import BytesIO
import numpy as np
import scipy.sparse
@ -70,6 +71,8 @@ def load_xglib():
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
lib.XGBoosterGetModelRaw.restype = ctypes.POINTER(ctypes.c_char)
lib.XGBoosterLoadModelFromBuffer.restype = ctypes.c_void_p
return lib
@ -467,10 +470,19 @@ class Booster(object):
Parameters
----------
fname : string
Output file name.
fname : string or file handle
Output file name or handle. If a handle is given must be a BytesIO
object or a file opened for writing in binary format.
"""
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
if isinstance(fname, string_types): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
else:
length = ctypes.c_ulong()
cptr = xglib.XGBoosterGetModelRaw(self.handle,
ctypes.byref(length))
address = ctypes.addressof(cptr.contents)
buf = (ctypes.c_char * length.value).from_address(address)
fname.write(buf)
def load_model(self, fname):
"""
@ -478,10 +490,16 @@ class Booster(object):
Parameters
----------
fname : string
Input file name.
fname : string of file handle
Input file name or file handle object.
"""
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
if isinstance(fname, string_types): # assume file name
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
else:
buf = fname.getbuffer()
length = ctypes.c_ulong(buf.nbytes)
ptr = ctypes.byref(ctypes.c_void_p.from_buffer(buf))
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
def dump_model(self, fo, fmap='', with_stats=False):
"""
@ -839,6 +857,23 @@ class XGBModel(XGBModelBase):
self._Booster = Booster()
def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place
tmp = BytesIO()
this["_Booster"].save_model(tmp)
tmp.seek(0)
this["_Booster"] = tmp
return this
def __setstate__(self, state):
booster = state["_Booster"]
state["_Booster"] = Booster(model_file=booster)
self.__dict__.update(state)
def get_xgb_params(self):
xgb_params = self.get_params()