Improved sklearn compatibility (#5255)
This commit is contained in:
parent
a5cc112eea
commit
71e7e3b96f
@ -4,7 +4,7 @@ ignore=tests
|
|||||||
|
|
||||||
extension-pkg-whitelist=numpy
|
extension-pkg-whitelist=numpy
|
||||||
|
|
||||||
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object
|
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance,import-outside-toplevel,unsubscriptable-object,attribute-defined-outside-init
|
||||||
|
|
||||||
dummy-variables-rgx=(unused|)_.*
|
dummy-variables-rgx=(unused|)_.*
|
||||||
|
|
||||||
|
|||||||
@ -105,9 +105,8 @@ __model_doc = '''
|
|||||||
Using gblinear booster with shotgun updater is nondeterministic as
|
Using gblinear booster with shotgun updater is nondeterministic as
|
||||||
it uses Hogwild algorithm.
|
it uses Hogwild algorithm.
|
||||||
|
|
||||||
missing : float, optional
|
missing : float, default np.nan
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value.
|
||||||
None, defaults to np.nan.
|
|
||||||
num_parallel_tree: int
|
num_parallel_tree: int
|
||||||
Used for boosting random forest.
|
Used for boosting random forest.
|
||||||
monotone_constraints : str
|
monotone_constraints : str
|
||||||
@ -208,7 +207,7 @@ class XGBModel(XGBModelBase):
|
|||||||
colsample_bytree=None, colsample_bylevel=None,
|
colsample_bytree=None, colsample_bylevel=None,
|
||||||
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
|
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
|
||||||
scale_pos_weight=None, base_score=None, random_state=None,
|
scale_pos_weight=None, base_score=None, random_state=None,
|
||||||
missing=None, num_parallel_tree=None,
|
missing=np.nan, num_parallel_tree=None,
|
||||||
monotone_constraints=None, interaction_constraints=None,
|
monotone_constraints=None, interaction_constraints=None,
|
||||||
importance_type="gain", gpu_id=None,
|
importance_type="gain", gpu_id=None,
|
||||||
validate_parameters=False, **kwargs):
|
validate_parameters=False, **kwargs):
|
||||||
@ -234,10 +233,9 @@ class XGBModel(XGBModelBase):
|
|||||||
self.reg_lambda = reg_lambda
|
self.reg_lambda = reg_lambda
|
||||||
self.scale_pos_weight = scale_pos_weight
|
self.scale_pos_weight = scale_pos_weight
|
||||||
self.base_score = base_score
|
self.base_score = base_score
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing
|
||||||
self.num_parallel_tree = num_parallel_tree
|
self.num_parallel_tree = num_parallel_tree
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self._Booster = None
|
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.monotone_constraints = monotone_constraints
|
self.monotone_constraints = monotone_constraints
|
||||||
@ -249,15 +247,6 @@ class XGBModel(XGBModelBase):
|
|||||||
# not.
|
# not.
|
||||||
self.validate_parameters = validate_parameters
|
self.validate_parameters = validate_parameters
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
# backward compatibility code
|
|
||||||
# load booster from raw if it is raw
|
|
||||||
# the booster now support pickle
|
|
||||||
bst = state["_Booster"]
|
|
||||||
if bst is not None and not isinstance(bst, Booster):
|
|
||||||
state["_Booster"] = Booster(model_file=bst)
|
|
||||||
self.__dict__.update(state)
|
|
||||||
|
|
||||||
def get_booster(self):
|
def get_booster(self):
|
||||||
"""Get the underlying xgboost Booster of this model.
|
"""Get the underlying xgboost Booster of this model.
|
||||||
|
|
||||||
@ -267,7 +256,7 @@ class XGBModel(XGBModelBase):
|
|||||||
-------
|
-------
|
||||||
booster : a xgboost booster of underlying model
|
booster : a xgboost booster of underlying model
|
||||||
"""
|
"""
|
||||||
if self._Booster is None:
|
if not hasattr(self, '_Booster'):
|
||||||
raise XGBoostError('need to call fit or load_model beforehand')
|
raise XGBoostError('need to call fit or load_model beforehand')
|
||||||
return self._Booster
|
return self._Booster
|
||||||
|
|
||||||
@ -415,7 +404,7 @@ class XGBModel(XGBModelBase):
|
|||||||
Input file name.
|
Input file name.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._Booster is None:
|
if not hasattr(self, '_Booster'):
|
||||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||||
self._Booster.load_model(fname)
|
self._Booster.load_model(fname)
|
||||||
meta = self._Booster.attr('scikit_learn')
|
meta = self._Booster.attr('scikit_learn')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user