Initial support for multi-label classification. (#7521)
* Add support in sklearn classifier.
This commit is contained in:
@@ -1215,6 +1215,14 @@ PredtT = TypeVar("PredtT", bound=np.ndarray)
|
||||
def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> PredtT:
|
||||
assert len(prediction.shape) <= 2
|
||||
if len(prediction.shape) == 2 and prediction.shape[1] == n_classes:
|
||||
# multi-class
|
||||
return prediction
|
||||
if (
|
||||
len(prediction.shape) == 2
|
||||
and n_classes == 2
|
||||
and prediction.shape[1] >= n_classes
|
||||
):
|
||||
# multi-label
|
||||
return prediction
|
||||
# binary logistic function
|
||||
classone_probs = prediction
|
||||
@@ -1374,9 +1382,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
# If output_margin is active, simply return the scores
|
||||
return class_probs
|
||||
|
||||
if len(class_probs.shape) > 1:
|
||||
# turns softprob into softmax
|
||||
if len(class_probs.shape) > 1 and self.n_classes_ != 2:
|
||||
# multi-class, turns softprob into softmax
|
||||
column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore
|
||||
elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1:
|
||||
# multi-label
|
||||
column_indexes = np.zeros(class_probs.shape)
|
||||
column_indexes[class_probs > 0.5] = 1
|
||||
else:
|
||||
# turns soft logit into class label
|
||||
column_indexes = np.repeat(0, class_probs.shape[0])
|
||||
|
||||
Reference in New Issue
Block a user