python sklearn api: defaulting to best_ntree_limit if defined, otherwise current behaviour (#3445)

* python sklearn api: defaulting to best_ntree_limit if defined, otherwise current behaviour

* Fix whitespace
This commit is contained in:
kodonnell 2018-07-09 09:35:52 +12:00 committed by Philip Hyunsu Cho
parent cb017d0c9a
commit 6bed54ac39

View File

@ -335,9 +335,13 @@ class XGBModel(XGBModelBase):
self.best_ntree_limit = self._Booster.best_ntree_limit self.best_ntree_limit = self._Booster.best_ntree_limit
return self return self
def predict(self, data, output_margin=False, ntree_limit=0): def predict(self, data, output_margin=False, ntree_limit=None):
# pylint: disable=missing-docstring,invalid-name # pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
# get ntree_limit to use - if none specified, default to
# best_ntree_limit if defined, otherwise 0.
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
return self.get_booster().predict(test_dmatrix, return self.get_booster().predict(test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)
@ -556,7 +560,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self return self
def predict(self, data, output_margin=False, ntree_limit=0): def predict(self, data, output_margin=False, ntree_limit=None):
""" """
Predict with `data`. Predict with `data`.
NOTE: This function is not thread safe. NOTE: This function is not thread safe.
@ -570,12 +574,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin : bool output_margin : bool
Whether to output the raw untransformed margin value. Whether to output the raw untransformed margin value.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to 0 (use all trees). Limit number of trees in the prediction; defaults to best_ntree_limit if defined
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
""" """
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)
@ -586,7 +593,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
def predict_proba(self, data, ntree_limit=0): def predict_proba(self, data, ntree_limit=None):
""" """
Predict the probability of each `data` example being of a given class. Predict the probability of each `data` example being of a given class.
NOTE: This function is not thread safe. NOTE: This function is not thread safe.
@ -598,13 +605,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
data : DMatrix data : DMatrix
The dmatrix storing the input. The dmatrix storing the input.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to 0 (use all trees). Limit number of trees in the prediction; defaults to best_ntree_limit if defined
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
a numpy array with the probability of each data example being of a given class. a numpy array with the probability of each data example being of a given class.
""" """
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)
if self.objective == "multi:softprob": if self.objective == "multi:softprob":