Fix weighted samples in multi-class AUC. (#7300) (#7305)

This commit is contained in:
Jiaming Yuan
2021-10-11 18:00:36 +08:00
committed by GitHub
parent c4aff733bb
commit 36e247aca4
6 changed files with 41 additions and 17 deletions

View File

@@ -291,7 +291,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
// labels is a vector of size n_samples.
float label = labels[idx % n_samples] == class_id;
float w = get_weight(i % n_samples);
float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples];
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);