Add special handling for multi:softmax in sklearn predict (#7607)
* Add special handling for multi:softmax in sklearn predict * Add test coverage
This commit is contained in:
parent
7f738e7f6f
commit
b4340abf56
@ -1419,6 +1419,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
# multi-label
|
# multi-label
|
||||||
column_indexes = np.zeros(class_probs.shape)
|
column_indexes = np.zeros(class_probs.shape)
|
||||||
column_indexes[class_probs > 0.5] = 1
|
column_indexes[class_probs > 0.5] = 1
|
||||||
|
elif self.objective == "multi:softmax":
|
||||||
|
return class_probs.astype(np.int32)
|
||||||
else:
|
else:
|
||||||
# turns soft logit into class label
|
# turns soft logit into class label
|
||||||
column_indexes = np.repeat(0, class_probs.shape[0])
|
column_indexes = np.repeat(0, class_probs.shape[0])
|
||||||
|
|||||||
@ -36,7 +36,8 @@ def test_binary_classification():
|
|||||||
assert err < 0.1
|
assert err < 0.1
|
||||||
|
|
||||||
|
|
||||||
def test_multiclass_classification():
|
@pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob'])
|
||||||
|
def test_multiclass_classification(objective):
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.model_selection import KFold
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ def test_multiclass_classification():
|
|||||||
X = iris['data']
|
X = iris['data']
|
||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
for train_index, test_index in kf.split(X, y):
|
for train_index, test_index in kf.split(X, y):
|
||||||
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index])
|
||||||
assert (xgb_model.get_booster().num_boosted_rounds() ==
|
assert (xgb_model.get_booster().num_boosted_rounds() ==
|
||||||
xgb_model.n_estimators)
|
xgb_model.n_estimators)
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user