Fix weighted samples in multi-class AUC. (#7300)

This commit is contained in:
Jiaming Yuan
2021-10-11 15:12:29 +08:00
committed by GitHub
parent 69d3b1b8b4
commit 298af6f409
6 changed files with 41 additions and 17 deletions

View File

@@ -291,7 +291,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
// labels is a vector of size n_samples.
float label = labels[idx % n_samples] == class_id;
float w = get_weight(i % n_samples);
float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples];
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);