fix pkl problem

This commit is contained in:
tqchen 2015-05-07 18:11:40 -07:00
parent 60bf389825
commit 68444a0626
3 changed files with 36 additions and 16 deletions

View File

@ -31,6 +31,9 @@ except ImportError:
class XGBoostLibraryNotFound(Exception):
pass
class XGBoostError(Exception):
pass
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
if sys.version_info[0] == 3:
@ -483,7 +486,7 @@ class Booster(object):
if isinstance(fname, string_types): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
else:
raise Exception("fname must be a string")
raise TypeError("fname must be a string")
def save_raw(self):
"""
@ -852,7 +855,7 @@ class XGBModel(XGBModelBase):
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1,
base_score=0.5, seed=0):
if not SKLEARN_INSTALLED:
raise Exception('sklearn needs to be installed in order to use this module')
raise XGBError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
@ -869,22 +872,36 @@ class XGBModel(XGBModelBase):
self.base_score = base_score
self.seed = seed
self._Booster = Booster()
self._Booster = None
def __getstate__(self):
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
this = self.__dict__.copy() # don't modify in place
raw = this["_Booster"].save_raw()
this["_Booster"] = raw
this = self.__dict__.copy() # don't modify in place
bst = this["_Booster"]
if bst is not None:
raw = this["_Booster"].save_raw()
this["_Booster"] = raw
return this
def __setstate__(self, state):
booster = state["_Booster"]
state["_Booster"] = Booster(model_file=booster)
bst = state["_Booster"]
if bst is not None:
state["_Booster"] = Booster(model_file=booster)
self.__dict__.update(state)
def booster(self):
"""
get the underlying xgboost Booster of this model
will raise an exception when fit was not called
Returns
-------
booster : a xgboost booster of underlying model
"""
if self._Booster is None:
raise XGBError('need to call fit beforehand')
return self._Booster
def get_xgb_params(self):
xgb_params = self.get_params()
@ -901,7 +918,7 @@ class XGBModel(XGBModelBase):
def predict(self, X):
testDmatrix = DMatrix(X)
return self._Booster.predict(testDmatrix)
return self.booster().predict(testDmatrix)
class XGBClassifier(XGBModel, XGBClassifier):
@ -942,7 +959,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
def predict(self, X):
testDmatrix = DMatrix(X)
class_probs = self._Booster.predict(testDmatrix)
class_probs = self.booster().predict(testDmatrix)
if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1)
else:
@ -952,7 +969,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
def predict_proba(self, X):
testDmatrix = DMatrix(X)
class_probs = self._Booster.predict(testDmatrix)
class_probs = self.booster().predict(testDmatrix)
if self.objective == "multi:softprob":
return class_probs
else:

View File

@ -62,6 +62,7 @@ class Booster: public learner::BoostLearner {
this->init_model = true;
}
inline const char *GetModelRaw(bst_ulong *out_len) {
this->CheckInitModel();
model_str.resize(0);
utils::MemoryBufferStream fs(&model_str);
learner::BoostLearner::SaveModel(fs, false);
@ -322,8 +323,10 @@ extern "C"{
void XGBoosterLoadModel(void *handle, const char *fname) {
static_cast<Booster*>(handle)->LoadModel(fname);
}
void XGBoosterSaveModel(const void *handle, const char *fname) {
static_cast<const Booster*>(handle)->SaveModel(fname, false);
void XGBoosterSaveModel(void *handle, const char *fname) {
Booster *bst = static_cast<Booster*>(handle);
bst->CheckInitModel();
bst->SaveModel(fname, false);
}
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);

View File

@ -203,7 +203,7 @@ extern "C" {
* \param handle handle
* \param fname file name
*/
XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname);
XGB_DLL void XGBoosterSaveModel(void *handle, const char *fname);
/*!
* \brief load model from in memory buffer
* \param handle handle