[dask, sklearn] Fix predict proba. (#6566)

* For sklearn:
  - Handles user defined objective function.
  - Handles `softmax`.

* For dask:
  - Use the implementation from sklearn, the previous implementation doesn't perform any extra handling.
This commit is contained in:
Jiaming Yuan 2021-01-05 08:29:06 +08:00 committed by GitHub
parent 516a93d25c
commit 60cfd14349
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 9 deletions

View File

@ -40,6 +40,7 @@ from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1504,6 +1505,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks) callbacks=callbacks)
self._Booster = results['booster'] self._Booster = results['booster']
if not callable(self.objective):
self.objective = params["objective"]
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history'] self.evals_result_ = results['history']
return self return self
@ -1554,7 +1559,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
data=test_dmatrix, data=test_dmatrix,
validate_features=validate_features, validate_features=validate_features,
output_margin=output_margin) output_margin=output_margin)
return pred_probs return _cls_predict_proba(self.objective, pred_probs, da.vstack)
# pylint: disable=arguments-differ,missing-docstring # pylint: disable=arguments-differ,missing-docstring
def predict_proba( def predict_proba(
@ -1593,6 +1598,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
output_margin=output_margin, output_margin=output_margin,
validate_features=validate_features validate_features=validate_features
) )
if output_margin:
return pred_probs
if self.n_classes_ == 2: if self.n_classes_ == 2:
preds = (pred_probs > 0.5).astype(int) preds = (pred_probs > 0.5).astype(int)

View File

@ -819,6 +819,20 @@ class XGBModel(XGBModelBase):
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack: Callable) -> Any:
if objective == 'multi:softmax':
raise ValueError('multi:softmax objective does not support predict_proba,'
' use `multi:softprob` or `binary:logistic` instead.')
if objective == 'multi:softprob' or callable(objective):
# Return prediction directly if if objective is defined by user since we don't
# know how to perform the transformation
return prediction
# Lastly the binary logistic function
classone_probs = prediction
classzero_probs = 1.0 - classone_probs
return vstack((classzero_probs, classone_probs)).transpose()
@xgboost_model_doc( @xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost classification.", "Implementation of the scikit-learn API for XGBoost classification.",
['model', 'objective'], extra_parameters=''' ['model', 'objective'], extra_parameters='''
@ -929,7 +943,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
verbose_eval=verbose, xgb_model=model, verbose_eval=verbose, xgb_model=model,
callbacks=callbacks) callbacks=callbacks)
self.objective = params["objective"] if not callable(self.objective):
self.objective = params["objective"]
if evals_result: if evals_result:
for val in evals_result.items(): for val in evals_result.items():
evals_result_key = list(val[1].keys())[0] evals_result_key = list(val[1].keys())[0]
@ -1031,7 +1047,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
a numpy array with the probability of each data example being of a given class. a numpy array of shape array-like of shape (n_samples, n_classes) with the
probability of each data example being of a given class.
""" """
test_dmatrix = DMatrix(X, base_margin=base_margin, test_dmatrix = DMatrix(X, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
@ -1040,11 +1057,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
ntree_limit=ntree_limit, ntree_limit=ntree_limit,
validate_features=validate_features) validate_features=validate_features)
if self.objective == "multi:softprob": return _cls_predict_proba(self.objective, class_probs, np.vstack)
return class_probs
classone_probs = class_probs
classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose()
def evals_result(self): def evals_result(self):
"""Return the evaluation results. """Return the evaluation results.

View File

@ -160,7 +160,7 @@ def test_boost_from_prediction(tree_method: str) -> None:
tree_method=tree_method, tree_method=tree_method,
) )
model_0.fit(X=X_, y=y_) model_0.fit(X=X_, y=y_)
margin = model_0.predict_proba(X_, output_margin=True) margin = model_0.predict(X_, output_margin=True)
model_1 = xgb.dask.DaskXGBClassifier( model_1 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3, learning_rate=0.3,

View File

@ -79,6 +79,18 @@ def test_multiclass_classification():
check_pred(preds3, labels, output_margin=True) check_pred(preds3, labels, output_margin=True)
check_pred(preds4, labels, output_margin=False) check_pred(preds4, labels, output_margin=False)
cls = xgb.XGBClassifier(n_estimators=4).fit(X, y)
assert cls.n_classes_ == 3
proba = cls.predict_proba(X)
assert proba.shape[0] == X.shape[0]
assert proba.shape[1] == cls.n_classes_
# custom objective, the default is multi:softprob so no transformation is required.
cls = xgb.XGBClassifier(n_estimators=4, objective=tm.softprob_obj(3)).fit(X, y)
proba = cls.predict_proba(X)
assert proba.shape[0] == X.shape[0]
assert proba.shape[1] == cls.n_classes_
def test_ranking(): def test_ranking():
# generate random data # generate random data
@ -788,6 +800,11 @@ def test_save_load_model():
booster.save_model(model_path) booster.save_model(model_path)
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier()
cls.load_model(model_path) cls.load_model(model_path)
proba = cls.predict_proba(X)
assert proba.shape[0] == X.shape[0]
assert proba.shape[1] == 2 # binary
predt_1 = cls.predict_proba(X)[:, 1] predt_1 = cls.predict_proba(X)[:, 1]
assert np.allclose(predt_0, predt_1) assert np.allclose(predt_0, predt_1)

View File

@ -253,6 +253,34 @@ def eval_error_metric(predt, dtrain: xgb.DMatrix):
return 'CustomErr', np.sum(r) return 'CustomErr', np.sum(r)
def softmax(x):
e = np.exp(x)
return e / np.sum(e)
def softprob_obj(classes):
def objective(labels, predt):
rows = labels.shape[0]
grad = np.zeros((rows, classes), dtype=float)
hess = np.zeros((rows, classes), dtype=float)
eps = 1e-6
for r in range(predt.shape[0]):
target = labels[r]
p = softmax(predt[r, :])
for c in range(predt.shape[1]):
assert target >= 0 or target <= classes
g = p[c] - 1.0 if c == target else p[c]
h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
grad[r, c] = g
hess[r, c] = h
grad = grad.reshape((rows * classes, 1))
hess = hess.reshape((rows * classes, 1))
return grad, hess
return objective
class DirectoryExcursion: class DirectoryExcursion:
def __init__(self, path: os.PathLike, cleanup=False): def __init__(self, path: os.PathLike, cleanup=False):
'''Change directory. Change back and optionally cleaning up the directory when exit. '''Change directory. Change back and optionally cleaning up the directory when exit.