checkin copy
This commit is contained in:
parent
e6b8b23a2c
commit
91a5390929
@ -3,7 +3,6 @@ import numpy as np
|
||||
import scipy.sparse
|
||||
import pickle
|
||||
import xgboost as xgb
|
||||
import copy
|
||||
|
||||
### simple example
|
||||
# load file from text file, also binary buffer generated by xgboost
|
||||
|
||||
@ -127,7 +127,6 @@ class DMatrix(object):
|
||||
weight : list or numpy 1-D array (optional)
|
||||
Weight for each instance.
|
||||
"""
|
||||
|
||||
# force into void_p, mac need to pass things in as void_p
|
||||
if data is None:
|
||||
self.handle = None
|
||||
@ -348,6 +347,46 @@ class Booster(object):
|
||||
def __del__(self):
|
||||
xglib.XGBoosterFree(self.handle)
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle ctypes pointers
|
||||
# put model content in bytearray
|
||||
this = self.__dict__.copy()
|
||||
handle = this['handle']
|
||||
if handle is not None:
|
||||
raw = self.save_raw()
|
||||
this["handle"] = raw
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
# reconstruct handle from raw data
|
||||
handle = state['handle']
|
||||
if handle is not None:
|
||||
buf = handle
|
||||
dmats = c_array(ctypes.c_void_p, [])
|
||||
handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, 0))
|
||||
length = ctypes.c_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
xglib.XGBoosterLoadModelFromBuffer(handle, ptr, length)
|
||||
state['handle'] = handle
|
||||
self.__dict__.update(state)
|
||||
self.set_param({'seed': 0})
|
||||
|
||||
def __copy__(self):
|
||||
return self.__deepcopy__()
|
||||
|
||||
def __deepcopy__(self):
|
||||
return Booster(model_file = self.save_raw())
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Copy the booster object
|
||||
|
||||
Returns
|
||||
--------
|
||||
a copied booster model
|
||||
"""
|
||||
return self.__copy__()
|
||||
|
||||
def set_param(self, params, pv=None):
|
||||
if isinstance(params, collections.Mapping):
|
||||
params = params.items()
|
||||
@ -440,6 +479,11 @@ class Booster(object):
|
||||
"""
|
||||
Predict with data.
|
||||
|
||||
NOTE: This function is not thread safe.
|
||||
For each booster object, predict can only be called from one thread.
|
||||
If you want to run prediction using multiple thread, call bst.copy() to make copies
|
||||
of model object and then call predict
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : DMatrix
|
||||
@ -874,18 +918,12 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
self._Booster = None
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle ctypes pointers so put _Booster in a bytearray object
|
||||
this = self.__dict__.copy() # don't modify in place
|
||||
bst = this["_Booster"]
|
||||
if bst is not None:
|
||||
raw = this["_Booster"].save_raw()
|
||||
this["_Booster"] = raw
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
# backward compatiblity code
|
||||
# load booster from raw if it is raw
|
||||
# the booster now support pickle
|
||||
bst = state["_Booster"]
|
||||
if bst is not None:
|
||||
if bst is not None and not isinstance(bst, Booster):
|
||||
state["_Booster"] = Booster(model_file=bst)
|
||||
self.__dict__.update(state)
|
||||
|
||||
@ -977,7 +1015,6 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
||||
classzero_probs = 1.0 - classone_probs
|
||||
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||
|
||||
|
||||
class XGBRegressor(XGBModel, XGBRegressor):
|
||||
__doc__ = """
|
||||
Implementation of the scikit-learn API for XGBoost regression
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user