diff --git a/demo/guide-python/basic_walkthrough.py b/demo/guide-python/basic_walkthrough.py index cdff65c33..5bfa55935 100755 --- a/demo/guide-python/basic_walkthrough.py +++ b/demo/guide-python/basic_walkthrough.py @@ -3,7 +3,6 @@ import numpy as np import scipy.sparse import pickle import xgboost as xgb -import copy ### simple example # load file from text file, also binary buffer generated by xgboost diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 46a229cd6..8d9c82b80 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -127,7 +127,6 @@ class DMatrix(object): weight : list or numpy 1-D array (optional) Weight for each instance. """ - # force into void_p, mac need to pass things in as void_p if data is None: self.handle = None @@ -348,6 +347,46 @@ class Booster(object): def __del__(self): xglib.XGBoosterFree(self.handle) + def __getstate__(self): + # can't pickle ctypes pointers + # put model content in bytearray + this = self.__dict__.copy() + handle = this['handle'] + if handle is not None: + raw = self.save_raw() + this["handle"] = raw + return this + + def __setstate__(self, state): + # reconstruct handle from raw data + handle = state['handle'] + if handle is not None: + buf = handle + dmats = c_array(ctypes.c_void_p, []) + handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, 0)) + length = ctypes.c_ulong(len(buf)) + ptr = (ctypes.c_char * len(buf)).from_buffer(buf) + xglib.XGBoosterLoadModelFromBuffer(handle, ptr, length) + state['handle'] = handle + self.__dict__.update(state) + self.set_param({'seed': 0}) + + def __copy__(self): + return self.__deepcopy__() + + def __deepcopy__(self): + return Booster(model_file = self.save_raw()) + + def copy(self): + """ + Copy the booster object + + Returns + -------- + a copied booster model + """ + return self.__copy__() + def set_param(self, params, pv=None): if isinstance(params, collections.Mapping): params = params.items() @@ -440,6 +479,11 @@ class Booster(object): """ Predict with data. + NOTE: This function is not thread safe. + For each booster object, predict can only be called from one thread. + If you want to run prediction using multiple thread, call bst.copy() to make copies + of model object and then call predict + Parameters ---------- data : DMatrix @@ -874,18 +918,12 @@ class XGBModel(XGBModelBase): self._Booster = None - def __getstate__(self): - # can't pickle ctypes pointers so put _Booster in a bytearray object - this = self.__dict__.copy() # don't modify in place - bst = this["_Booster"] - if bst is not None: - raw = this["_Booster"].save_raw() - this["_Booster"] = raw - return this - def __setstate__(self, state): + # backward compatiblity code + # load booster from raw if it is raw + # the booster now support pickle bst = state["_Booster"] - if bst is not None: + if bst is not None and not isinstance(bst, Booster): state["_Booster"] = Booster(model_file=bst) self.__dict__.update(state) @@ -977,7 +1015,6 @@ class XGBClassifier(XGBModel, XGBClassifier): classzero_probs = 1.0 - classone_probs return np.vstack((classzero_probs, classone_probs)).transpose() - class XGBRegressor(XGBModel, XGBRegressor): __doc__ = """ Implementation of the scikit-learn API for XGBoost regression