Add special handling for multi:softmax in sklearn predict (#7607)

* Add special handling for multi:softmax in sklearn predict

* Add test coverage
This commit is contained in:
Philip Hyunsu Cho 2022-01-29 15:54:49 -08:00 committed by GitHub
parent 7f738e7f6f
commit b4340abf56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -1419,6 +1419,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
# multi-label # multi-label
column_indexes = np.zeros(class_probs.shape) column_indexes = np.zeros(class_probs.shape)
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
elif self.objective == "multi:softmax":
return class_probs.astype(np.int32)
else: else:
# turns soft logit into class label # turns soft logit into class label
column_indexes = np.repeat(0, class_probs.shape[0]) column_indexes = np.repeat(0, class_probs.shape[0])

View File

@ -36,7 +36,8 @@ def test_binary_classification():
assert err < 0.1 assert err < 0.1
def test_multiclass_classification(): @pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob'])
def test_multiclass_classification(objective):
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
@ -54,7 +55,7 @@ def test_multiclass_classification():
X = iris['data'] X = iris['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index])
assert (xgb_model.get_booster().num_boosted_rounds() == assert (xgb_model.get_booster().num_boosted_rounds() ==
xgb_model.n_estimators) xgb_model.n_estimators)
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])