parent
c4aff733bb
commit
36e247aca4
@ -291,7 +291,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
|
||||
// labels is a vector of size n_samples.
|
||||
float label = labels[idx % n_samples] == class_id;
|
||||
|
||||
float w = get_weight(i % n_samples);
|
||||
float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples];
|
||||
float fp = (1.0 - label) * w;
|
||||
float tp = label * w;
|
||||
return thrust::make_pair(fp, tp);
|
||||
|
||||
@ -143,7 +143,7 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||
}
|
||||
|
||||
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> preds,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups) {
|
||||
|
||||
@ -86,7 +86,7 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
||||
|
||||
xgboost::bst_float GetMetricEval(
|
||||
xgboost::Metric * metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> preds,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float>(),
|
||||
std::vector<xgboost::bst_uint> groups = std::vector<xgboost::bst_uint>());
|
||||
|
||||
@ -90,6 +90,16 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) {
|
||||
},
|
||||
{0, 1, 1}); // no class 2.
|
||||
EXPECT_TRUE(std::isnan(auc)) << auc;
|
||||
|
||||
HostDeviceVector<float> predts{
|
||||
0.0f, 1.0f, 0.0f,
|
||||
1.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f,
|
||||
0.0f, 0.0f, 1.0f,
|
||||
};
|
||||
std::vector<float> labels {1.0f, 0.0f, 2.0f, 1.0f};
|
||||
auc = GetMetricEval(metric, predts, labels, {1.0f, 2.0f, 3.0f, 4.0f});
|
||||
ASSERT_GT(auc, 0.714);
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(RankingAUC)) {
|
||||
|
||||
@ -13,9 +13,11 @@ class TestGPUEvalMetrics:
|
||||
def test_roc_auc_binary(self, n_samples):
|
||||
self.cpu_test.run_roc_auc_binary("gpu_hist", n_samples)
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
|
||||
def test_roc_auc_multi(self, n_samples):
|
||||
self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples)
|
||||
@pytest.mark.parametrize(
|
||||
"n_samples,weighted", [(4, False), (100, False), (1000, False), (1000, True)]
|
||||
)
|
||||
def test_roc_auc_multi(self, n_samples, weighted):
|
||||
self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples, weighted)
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
|
||||
def test_roc_auc_ltr(self, n_samples):
|
||||
|
||||
@ -191,11 +191,11 @@ class TestEvalMetrics:
|
||||
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
|
||||
@pytest.mark.parametrize("n_samples", [100, 1000])
|
||||
def test_roc_auc(self, n_samples):
|
||||
self.run_roc_auc_binary("hist", n_samples)
|
||||
|
||||
def run_roc_auc_multi(self, tree_method, n_samples):
|
||||
def run_roc_auc_multi(self, tree_method, n_samples, weighted):
|
||||
import numpy as np
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.metrics import roc_auc_score
|
||||
@ -213,8 +213,14 @@ class TestEvalMetrics:
|
||||
n_classes=n_classes,
|
||||
random_state=rng
|
||||
)
|
||||
if weighted:
|
||||
weights = rng.randn(n_samples)
|
||||
weights -= weights.min()
|
||||
weights /= weights.max()
|
||||
else:
|
||||
weights = None
|
||||
|
||||
Xy = xgb.DMatrix(X, y)
|
||||
Xy = xgb.DMatrix(X, y, weight=weights)
|
||||
booster = xgb.train(
|
||||
{
|
||||
"tree_method": tree_method,
|
||||
@ -226,16 +232,22 @@ class TestEvalMetrics:
|
||||
num_boost_round=8,
|
||||
)
|
||||
score = booster.predict(Xy)
|
||||
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
|
||||
skl_auc = roc_auc_score(
|
||||
y, score, average="weighted", sample_weight=weights, multi_class="ovr"
|
||||
)
|
||||
auc = float(booster.eval(Xy).split(":")[1])
|
||||
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
|
||||
|
||||
X = rng.randn(*X.shape)
|
||||
score = booster.predict(xgb.DMatrix(X))
|
||||
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
|
||||
auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1])
|
||||
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
|
||||
score = booster.predict(xgb.DMatrix(X, weight=weights))
|
||||
skl_auc = roc_auc_score(
|
||||
y, score, average="weighted", sample_weight=weights, multi_class="ovr"
|
||||
)
|
||||
auc = float(booster.eval(xgb.DMatrix(X, y, weight=weights)).split(":")[1])
|
||||
np.testing.assert_allclose(skl_auc, auc, rtol=1e-5)
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
|
||||
def test_roc_auc_multi(self, n_samples):
|
||||
self.run_roc_auc_multi("hist", n_samples)
|
||||
@pytest.mark.parametrize(
|
||||
"n_samples,weighted", [(4, False), (100, False), (1000, False), (1000, True)]
|
||||
)
|
||||
def test_roc_auc_multi(self, n_samples, weighted):
|
||||
self.run_roc_auc_multi("hist", n_samples, weighted)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user