fix pkl problem
This commit is contained in:
parent
60bf389825
commit
68444a0626
@ -31,6 +31,9 @@ except ImportError:
|
||||
class XGBoostLibraryNotFound(Exception):
|
||||
pass
|
||||
|
||||
class XGBoostError(Exception):
|
||||
pass
|
||||
|
||||
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
@ -483,7 +486,7 @@ class Booster(object):
|
||||
if isinstance(fname, string_types): # assume file name
|
||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||
else:
|
||||
raise Exception("fname must be a string")
|
||||
raise TypeError("fname must be a string")
|
||||
|
||||
def save_raw(self):
|
||||
"""
|
||||
@ -852,7 +855,7 @@ class XGBModel(XGBModelBase):
|
||||
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1,
|
||||
base_score=0.5, seed=0):
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise Exception('sklearn needs to be installed in order to use this module')
|
||||
raise XGBError('sklearn needs to be installed in order to use this module')
|
||||
self.max_depth = max_depth
|
||||
self.learning_rate = learning_rate
|
||||
self.n_estimators = n_estimators
|
||||
@ -869,22 +872,36 @@ class XGBModel(XGBModelBase):
|
||||
self.base_score = base_score
|
||||
self.seed = seed
|
||||
|
||||
self._Booster = Booster()
|
||||
self._Booster = None
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
||||
|
||||
this = self.__dict__.copy() # don't modify in place
|
||||
raw = this["_Booster"].save_raw()
|
||||
this["_Booster"] = raw
|
||||
|
||||
this = self.__dict__.copy() # don't modify in place
|
||||
bst = this["_Booster"]
|
||||
if bst is not None:
|
||||
raw = this["_Booster"].save_raw()
|
||||
this["_Booster"] = raw
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
booster = state["_Booster"]
|
||||
state["_Booster"] = Booster(model_file=booster)
|
||||
bst = state["_Booster"]
|
||||
if bst is not None:
|
||||
state["_Booster"] = Booster(model_file=booster)
|
||||
self.__dict__.update(state)
|
||||
|
||||
def booster(self):
|
||||
"""
|
||||
get the underlying xgboost Booster of this model
|
||||
will raise an exception when fit was not called
|
||||
|
||||
Returns
|
||||
-------
|
||||
booster : a xgboost booster of underlying model
|
||||
"""
|
||||
if self._Booster is None:
|
||||
raise XGBError('need to call fit beforehand')
|
||||
return self._Booster
|
||||
|
||||
def get_xgb_params(self):
|
||||
xgb_params = self.get_params()
|
||||
|
||||
@ -901,7 +918,7 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def predict(self, X):
|
||||
testDmatrix = DMatrix(X)
|
||||
return self._Booster.predict(testDmatrix)
|
||||
return self.booster().predict(testDmatrix)
|
||||
|
||||
|
||||
class XGBClassifier(XGBModel, XGBClassifier):
|
||||
@ -942,7 +959,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
||||
|
||||
def predict(self, X):
|
||||
testDmatrix = DMatrix(X)
|
||||
class_probs = self._Booster.predict(testDmatrix)
|
||||
class_probs = self.booster().predict(testDmatrix)
|
||||
if len(class_probs.shape) > 1:
|
||||
column_indexes = np.argmax(class_probs, axis=1)
|
||||
else:
|
||||
@ -952,7 +969,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
||||
|
||||
def predict_proba(self, X):
|
||||
testDmatrix = DMatrix(X)
|
||||
class_probs = self._Booster.predict(testDmatrix)
|
||||
class_probs = self.booster().predict(testDmatrix)
|
||||
if self.objective == "multi:softprob":
|
||||
return class_probs
|
||||
else:
|
||||
|
||||
@ -62,6 +62,7 @@ class Booster: public learner::BoostLearner {
|
||||
this->init_model = true;
|
||||
}
|
||||
inline const char *GetModelRaw(bst_ulong *out_len) {
|
||||
this->CheckInitModel();
|
||||
model_str.resize(0);
|
||||
utils::MemoryBufferStream fs(&model_str);
|
||||
learner::BoostLearner::SaveModel(fs, false);
|
||||
@ -322,8 +323,10 @@ extern "C"{
|
||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||
}
|
||||
void XGBoosterSaveModel(const void *handle, const char *fname) {
|
||||
static_cast<const Booster*>(handle)->SaveModel(fname, false);
|
||||
void XGBoosterSaveModel(void *handle, const char *fname) {
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->CheckInitModel();
|
||||
bst->SaveModel(fname, false);
|
||||
}
|
||||
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
|
||||
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);
|
||||
|
||||
@ -203,7 +203,7 @@ extern "C" {
|
||||
* \param handle handle
|
||||
* \param fname file name
|
||||
*/
|
||||
XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname);
|
||||
XGB_DLL void XGBoosterSaveModel(void *handle, const char *fname);
|
||||
/*!
|
||||
* \brief load model from in memory buffer
|
||||
* \param handle handle
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user