checkin copy

This commit is contained in:
tqchen 2015-05-17 21:29:51 -07:00
parent e6b8b23a2c
commit 91a5390929
2 changed files with 49 additions and 13 deletions

View File

@ -3,7 +3,6 @@ import numpy as np
import scipy.sparse import scipy.sparse
import pickle import pickle
import xgboost as xgb import xgboost as xgb
import copy
### simple example ### simple example
# load file from text file, also binary buffer generated by xgboost # load file from text file, also binary buffer generated by xgboost

View File

@ -127,7 +127,6 @@ class DMatrix(object):
weight : list or numpy 1-D array (optional) weight : list or numpy 1-D array (optional)
Weight for each instance. Weight for each instance.
""" """
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
if data is None: if data is None:
self.handle = None self.handle = None
@ -348,6 +347,46 @@ class Booster(object):
def __del__(self): def __del__(self):
xglib.XGBoosterFree(self.handle) xglib.XGBoosterFree(self.handle)
def __getstate__(self):
# can't pickle ctypes pointers
# put model content in bytearray
this = self.__dict__.copy()
handle = this['handle']
if handle is not None:
raw = self.save_raw()
this["handle"] = raw
return this
def __setstate__(self, state):
# reconstruct handle from raw data
handle = state['handle']
if handle is not None:
buf = handle
dmats = c_array(ctypes.c_void_p, [])
handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, 0))
length = ctypes.c_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
xglib.XGBoosterLoadModelFromBuffer(handle, ptr, length)
state['handle'] = handle
self.__dict__.update(state)
self.set_param({'seed': 0})
def __copy__(self):
return self.__deepcopy__()
def __deepcopy__(self):
return Booster(model_file = self.save_raw())
def copy(self):
"""
Copy the booster object
Returns
--------
a copied booster model
"""
return self.__copy__()
def set_param(self, params, pv=None): def set_param(self, params, pv=None):
if isinstance(params, collections.Mapping): if isinstance(params, collections.Mapping):
params = params.items() params = params.items()
@ -440,6 +479,11 @@ class Booster(object):
""" """
Predict with data. Predict with data.
NOTE: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call bst.copy() to make copies
of model object and then call predict
Parameters Parameters
---------- ----------
data : DMatrix data : DMatrix
@ -874,18 +918,12 @@ class XGBModel(XGBModelBase):
self._Booster = None self._Booster = None
def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a bytearray object
this = self.__dict__.copy() # don't modify in place
bst = this["_Booster"]
if bst is not None:
raw = this["_Booster"].save_raw()
this["_Booster"] = raw
return this
def __setstate__(self, state): def __setstate__(self, state):
# backward compatiblity code
# load booster from raw if it is raw
# the booster now support pickle
bst = state["_Booster"] bst = state["_Booster"]
if bst is not None: if bst is not None and not isinstance(bst, Booster):
state["_Booster"] = Booster(model_file=bst) state["_Booster"] = Booster(model_file=bst)
self.__dict__.update(state) self.__dict__.update(state)
@ -977,7 +1015,6 @@ class XGBClassifier(XGBModel, XGBClassifier):
classzero_probs = 1.0 - classone_probs classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose() return np.vstack((classzero_probs, classone_probs)).transpose()
class XGBRegressor(XGBModel, XGBRegressor): class XGBRegressor(XGBModel, XGBRegressor):
__doc__ = """ __doc__ = """
Implementation of the scikit-learn API for XGBoost regression Implementation of the scikit-learn API for XGBoost regression