Improved sklearn compatibility (#5255)
This commit is contained in:
parent
a5cc112eea
commit
71e7e3b96f
@ -4,7 +4,7 @@ ignore=tests
|
||||
|
||||
extension-pkg-whitelist=numpy
|
||||
|
||||
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object
|
||||
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object,attribute-defined-outside-init
|
||||
|
||||
dummy-variables-rgx=(unused|)_.*
|
||||
|
||||
|
||||
@ -105,9 +105,8 @@ __model_doc = '''
|
||||
Using gblinear booster with shotgun updater is nondeterministic as
|
||||
it uses Hogwild algorithm.
|
||||
|
||||
missing : float, optional
|
||||
Value in the data which needs to be present as a missing value. If
|
||||
None, defaults to np.nan.
|
||||
missing : float, default np.nan
|
||||
Value in the data which needs to be present as a missing value.
|
||||
num_parallel_tree: int
|
||||
Used for boosting random forest.
|
||||
monotone_constraints : str
|
||||
@ -208,7 +207,7 @@ class XGBModel(XGBModelBase):
|
||||
colsample_bytree=None, colsample_bylevel=None,
|
||||
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
|
||||
scale_pos_weight=None, base_score=None, random_state=None,
|
||||
missing=None, num_parallel_tree=None,
|
||||
missing=np.nan, num_parallel_tree=None,
|
||||
monotone_constraints=None, interaction_constraints=None,
|
||||
importance_type="gain", gpu_id=None,
|
||||
validate_parameters=False, **kwargs):
|
||||
@ -234,10 +233,9 @@ class XGBModel(XGBModelBase):
|
||||
self.reg_lambda = reg_lambda
|
||||
self.scale_pos_weight = scale_pos_weight
|
||||
self.base_score = base_score
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.missing = missing
|
||||
self.num_parallel_tree = num_parallel_tree
|
||||
self.kwargs = kwargs
|
||||
self._Booster = None
|
||||
self.random_state = random_state
|
||||
self.n_jobs = n_jobs
|
||||
self.monotone_constraints = monotone_constraints
|
||||
@ -249,15 +247,6 @@ class XGBModel(XGBModelBase):
|
||||
# not.
|
||||
self.validate_parameters = validate_parameters
|
||||
|
||||
def __setstate__(self, state):
|
||||
# backward compatibility code
|
||||
# load booster from raw if it is raw
|
||||
# the booster now support pickle
|
||||
bst = state["_Booster"]
|
||||
if bst is not None and not isinstance(bst, Booster):
|
||||
state["_Booster"] = Booster(model_file=bst)
|
||||
self.__dict__.update(state)
|
||||
|
||||
def get_booster(self):
|
||||
"""Get the underlying xgboost Booster of this model.
|
||||
|
||||
@ -267,7 +256,7 @@ class XGBModel(XGBModelBase):
|
||||
-------
|
||||
booster : a xgboost booster of underlying model
|
||||
"""
|
||||
if self._Booster is None:
|
||||
if not hasattr(self, '_Booster'):
|
||||
raise XGBoostError('need to call fit or load_model beforehand')
|
||||
return self._Booster
|
||||
|
||||
@ -415,7 +404,7 @@ class XGBModel(XGBModelBase):
|
||||
Input file name.
|
||||
|
||||
"""
|
||||
if self._Booster is None:
|
||||
if not hasattr(self, '_Booster'):
|
||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||
self._Booster.load_model(fname)
|
||||
meta = self._Booster.attr('scikit_learn')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user