Re-implement ROC-AUC. (#6747)

* Re-implement ROC-AUC.

* Binary
* MultiClass
* LTR
* Add documents.

This PR resolves a few issues:
  - Define a value when the dataset is invalid, which can happen if there's an
  empty dataset, or when the dataset contains only positive or negative values.
  - Define ROC-AUC for multi-class classification.
  - Define weighted average value for distributed setting.
  - A correct implementation for learning to rank task.  Previous
  implementation is just binary classification with averaging across groups,
  which doesn't measure ordered learning to rank.
This commit is contained in:
Jiaming Yuan
2021-03-20 16:52:40 +08:00
committed by GitHub
parent 4ee8340e79
commit bcc0277338
27 changed files with 1622 additions and 461 deletions

View File

@@ -123,3 +123,90 @@ class TestEvalMetrics:
gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0])
skl_gamma_dev = mean_gamma_deviance(y, score)
np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6)
def run_roc_auc_binary(self, tree_method, n_samples):
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
rng = np.random.RandomState(1994)
n_samples = n_samples
n_features = 10
X, y = make_classification(
n_samples,
n_features,
n_informative=n_features,
n_redundant=0,
random_state=rng
)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": tree_method,
"eval_metric": "auc",
"objective": "binary:logistic",
},
Xy,
num_boost_round=8,
)
score = booster.predict(Xy)
skl_auc = roc_auc_score(y, score)
auc = float(booster.eval(Xy).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
X = rng.randn(*X.shape)
score = booster.predict(xgb.DMatrix(X))
skl_auc = roc_auc_score(y, score)
auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc(self, n_samples):
self.run_roc_auc_binary("hist", n_samples)
def run_roc_auc_multi(self, tree_method, n_samples):
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
rng = np.random.RandomState(1994)
n_samples = n_samples
n_features = 10
n_classes = 4
X, y = make_classification(
n_samples,
n_features,
n_informative=n_features,
n_redundant=0,
n_classes=n_classes,
random_state=rng
)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": tree_method,
"eval_metric": "auc",
"objective": "multi:softprob",
"num_class": n_classes,
},
Xy,
num_boost_round=8,
)
score = booster.predict(Xy)
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
auc = float(booster.eval(Xy).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
X = rng.randn(*X.shape)
score = booster.predict(xgb.DMatrix(X))
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_multi(self, n_samples):
self.run_roc_auc_multi("hist", n_samples)