From 6bed54ac39f5936b62887b0e0ed5b17a57209e15 Mon Sep 17 00:00:00 2001 From: kodonnell Date: Mon, 9 Jul 2018 09:35:52 +1200 Subject: [PATCH] python sklearn api: defaulting to best_ntree_limit if defined, otherwise current behaviour (#3445) * python sklearn api: defaulting to best_ntree_limit if defined, otherwise current behaviour * Fix whitespace --- python-package/xgboost/sklearn.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 66eff2abd..5d8e98b5c 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -335,9 +335,13 @@ class XGBModel(XGBModelBase): self.best_ntree_limit = self._Booster.best_ntree_limit return self - def predict(self, data, output_margin=False, ntree_limit=0): + def predict(self, data, output_margin=False, ntree_limit=None): # pylint: disable=missing-docstring,invalid-name test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) + # get ntree_limit to use - if none specified, default to + # best_ntree_limit if defined, otherwise 0. + if ntree_limit is None: + ntree_limit = getattr(self, "best_ntree_limit", 0) return self.get_booster().predict(test_dmatrix, output_margin=output_margin, ntree_limit=ntree_limit) @@ -556,7 +560,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): return self - def predict(self, data, output_margin=False, ntree_limit=0): + def predict(self, data, output_margin=False, ntree_limit=None): """ Predict with `data`. NOTE: This function is not thread safe. @@ -570,12 +574,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase): output_margin : bool Whether to output the raw untransformed margin value. ntree_limit : int - Limit number of trees in the prediction; defaults to 0 (use all trees). + Limit number of trees in the prediction; defaults to best_ntree_limit if defined + (i.e. it has been trained with early stopping), otherwise 0 (use all trees). Returns ------- prediction : numpy array """ test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) + if ntree_limit is None: + ntree_limit = getattr(self, "best_ntree_limit", 0) class_probs = self.get_booster().predict(test_dmatrix, output_margin=output_margin, ntree_limit=ntree_limit) @@ -586,7 +593,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): column_indexes[class_probs > 0.5] = 1 return self._le.inverse_transform(column_indexes) - def predict_proba(self, data, ntree_limit=0): + def predict_proba(self, data, ntree_limit=None): """ Predict the probability of each `data` example being of a given class. NOTE: This function is not thread safe. @@ -598,13 +605,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase): data : DMatrix The dmatrix storing the input. ntree_limit : int - Limit number of trees in the prediction; defaults to 0 (use all trees). + Limit number of trees in the prediction; defaults to best_ntree_limit if defined + (i.e. it has been trained with early stopping), otherwise 0 (use all trees). Returns ------- prediction : numpy array a numpy array with the probability of each data example being of a given class. """ test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) + if ntree_limit is None: + ntree_limit = getattr(self, "best_ntree_limit", 0) class_probs = self.get_booster().predict(test_dmatrix, ntree_limit=ntree_limit) if self.objective == "multi:softprob":