Initial support for multi-label classification. (#7521)
* Add support in sklearn classifier.
This commit is contained in:
@@ -1194,6 +1194,24 @@ def test_estimator_type():
|
||||
cls.load_model(path) # no error
|
||||
|
||||
|
||||
def test_multilabel_classification() -> None:
|
||||
from sklearn.datasets import make_multilabel_classification
|
||||
|
||||
X, y = make_multilabel_classification(
|
||||
n_samples=32, n_classes=5, n_labels=3, random_state=0
|
||||
)
|
||||
clf = xgb.XGBClassifier(tree_method="hist")
|
||||
clf.fit(X, y)
|
||||
booster = clf.get_booster()
|
||||
learner = json.loads(booster.save_config())["learner"]
|
||||
assert int(learner["learner_model_param"]["num_target"]) == 5
|
||||
|
||||
np.testing.assert_allclose(clf.predict(X), y)
|
||||
predt = (clf.predict_proba(X) > 0.5).astype(np.int64)
|
||||
np.testing.assert_allclose(clf.predict(X), predt)
|
||||
assert predt.dtype == np.int64
|
||||
|
||||
|
||||
def run_data_initialization(DMatrix, model, X, y):
|
||||
"""Assert that we don't create duplicated DMatrix."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user