fix pkl problem
This commit is contained in:
parent
60bf389825
commit
68444a0626
@ -31,6 +31,9 @@ except ImportError:
|
|||||||
class XGBoostLibraryNotFound(Exception):
|
class XGBoostLibraryNotFound(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class XGBoostError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
||||||
|
|
||||||
if sys.version_info[0] == 3:
|
if sys.version_info[0] == 3:
|
||||||
@ -483,7 +486,7 @@ class Booster(object):
|
|||||||
if isinstance(fname, string_types): # assume file name
|
if isinstance(fname, string_types): # assume file name
|
||||||
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||||
else:
|
else:
|
||||||
raise Exception("fname must be a string")
|
raise TypeError("fname must be a string")
|
||||||
|
|
||||||
def save_raw(self):
|
def save_raw(self):
|
||||||
"""
|
"""
|
||||||
@ -852,7 +855,7 @@ class XGBModel(XGBModelBase):
|
|||||||
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1,
|
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1,
|
||||||
base_score=0.5, seed=0):
|
base_score=0.5, seed=0):
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise Exception('sklearn needs to be installed in order to use this module')
|
raise XGBError('sklearn needs to be installed in order to use this module')
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.n_estimators = n_estimators
|
self.n_estimators = n_estimators
|
||||||
@ -869,22 +872,36 @@ class XGBModel(XGBModelBase):
|
|||||||
self.base_score = base_score
|
self.base_score = base_score
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
self._Booster = Booster()
|
self._Booster = None
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
# can't pickle ctypes pointers so put _Booster in a BytesIO obj
|
||||||
|
|
||||||
this = self.__dict__.copy() # don't modify in place
|
this = self.__dict__.copy() # don't modify in place
|
||||||
raw = this["_Booster"].save_raw()
|
bst = this["_Booster"]
|
||||||
this["_Booster"] = raw
|
if bst is not None:
|
||||||
|
raw = this["_Booster"].save_raw()
|
||||||
|
this["_Booster"] = raw
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
booster = state["_Booster"]
|
bst = state["_Booster"]
|
||||||
state["_Booster"] = Booster(model_file=booster)
|
if bst is not None:
|
||||||
|
state["_Booster"] = Booster(model_file=booster)
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def booster(self):
|
||||||
|
"""
|
||||||
|
get the underlying xgboost Booster of this model
|
||||||
|
will raise an exception when fit was not called
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
booster : a xgboost booster of underlying model
|
||||||
|
"""
|
||||||
|
if self._Booster is None:
|
||||||
|
raise XGBError('need to call fit beforehand')
|
||||||
|
return self._Booster
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
xgb_params = self.get_params()
|
xgb_params = self.get_params()
|
||||||
|
|
||||||
@ -901,7 +918,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
return self._Booster.predict(testDmatrix)
|
return self.booster().predict(testDmatrix)
|
||||||
|
|
||||||
|
|
||||||
class XGBClassifier(XGBModel, XGBClassifier):
|
class XGBClassifier(XGBModel, XGBClassifier):
|
||||||
@ -942,7 +959,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
|||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
class_probs = self._Booster.predict(testDmatrix)
|
class_probs = self.booster().predict(testDmatrix)
|
||||||
if len(class_probs.shape) > 1:
|
if len(class_probs.shape) > 1:
|
||||||
column_indexes = np.argmax(class_probs, axis=1)
|
column_indexes = np.argmax(class_probs, axis=1)
|
||||||
else:
|
else:
|
||||||
@ -952,7 +969,7 @@ class XGBClassifier(XGBModel, XGBClassifier):
|
|||||||
|
|
||||||
def predict_proba(self, X):
|
def predict_proba(self, X):
|
||||||
testDmatrix = DMatrix(X)
|
testDmatrix = DMatrix(X)
|
||||||
class_probs = self._Booster.predict(testDmatrix)
|
class_probs = self.booster().predict(testDmatrix)
|
||||||
if self.objective == "multi:softprob":
|
if self.objective == "multi:softprob":
|
||||||
return class_probs
|
return class_probs
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -62,6 +62,7 @@ class Booster: public learner::BoostLearner {
|
|||||||
this->init_model = true;
|
this->init_model = true;
|
||||||
}
|
}
|
||||||
inline const char *GetModelRaw(bst_ulong *out_len) {
|
inline const char *GetModelRaw(bst_ulong *out_len) {
|
||||||
|
this->CheckInitModel();
|
||||||
model_str.resize(0);
|
model_str.resize(0);
|
||||||
utils::MemoryBufferStream fs(&model_str);
|
utils::MemoryBufferStream fs(&model_str);
|
||||||
learner::BoostLearner::SaveModel(fs, false);
|
learner::BoostLearner::SaveModel(fs, false);
|
||||||
@ -322,8 +323,10 @@ extern "C"{
|
|||||||
void XGBoosterLoadModel(void *handle, const char *fname) {
|
void XGBoosterLoadModel(void *handle, const char *fname) {
|
||||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||||
}
|
}
|
||||||
void XGBoosterSaveModel(const void *handle, const char *fname) {
|
void XGBoosterSaveModel(void *handle, const char *fname) {
|
||||||
static_cast<const Booster*>(handle)->SaveModel(fname, false);
|
Booster *bst = static_cast<Booster*>(handle);
|
||||||
|
bst->CheckInitModel();
|
||||||
|
bst->SaveModel(fname, false);
|
||||||
}
|
}
|
||||||
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
|
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
|
||||||
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);
|
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);
|
||||||
|
|||||||
@ -203,7 +203,7 @@ extern "C" {
|
|||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \param fname file name
|
* \param fname file name
|
||||||
*/
|
*/
|
||||||
XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname);
|
XGB_DLL void XGBoosterSaveModel(void *handle, const char *fname);
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from in memory buffer
|
* \brief load model from in memory buffer
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user