python sklearn api: defaulting to best_ntree_limit if defined, otherwise current behaviour (#3445)
* python sklearn api: defaulting to best_ntree_limit if defined, otherwise current behaviour * Fix whitespace
This commit is contained in:
parent
cb017d0c9a
commit
6bed54ac39
@ -335,9 +335,13 @@ class XGBModel(XGBModelBase):
|
||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||
def predict(self, data, output_margin=False, ntree_limit=None):
|
||||
# pylint: disable=missing-docstring,invalid-name
|
||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||
# get ntree_limit to use - if none specified, default to
|
||||
# best_ntree_limit if defined, otherwise 0.
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
return self.get_booster().predict(test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit)
|
||||
@ -556,7 +560,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||
def predict(self, data, output_margin=False, ntree_limit=None):
|
||||
"""
|
||||
Predict with `data`.
|
||||
NOTE: This function is not thread safe.
|
||||
@ -570,12 +574,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
output_margin : bool
|
||||
Whether to output the raw untransformed margin value.
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
|
||||
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy array
|
||||
"""
|
||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
class_probs = self.get_booster().predict(test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit)
|
||||
@ -586,7 +593,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
column_indexes[class_probs > 0.5] = 1
|
||||
return self._le.inverse_transform(column_indexes)
|
||||
|
||||
def predict_proba(self, data, ntree_limit=0):
|
||||
def predict_proba(self, data, ntree_limit=None):
|
||||
"""
|
||||
Predict the probability of each `data` example being of a given class.
|
||||
NOTE: This function is not thread safe.
|
||||
@ -598,13 +605,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
data : DMatrix
|
||||
The dmatrix storing the input.
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
|
||||
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy array
|
||||
a numpy array with the probability of each data example being of a given class.
|
||||
"""
|
||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
class_probs = self.get_booster().predict(test_dmatrix,
|
||||
ntree_limit=ntree_limit)
|
||||
if self.objective == "multi:softprob":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user