Improved sklearn compatibility (#5255)

This commit is contained in:
David Díaz Vico 2020-02-03 06:30:45 +01:00 committed by GitHub
parent a5cc112eea
commit 71e7e3b96f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 18 deletions

View File

@ -4,7 +4,7 @@ ignore=tests
extension-pkg-whitelist=numpy
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object,attribute-defined-outside-init
dummy-variables-rgx=(unused|)_.*

View File

@ -105,9 +105,8 @@ __model_doc = '''
Using gblinear booster with shotgun updater is nondeterministic as
it uses Hogwild algorithm.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
missing : float, default np.nan
Value in the data which needs to be present as a missing value.
num_parallel_tree: int
Used for boosting random forest.
monotone_constraints : str
@ -208,7 +207,7 @@ class XGBModel(XGBModelBase):
colsample_bytree=None, colsample_bylevel=None,
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
scale_pos_weight=None, base_score=None, random_state=None,
missing=None, num_parallel_tree=None,
missing=np.nan, num_parallel_tree=None,
monotone_constraints=None, interaction_constraints=None,
importance_type="gain", gpu_id=None,
validate_parameters=False, **kwargs):
@ -234,10 +233,9 @@ class XGBModel(XGBModelBase):
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.missing = missing if missing is not None else np.nan
self.missing = missing
self.num_parallel_tree = num_parallel_tree
self.kwargs = kwargs
self._Booster = None
self.random_state = random_state
self.n_jobs = n_jobs
self.monotone_constraints = monotone_constraints
@ -249,15 +247,6 @@ class XGBModel(XGBModelBase):
# not.
self.validate_parameters = validate_parameters
def __setstate__(self, state):
# backward compatibility code
# load booster from raw if it is raw
# the booster now support pickle
bst = state["_Booster"]
if bst is not None and not isinstance(bst, Booster):
state["_Booster"] = Booster(model_file=bst)
self.__dict__.update(state)
def get_booster(self):
"""Get the underlying xgboost Booster of this model.
@ -267,7 +256,7 @@ class XGBModel(XGBModelBase):
-------
booster : a xgboost booster of underlying model
"""
if self._Booster is None:
if not hasattr(self, '_Booster'):
raise XGBoostError('need to call fit or load_model beforehand')
return self._Booster
@ -415,7 +404,7 @@ class XGBModel(XGBModelBase):
Input file name.
"""
if self._Booster is None:
if not hasattr(self, '_Booster'):
self._Booster = Booster({'n_jobs': self.n_jobs})
self._Booster.load_model(fname)
meta = self._Booster.attr('scikit_learn')