fix pkl problem

This commit is contained in:
tqchen 2015-05-07 18:11:40 -07:00
parent 60bf389825
commit 68444a0626
3 changed files with 36 additions and 16 deletions

View File

@ -31,6 +31,9 @@ except ImportError:
class XGBoostLibraryNotFound(Exception): class XGBoostLibraryNotFound(Exception):
pass pass
class XGBoostError(Exception):
pass
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train'] __all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
@ -483,7 +486,7 @@ class Booster(object):
if isinstance(fname, string_types): # assume file name if isinstance(fname, string_types): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname)) xglib.XGBoosterSaveModel(self.handle, c_str(fname))
else: else:
raise Exception("fname must be a string") raise TypeError("fname must be a string")
def save_raw(self): def save_raw(self):
""" """
@ -852,7 +855,7 @@ class XGBModel(XGBModelBase):
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1,
base_score=0.5, seed=0): base_score=0.5, seed=0):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise Exception('sklearn needs to be installed in order to use this module') raise XGBError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth self.max_depth = max_depth
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_estimators = n_estimators self.n_estimators = n_estimators
@ -869,22 +872,36 @@ class XGBModel(XGBModelBase):
self.base_score = base_score self.base_score = base_score
self.seed = seed self.seed = seed
self._Booster = Booster() self._Booster = None
def __getstate__(self): def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a BytesIO obj # can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place this = self.__dict__.copy() # don't modify in place
bst = this["_Booster"]
if bst is not None:
raw = this["_Booster"].save_raw() raw = this["_Booster"].save_raw()
this["_Booster"] = raw this["_Booster"] = raw
return this return this
def __setstate__(self, state): def __setstate__(self, state):
booster = state["_Booster"] bst = state["_Booster"]
if bst is not None:
state["_Booster"] = Booster(model_file=booster) state["_Booster"] = Booster(model_file=booster)
self.__dict__.update(state) self.__dict__.update(state)
def booster(self):
"""
get the underlying xgboost Booster of this model
will raise an exception when fit was not called
Returns
-------
booster : a xgboost booster of underlying model
"""
if self._Booster is None:
raise XGBError('need to call fit beforehand')
return self._Booster
def get_xgb_params(self): def get_xgb_params(self):
xgb_params = self.get_params() xgb_params = self.get_params()
@ -901,7 +918,7 @@ class XGBModel(XGBModelBase):
def predict(self, X): def predict(self, X):
testDmatrix = DMatrix(X) testDmatrix = DMatrix(X)
return self._Booster.predict(testDmatrix) return self.booster().predict(testDmatrix)
class XGBClassifier(XGBModel, XGBClassifier): class XGBClassifier(XGBModel, XGBClassifier):
@ -942,7 +959,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
def predict(self, X): def predict(self, X):
testDmatrix = DMatrix(X) testDmatrix = DMatrix(X)
class_probs = self._Booster.predict(testDmatrix) class_probs = self.booster().predict(testDmatrix)
if len(class_probs.shape) > 1: if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1) column_indexes = np.argmax(class_probs, axis=1)
else: else:
@ -952,7 +969,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
def predict_proba(self, X): def predict_proba(self, X):
testDmatrix = DMatrix(X) testDmatrix = DMatrix(X)
class_probs = self._Booster.predict(testDmatrix) class_probs = self.booster().predict(testDmatrix)
if self.objective == "multi:softprob": if self.objective == "multi:softprob":
return class_probs return class_probs
else: else:

View File

@ -62,6 +62,7 @@ class Booster: public learner::BoostLearner {
this->init_model = true; this->init_model = true;
} }
inline const char *GetModelRaw(bst_ulong *out_len) { inline const char *GetModelRaw(bst_ulong *out_len) {
this->CheckInitModel();
model_str.resize(0); model_str.resize(0);
utils::MemoryBufferStream fs(&model_str); utils::MemoryBufferStream fs(&model_str);
learner::BoostLearner::SaveModel(fs, false); learner::BoostLearner::SaveModel(fs, false);
@ -322,8 +323,10 @@ extern "C"{
void XGBoosterLoadModel(void *handle, const char *fname) { void XGBoosterLoadModel(void *handle, const char *fname) {
static_cast<Booster*>(handle)->LoadModel(fname); static_cast<Booster*>(handle)->LoadModel(fname);
} }
void XGBoosterSaveModel(const void *handle, const char *fname) { void XGBoosterSaveModel(void *handle, const char *fname) {
static_cast<const Booster*>(handle)->SaveModel(fname, false); Booster *bst = static_cast<Booster*>(handle);
bst->CheckInitModel();
bst->SaveModel(fname, false);
} }
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) { void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len); static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);

View File

@ -203,7 +203,7 @@ extern "C" {
* \param handle handle * \param handle handle
* \param fname file name * \param fname file name
*/ */
XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname); XGB_DLL void XGBoosterSaveModel(void *handle, const char *fname);
/*! /*!
* \brief load model from in memory buffer * \brief load model from in memory buffer
* \param handle handle * \param handle handle