Initial support for multi-label classification. (#7521)

* Add support in sklearn classifier.
This commit is contained in:
Jiaming Yuan
2022-01-04 23:58:21 +08:00
committed by GitHub
parent 68cdbc9c16
commit 8f0a42a266
4 changed files with 70 additions and 2 deletions

View File

@@ -1194,6 +1194,24 @@ def test_estimator_type():
cls.load_model(path) # no error
def test_multilabel_classification() -> None:
from sklearn.datasets import make_multilabel_classification
X, y = make_multilabel_classification(
n_samples=32, n_classes=5, n_labels=3, random_state=0
)
clf = xgb.XGBClassifier(tree_method="hist")
clf.fit(X, y)
booster = clf.get_booster()
learner = json.loads(booster.save_config())["learner"]
assert int(learner["learner_model_param"]["num_target"]) == 5
np.testing.assert_allclose(clf.predict(X), y)
predt = (clf.predict_proba(X) > 0.5).astype(np.int64)
np.testing.assert_allclose(clf.predict(X), predt)
assert predt.dtype == np.int64
def run_data_initialization(DMatrix, model, X, y):
"""Assert that we don't create duplicated DMatrix."""