From 60cfd14349a3ed4593de276525fba068adcaea53 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 5 Jan 2021 08:29:06 +0800 Subject: [PATCH] [dask, sklearn] Fix predict proba. (#6566) * For sklearn: - Handles user defined objective function. - Handles `softmax`. * For dask: - Use the implementation from sklearn, the previous implementation doesn't perform any extra handling. --- python-package/xgboost/dask.py | 9 ++++++++- python-package/xgboost/sklearn.py | 27 ++++++++++++++++++++------- tests/python/test_with_dask.py | 2 +- tests/python/test_with_sklearn.py | 17 +++++++++++++++++ tests/python/testing.py | 28 ++++++++++++++++++++++++++++ 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b80454797..bf5992928 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -40,6 +40,7 @@ from .training import train as worker_train from .tracker import RabitTracker, get_host_ip from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator from .sklearn import xgboost_model_doc +from .sklearn import _cls_predict_proba if TYPE_CHECKING: @@ -1504,6 +1505,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): early_stopping_rounds=early_stopping_rounds, callbacks=callbacks) self._Booster = results['booster'] + + if not callable(self.objective): + self.objective = params["objective"] + # pylint: disable=attribute-defined-outside-init self.evals_result_ = results['history'] return self @@ -1554,7 +1559,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): data=test_dmatrix, validate_features=validate_features, output_margin=output_margin) - return pred_probs + return _cls_predict_proba(self.objective, pred_probs, da.vstack) # pylint: disable=arguments-differ,missing-docstring def predict_proba( @@ -1593,6 +1598,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): output_margin=output_margin, validate_features=validate_features ) + if output_margin: + return pred_probs if self.n_classes_ == 2: preds = (pred_probs > 0.5).astype(int) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 00a78522e..7359ec124 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -819,6 +819,20 @@ class XGBModel(XGBModelBase): return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) +def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack: Callable) -> Any: + if objective == 'multi:softmax': + raise ValueError('multi:softmax objective does not support predict_proba,' + ' use `multi:softprob` or `binary:logistic` instead.') + if objective == 'multi:softprob' or callable(objective): + # Return prediction directly if if objective is defined by user since we don't + # know how to perform the transformation + return prediction + # Lastly the binary logistic function + classone_probs = prediction + classzero_probs = 1.0 - classone_probs + return vstack((classzero_probs, classone_probs)).transpose() + + @xgboost_model_doc( "Implementation of the scikit-learn API for XGBoost classification.", ['model', 'objective'], extra_parameters=''' @@ -929,7 +943,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): verbose_eval=verbose, xgb_model=model, callbacks=callbacks) - self.objective = params["objective"] + if not callable(self.objective): + self.objective = params["objective"] + if evals_result: for val in evals_result.items(): evals_result_key = list(val[1].keys())[0] @@ -1031,7 +1047,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): Returns ------- prediction : numpy array - a numpy array with the probability of each data example being of a given class. + a numpy array of shape array-like of shape (n_samples, n_classes) with the + probability of each data example being of a given class. """ test_dmatrix = DMatrix(X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) @@ -1040,11 +1057,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): class_probs = self.get_booster().predict(test_dmatrix, ntree_limit=ntree_limit, validate_features=validate_features) - if self.objective == "multi:softprob": - return class_probs - classone_probs = class_probs - classzero_probs = 1.0 - classone_probs - return np.vstack((classzero_probs, classone_probs)).transpose() + return _cls_predict_proba(self.objective, class_probs, np.vstack) def evals_result(self): """Return the evaluation results. diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 9826ddbff..d3c2f988a 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -160,7 +160,7 @@ def test_boost_from_prediction(tree_method: str) -> None: tree_method=tree_method, ) model_0.fit(X=X_, y=y_) - margin = model_0.predict_proba(X_, output_margin=True) + margin = model_0.predict(X_, output_margin=True) model_1 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 75cd19636..5d105b5a0 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -79,6 +79,18 @@ def test_multiclass_classification(): check_pred(preds3, labels, output_margin=True) check_pred(preds4, labels, output_margin=False) + cls = xgb.XGBClassifier(n_estimators=4).fit(X, y) + assert cls.n_classes_ == 3 + proba = cls.predict_proba(X) + assert proba.shape[0] == X.shape[0] + assert proba.shape[1] == cls.n_classes_ + + # custom objective, the default is multi:softprob so no transformation is required. + cls = xgb.XGBClassifier(n_estimators=4, objective=tm.softprob_obj(3)).fit(X, y) + proba = cls.predict_proba(X) + assert proba.shape[0] == X.shape[0] + assert proba.shape[1] == cls.n_classes_ + def test_ranking(): # generate random data @@ -788,6 +800,11 @@ def test_save_load_model(): booster.save_model(model_path) cls = xgb.XGBClassifier() cls.load_model(model_path) + + proba = cls.predict_proba(X) + assert proba.shape[0] == X.shape[0] + assert proba.shape[1] == 2 # binary + predt_1 = cls.predict_proba(X)[:, 1] assert np.allclose(predt_0, predt_1) diff --git a/tests/python/testing.py b/tests/python/testing.py index df2f7ba02..0530d270e 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -253,6 +253,34 @@ def eval_error_metric(predt, dtrain: xgb.DMatrix): return 'CustomErr', np.sum(r) +def softmax(x): + e = np.exp(x) + return e / np.sum(e) + + +def softprob_obj(classes): + def objective(labels, predt): + rows = labels.shape[0] + grad = np.zeros((rows, classes), dtype=float) + hess = np.zeros((rows, classes), dtype=float) + eps = 1e-6 + for r in range(predt.shape[0]): + target = labels[r] + p = softmax(predt[r, :]) + for c in range(predt.shape[1]): + assert target >= 0 or target <= classes + g = p[c] - 1.0 if c == target else p[c] + h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps) + grad[r, c] = g + hess[r, c] = h + + grad = grad.reshape((rows * classes, 1)) + hess = hess.reshape((rows * classes, 1)) + return grad, hess + + return objective + + class DirectoryExcursion: def __init__(self, path: os.PathLike, cleanup=False): '''Change directory. Change back and optionally cleaning up the directory when exit.