Initial support for multi-label classification. (#7521)

* Add support in sklearn classifier.
This commit is contained in:
Jiaming Yuan
2022-01-04 23:58:21 +08:00
committed by GitHub
parent 68cdbc9c16
commit 8f0a42a266
4 changed files with 70 additions and 2 deletions

View File

@@ -1215,6 +1215,14 @@ PredtT = TypeVar("PredtT", bound=np.ndarray)
def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> PredtT:
assert len(prediction.shape) <= 2
if len(prediction.shape) == 2 and prediction.shape[1] == n_classes:
# multi-class
return prediction
if (
len(prediction.shape) == 2
and n_classes == 2
and prediction.shape[1] >= n_classes
):
# multi-label
return prediction
# binary logistic function
classone_probs = prediction
@@ -1374,9 +1382,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
# If output_margin is active, simply return the scores
return class_probs
if len(class_probs.shape) > 1:
# turns softprob into softmax
if len(class_probs.shape) > 1 and self.n_classes_ != 2:
# multi-class, turns softprob into softmax
column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore
elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1:
# multi-label
column_indexes = np.zeros(class_probs.shape)
column_indexes[class_probs > 0.5] = 1
else:
# turns soft logit into class label
column_indexes = np.repeat(0, class_probs.shape[0])