From 68444a06269013efd133ee6e5535faad203110b9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 7 May 2015 18:11:40 -0700 Subject: [PATCH] fix pkl problem --- wrapper/xgboost.py | 43 ++++++++++++++++++++++++++----------- wrapper/xgboost_wrapper.cpp | 7 ++++-- wrapper/xgboost_wrapper.h | 2 +- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 40ffaa84b..f0a516e2d 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -31,6 +31,9 @@ except ImportError: class XGBoostLibraryNotFound(Exception): pass +class XGBoostError(Exception): + pass + __all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train'] if sys.version_info[0] == 3: @@ -483,7 +486,7 @@ class Booster(object): if isinstance(fname, string_types): # assume file name xglib.XGBoosterSaveModel(self.handle, c_str(fname)) else: - raise Exception("fname must be a string") + raise TypeError("fname must be a string") def save_raw(self): """ @@ -852,7 +855,7 @@ class XGBModel(XGBModelBase): nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, base_score=0.5, seed=0): if not SKLEARN_INSTALLED: - raise Exception('sklearn needs to be installed in order to use this module') + raise XGBError('sklearn needs to be installed in order to use this module') self.max_depth = max_depth self.learning_rate = learning_rate self.n_estimators = n_estimators @@ -869,22 +872,36 @@ class XGBModel(XGBModelBase): self.base_score = base_score self.seed = seed - self._Booster = Booster() + self._Booster = None def __getstate__(self): # can't pickle ctypes pointers so put _Booster in a BytesIO obj - - this = self.__dict__.copy() # don't modify in place - raw = this["_Booster"].save_raw() - this["_Booster"] = raw - + this = self.__dict__.copy() # don't modify in place + bst = this["_Booster"] + if bst is not None: + raw = this["_Booster"].save_raw() + this["_Booster"] = raw return this def __setstate__(self, state): - booster = state["_Booster"] - state["_Booster"] = Booster(model_file=booster) + bst = state["_Booster"] + if bst is not None: + state["_Booster"] = Booster(model_file=booster) self.__dict__.update(state) + def booster(self): + """ + get the underlying xgboost Booster of this model + will raise an exception when fit was not called + + Returns + ------- + booster : a xgboost booster of underlying model + """ + if self._Booster is None: + raise XGBError('need to call fit beforehand') + return self._Booster + def get_xgb_params(self): xgb_params = self.get_params() @@ -901,7 +918,7 @@ class XGBModel(XGBModelBase): def predict(self, X): testDmatrix = DMatrix(X) - return self._Booster.predict(testDmatrix) + return self.booster().predict(testDmatrix) class XGBClassifier(XGBModel, XGBClassifier): @@ -942,7 +959,7 @@ class XGBClassifier(XGBModel, XGBClassifier): def predict(self, X): testDmatrix = DMatrix(X) - class_probs = self._Booster.predict(testDmatrix) + class_probs = self.booster().predict(testDmatrix) if len(class_probs.shape) > 1: column_indexes = np.argmax(class_probs, axis=1) else: @@ -952,7 +969,7 @@ class XGBClassifier(XGBModel, XGBClassifier): def predict_proba(self, X): testDmatrix = DMatrix(X) - class_probs = self._Booster.predict(testDmatrix) + class_probs = self.booster().predict(testDmatrix) if self.objective == "multi:softprob": return class_probs else: diff --git a/wrapper/xgboost_wrapper.cpp b/wrapper/xgboost_wrapper.cpp index be2a2001c..4d7828faf 100644 --- a/wrapper/xgboost_wrapper.cpp +++ b/wrapper/xgboost_wrapper.cpp @@ -62,6 +62,7 @@ class Booster: public learner::BoostLearner { this->init_model = true; } inline const char *GetModelRaw(bst_ulong *out_len) { + this->CheckInitModel(); model_str.resize(0); utils::MemoryBufferStream fs(&model_str); learner::BoostLearner::SaveModel(fs, false); @@ -322,8 +323,10 @@ extern "C"{ void XGBoosterLoadModel(void *handle, const char *fname) { static_cast(handle)->LoadModel(fname); } - void XGBoosterSaveModel(const void *handle, const char *fname) { - static_cast(handle)->SaveModel(fname, false); + void XGBoosterSaveModel(void *handle, const char *fname) { + Booster *bst = static_cast(handle); + bst->CheckInitModel(); + bst->SaveModel(fname, false); } void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) { static_cast(handle)->LoadModelFromBuffer(buf, len); diff --git a/wrapper/xgboost_wrapper.h b/wrapper/xgboost_wrapper.h index f1d2cc92a..88a327d0d 100644 --- a/wrapper/xgboost_wrapper.h +++ b/wrapper/xgboost_wrapper.h @@ -203,7 +203,7 @@ extern "C" { * \param handle handle * \param fname file name */ - XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname); + XGB_DLL void XGBoosterSaveModel(void *handle, const char *fname); /*! * \brief load model from in memory buffer * \param handle handle