Add special handling for multi:softmax in sklearn predict (#7607)

* Add special handling for multi:softmax in sklearn predict

* Add test coverage
This commit is contained in:
Philip Hyunsu Cho
2022-01-29 15:54:49 -08:00
committed by GitHub
parent 7f738e7f6f
commit b4340abf56
2 changed files with 5 additions and 2 deletions

View File

@@ -1419,6 +1419,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
# multi-label
column_indexes = np.zeros(class_probs.shape)
column_indexes[class_probs > 0.5] = 1
elif self.objective == "multi:softmax":
return class_probs.astype(np.int32)
else:
# turns soft logit into class label
column_indexes = np.repeat(0, class_probs.shape[0])