[dask, sklearn] Fix predict proba. (#6566)
* For sklearn: - Handles user defined objective function. - Handles `softmax`. * For dask: - Use the implementation from sklearn, the previous implementation doesn't perform any extra handling.
This commit is contained in:
parent
516a93d25c
commit
60cfd14349
@ -40,6 +40,7 @@ from .training import train as worker_train
|
|||||||
from .tracker import RabitTracker, get_host_ip
|
from .tracker import RabitTracker, get_host_ip
|
||||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
||||||
from .sklearn import xgboost_model_doc
|
from .sklearn import xgboost_model_doc
|
||||||
|
from .sklearn import _cls_predict_proba
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1504,6 +1505,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
|
|
||||||
|
if not callable(self.objective):
|
||||||
|
self.objective = params["objective"]
|
||||||
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
return self
|
return self
|
||||||
@ -1554,7 +1559,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
data=test_dmatrix,
|
data=test_dmatrix,
|
||||||
validate_features=validate_features,
|
validate_features=validate_features,
|
||||||
output_margin=output_margin)
|
output_margin=output_margin)
|
||||||
return pred_probs
|
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
|
||||||
|
|
||||||
# pylint: disable=arguments-differ,missing-docstring
|
# pylint: disable=arguments-differ,missing-docstring
|
||||||
def predict_proba(
|
def predict_proba(
|
||||||
@ -1593,6 +1598,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
validate_features=validate_features
|
validate_features=validate_features
|
||||||
)
|
)
|
||||||
|
if output_margin:
|
||||||
|
return pred_probs
|
||||||
|
|
||||||
if self.n_classes_ == 2:
|
if self.n_classes_ == 2:
|
||||||
preds = (pred_probs > 0.5).astype(int)
|
preds = (pred_probs > 0.5).astype(int)
|
||||||
|
|||||||
@ -819,6 +819,20 @@ class XGBModel(XGBModelBase):
|
|||||||
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
||||||
|
|
||||||
|
|
||||||
|
def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack: Callable) -> Any:
|
||||||
|
if objective == 'multi:softmax':
|
||||||
|
raise ValueError('multi:softmax objective does not support predict_proba,'
|
||||||
|
' use `multi:softprob` or `binary:logistic` instead.')
|
||||||
|
if objective == 'multi:softprob' or callable(objective):
|
||||||
|
# Return prediction directly if if objective is defined by user since we don't
|
||||||
|
# know how to perform the transformation
|
||||||
|
return prediction
|
||||||
|
# Lastly the binary logistic function
|
||||||
|
classone_probs = prediction
|
||||||
|
classzero_probs = 1.0 - classone_probs
|
||||||
|
return vstack((classzero_probs, classone_probs)).transpose()
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||||
['model', 'objective'], extra_parameters='''
|
['model', 'objective'], extra_parameters='''
|
||||||
@ -929,7 +943,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
verbose_eval=verbose, xgb_model=model,
|
verbose_eval=verbose, xgb_model=model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
|
if not callable(self.objective):
|
||||||
self.objective = params["objective"]
|
self.objective = params["objective"]
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
for val in evals_result.items():
|
for val in evals_result.items():
|
||||||
evals_result_key = list(val[1].keys())[0]
|
evals_result_key = list(val[1].keys())[0]
|
||||||
@ -1031,7 +1047,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction : numpy array
|
prediction : numpy array
|
||||||
a numpy array with the probability of each data example being of a given class.
|
a numpy array of shape array-like of shape (n_samples, n_classes) with the
|
||||||
|
probability of each data example being of a given class.
|
||||||
"""
|
"""
|
||||||
test_dmatrix = DMatrix(X, base_margin=base_margin,
|
test_dmatrix = DMatrix(X, base_margin=base_margin,
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
@ -1040,11 +1057,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
class_probs = self.get_booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
ntree_limit=ntree_limit,
|
ntree_limit=ntree_limit,
|
||||||
validate_features=validate_features)
|
validate_features=validate_features)
|
||||||
if self.objective == "multi:softprob":
|
return _cls_predict_proba(self.objective, class_probs, np.vstack)
|
||||||
return class_probs
|
|
||||||
classone_probs = class_probs
|
|
||||||
classzero_probs = 1.0 - classone_probs
|
|
||||||
return np.vstack((classzero_probs, classone_probs)).transpose()
|
|
||||||
|
|
||||||
def evals_result(self):
|
def evals_result(self):
|
||||||
"""Return the evaluation results.
|
"""Return the evaluation results.
|
||||||
|
|||||||
@ -160,7 +160,7 @@ def test_boost_from_prediction(tree_method: str) -> None:
|
|||||||
tree_method=tree_method,
|
tree_method=tree_method,
|
||||||
)
|
)
|
||||||
model_0.fit(X=X_, y=y_)
|
model_0.fit(X=X_, y=y_)
|
||||||
margin = model_0.predict_proba(X_, output_margin=True)
|
margin = model_0.predict(X_, output_margin=True)
|
||||||
|
|
||||||
model_1 = xgb.dask.DaskXGBClassifier(
|
model_1 = xgb.dask.DaskXGBClassifier(
|
||||||
learning_rate=0.3,
|
learning_rate=0.3,
|
||||||
|
|||||||
@ -79,6 +79,18 @@ def test_multiclass_classification():
|
|||||||
check_pred(preds3, labels, output_margin=True)
|
check_pred(preds3, labels, output_margin=True)
|
||||||
check_pred(preds4, labels, output_margin=False)
|
check_pred(preds4, labels, output_margin=False)
|
||||||
|
|
||||||
|
cls = xgb.XGBClassifier(n_estimators=4).fit(X, y)
|
||||||
|
assert cls.n_classes_ == 3
|
||||||
|
proba = cls.predict_proba(X)
|
||||||
|
assert proba.shape[0] == X.shape[0]
|
||||||
|
assert proba.shape[1] == cls.n_classes_
|
||||||
|
|
||||||
|
# custom objective, the default is multi:softprob so no transformation is required.
|
||||||
|
cls = xgb.XGBClassifier(n_estimators=4, objective=tm.softprob_obj(3)).fit(X, y)
|
||||||
|
proba = cls.predict_proba(X)
|
||||||
|
assert proba.shape[0] == X.shape[0]
|
||||||
|
assert proba.shape[1] == cls.n_classes_
|
||||||
|
|
||||||
|
|
||||||
def test_ranking():
|
def test_ranking():
|
||||||
# generate random data
|
# generate random data
|
||||||
@ -788,6 +800,11 @@ def test_save_load_model():
|
|||||||
booster.save_model(model_path)
|
booster.save_model(model_path)
|
||||||
cls = xgb.XGBClassifier()
|
cls = xgb.XGBClassifier()
|
||||||
cls.load_model(model_path)
|
cls.load_model(model_path)
|
||||||
|
|
||||||
|
proba = cls.predict_proba(X)
|
||||||
|
assert proba.shape[0] == X.shape[0]
|
||||||
|
assert proba.shape[1] == 2 # binary
|
||||||
|
|
||||||
predt_1 = cls.predict_proba(X)[:, 1]
|
predt_1 = cls.predict_proba(X)[:, 1]
|
||||||
assert np.allclose(predt_0, predt_1)
|
assert np.allclose(predt_0, predt_1)
|
||||||
|
|
||||||
|
|||||||
@ -253,6 +253,34 @@ def eval_error_metric(predt, dtrain: xgb.DMatrix):
|
|||||||
return 'CustomErr', np.sum(r)
|
return 'CustomErr', np.sum(r)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x):
|
||||||
|
e = np.exp(x)
|
||||||
|
return e / np.sum(e)
|
||||||
|
|
||||||
|
|
||||||
|
def softprob_obj(classes):
|
||||||
|
def objective(labels, predt):
|
||||||
|
rows = labels.shape[0]
|
||||||
|
grad = np.zeros((rows, classes), dtype=float)
|
||||||
|
hess = np.zeros((rows, classes), dtype=float)
|
||||||
|
eps = 1e-6
|
||||||
|
for r in range(predt.shape[0]):
|
||||||
|
target = labels[r]
|
||||||
|
p = softmax(predt[r, :])
|
||||||
|
for c in range(predt.shape[1]):
|
||||||
|
assert target >= 0 or target <= classes
|
||||||
|
g = p[c] - 1.0 if c == target else p[c]
|
||||||
|
h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
|
||||||
|
grad[r, c] = g
|
||||||
|
hess[r, c] = h
|
||||||
|
|
||||||
|
grad = grad.reshape((rows * classes, 1))
|
||||||
|
hess = hess.reshape((rows * classes, 1))
|
||||||
|
return grad, hess
|
||||||
|
|
||||||
|
return objective
|
||||||
|
|
||||||
|
|
||||||
class DirectoryExcursion:
|
class DirectoryExcursion:
|
||||||
def __init__(self, path: os.PathLike, cleanup=False):
|
def __init__(self, path: os.PathLike, cleanup=False):
|
||||||
'''Change directory. Change back and optionally cleaning up the directory when exit.
|
'''Change directory. Change back and optionally cleaning up the directory when exit.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user