checkin copy
This commit is contained in:
parent
e6b8b23a2c
commit
91a5390929
@ -3,7 +3,6 @@ import numpy as np
|
|||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
import pickle
|
import pickle
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import copy
|
|
||||||
|
|
||||||
### simple example
|
### simple example
|
||||||
# load file from text file, also binary buffer generated by xgboost
|
# load file from text file, also binary buffer generated by xgboost
|
||||||
|
|||||||
@ -127,7 +127,6 @@ class DMatrix(object):
|
|||||||
weight : list or numpy 1-D array (optional)
|
weight : list or numpy 1-D array (optional)
|
||||||
Weight for each instance.
|
Weight for each instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# force into void_p, mac need to pass things in as void_p
|
# force into void_p, mac need to pass things in as void_p
|
||||||
if data is None:
|
if data is None:
|
||||||
self.handle = None
|
self.handle = None
|
||||||
@ -348,6 +347,46 @@ class Booster(object):
|
|||||||
def __del__(self):
|
def __del__(self):
|
||||||
xglib.XGBoosterFree(self.handle)
|
xglib.XGBoosterFree(self.handle)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# can't pickle ctypes pointers
|
||||||
|
# put model content in bytearray
|
||||||
|
this = self.__dict__.copy()
|
||||||
|
handle = this['handle']
|
||||||
|
if handle is not None:
|
||||||
|
raw = self.save_raw()
|
||||||
|
this["handle"] = raw
|
||||||
|
return this
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# reconstruct handle from raw data
|
||||||
|
handle = state['handle']
|
||||||
|
if handle is not None:
|
||||||
|
buf = handle
|
||||||
|
dmats = c_array(ctypes.c_void_p, [])
|
||||||
|
handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, 0))
|
||||||
|
length = ctypes.c_ulong(len(buf))
|
||||||
|
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||||
|
xglib.XGBoosterLoadModelFromBuffer(handle, ptr, length)
|
||||||
|
state['handle'] = handle
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.set_param({'seed': 0})
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
return self.__deepcopy__()
|
||||||
|
|
||||||
|
def __deepcopy__(self):
|
||||||
|
return Booster(model_file = self.save_raw())
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
"""
|
||||||
|
Copy the booster object
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
a copied booster model
|
||||||
|
"""
|
||||||
|
return self.__copy__()
|
||||||
|
|
||||||
def set_param(self, params, pv=None):
|
def set_param(self, params, pv=None):
|
||||||
if isinstance(params, collections.Mapping):
|
if isinstance(params, collections.Mapping):
|
||||||
params = params.items()
|
params = params.items()
|
||||||
@ -440,6 +479,11 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
Predict with data.
|
Predict with data.
|
||||||
|
|
||||||
|
NOTE: This function is not thread safe.
|
||||||
|
For each booster object, predict can only be called from one thread.
|
||||||
|
If you want to run prediction using multiple thread, call bst.copy() to make copies
|
||||||
|
of model object and then call predict
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
data : DMatrix
|
data : DMatrix
|
||||||
@ -874,18 +918,12 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
self._Booster = None
|
self._Booster = None
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
# can't pickle ctypes pointers so put _Booster in a bytearray object
|
|
||||||
this = self.__dict__.copy() # don't modify in place
|
|
||||||
bst = this["_Booster"]
|
|
||||||
if bst is not None:
|
|
||||||
raw = this["_Booster"].save_raw()
|
|
||||||
this["_Booster"] = raw
|
|
||||||
return this
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
# backward compatiblity code
|
||||||
|
# load booster from raw if it is raw
|
||||||
|
# the booster now support pickle
|
||||||
bst = state["_Booster"]
|
bst = state["_Booster"]
|
||||||
if bst is not None:
|
if bst is not None and not isinstance(bst, Booster):
|
||||||
state["_Booster"] = Booster(model_file=bst)
|
state["_Booster"] = Booster(model_file=bst)
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
@ -977,7 +1015,6 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
|||||||
classzero_probs = 1.0 - classone_probs
|
classzero_probs = 1.0 - classone_probs
|
||||||
return np.vstack((classzero_probs, classone_probs)).transpose()
|
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||||
|
|
||||||
|
|
||||||
class XGBRegressor(XGBModel, XGBRegressor):
|
class XGBRegressor(XGBModel, XGBRegressor):
|
||||||
__doc__ = """
|
__doc__ = """
|
||||||
Implementation of the scikit-learn API for XGBoost regression
|
Implementation of the scikit-learn API for XGBoost regression
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user