Use double precision in metric calculation. (#7364)

This commit is contained in:
Jiaming Yuan 2021-11-02 12:00:32 +08:00 committed by GitHub
parent 239dbb3c0a
commit 0f7a9b42f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 219 additions and 224 deletions

View File

@ -58,9 +58,8 @@ class Metric : public Configurable {
* the average statistics across all the node, * the average statistics across all the node,
* this is only supported by some metrics * this is only supported by some metrics
*/ */
virtual bst_float Eval(const HostDeviceVector<bst_float>& preds, virtual double Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo& info, const MetaInfo &info, bool distributed) = 0;
bool distributed) = 0;
/*! \return name of metric */ /*! \return name of metric */
virtual const char* Name() const = 0; virtual const char* Name() const = 0;
/*! \brief virtual destructor */ /*! \brief virtual destructor */

View File

@ -1110,6 +1110,7 @@ class LearnerImpl : public LearnerIO {
this->Configure(); this->Configure();
std::ostringstream os; std::ostringstream os;
os.precision(std::numeric_limits<double>::max_digits10);
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed); os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) { if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
auto warn_default_eval_metric = [](const std::string& objective, const std::string& before, auto warn_default_eval_metric = [](const std::string& objective, const std::string& before,

View File

@ -30,7 +30,7 @@ namespace metric {
* handle the normalization. * handle the normalization.
*/ */
template <typename Fn> template <typename Fn>
std::tuple<float, float, float> std::tuple<double, double, double>
BinaryAUC(common::Span<float const> predts, common::Span<float const> labels, BinaryAUC(common::Span<float const> predts, common::Span<float const> labels,
OptionalWeights weights, OptionalWeights weights,
std::vector<size_t> const &sorted_idx, Fn &&area_fn) { std::vector<size_t> const &sorted_idx, Fn &&area_fn) {
@ -39,12 +39,12 @@ BinaryAUC(common::Span<float const> predts, common::Span<float const> labels,
auto p_predts = predts.data(); auto p_predts = predts.data();
auto p_labels = labels.data(); auto p_labels = labels.data();
float auc{0}; double auc{0};
float label = p_labels[sorted_idx.front()]; float label = p_labels[sorted_idx.front()];
float w = weights[sorted_idx[0]]; float w = weights[sorted_idx[0]];
float fp = (1.0 - label) * w, tp = label * w; double fp = (1.0 - label) * w, tp = label * w;
float tp_prev = 0, fp_prev = 0; double tp_prev = 0, fp_prev = 0;
// TODO(jiaming): We can parallize this if we have a parallel scan for CPU. // TODO(jiaming): We can parallize this if we have a parallel scan for CPU.
for (size_t i = 1; i < sorted_idx.size(); ++i) { for (size_t i = 1; i < sorted_idx.size(); ++i) {
if (p_predts[sorted_idx[i]] != p_predts[sorted_idx[i - 1]]) { if (p_predts[sorted_idx[i]] != p_predts[sorted_idx[i - 1]]) {
@ -77,17 +77,19 @@ BinaryAUC(common::Span<float const> predts, common::Span<float const> labels,
* Machine Learning Models * Machine Learning Models
*/ */
template <typename BinaryAUC> template <typename BinaryAUC>
float MultiClassOVR(common::Span<float const> predts, MetaInfo const &info, double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads, size_t n_classes, int32_t n_threads,
BinaryAUC &&binary_auc) { BinaryAUC &&binary_auc) {
CHECK_NE(n_classes, 0); CHECK_NE(n_classes, 0);
auto const &labels = info.labels_.ConstHostVector(); auto const &labels = info.labels_.ConstHostVector();
std::vector<float> results(n_classes * 3, 0); std::vector<double> results(n_classes * 3, 0);
auto s_results = common::Span<float>(results); auto s_results = common::Span<double>(results);
auto local_area = s_results.subspan(0, n_classes); auto local_area = s_results.subspan(0, n_classes);
auto tp = s_results.subspan(n_classes, n_classes); auto tp = s_results.subspan(n_classes, n_classes);
auto auc = s_results.subspan(2 * n_classes, n_classes); auto auc = s_results.subspan(2 * n_classes, n_classes);
auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; auto weights = OptionalWeights{info.weights_.ConstHostSpan()};
if (!info.labels_.Empty()) { if (!info.labels_.Empty()) {
@ -98,7 +100,7 @@ float MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
proba[i] = predts[i * n_classes + c]; proba[i] = predts[i * n_classes + c];
response[i] = labels[i] == c ? 1.0f : 0.0; response[i] = labels[i] == c ? 1.0f : 0.0;
} }
float fp; double fp;
std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights); std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights);
local_area[c] = fp * tp[c]; local_area[c] = fp * tp[c];
}); });
@ -107,8 +109,8 @@ float MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
// we have 2 averages going in here, first is among workers, second is among // we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class. // classes. allreduce sums up fp/tp auc for each class.
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size()); rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
float auc_sum{0}; double auc_sum{0};
float tp_sum{0}; double tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) { for (size_t c = 0; c < n_classes; ++c) {
if (local_area[c] != 0) { if (local_area[c] != 0) {
// normalize and weight it by prevalence. After allreduce, `local_area` // normalize and weight it by prevalence. After allreduce, `local_area`
@ -117,21 +119,21 @@ float MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
auc_sum += auc[c] / local_area[c] * tp[c]; auc_sum += auc[c] / local_area[c] * tp[c];
tp_sum += tp[c]; tp_sum += tp[c];
} else { } else {
auc_sum = std::numeric_limits<float>::quiet_NaN(); auc_sum = std::numeric_limits<double>::quiet_NaN();
break; break;
} }
} }
if (tp_sum == 0 || std::isnan(auc_sum)) { if (tp_sum == 0 || std::isnan(auc_sum)) {
auc_sum = std::numeric_limits<float>::quiet_NaN(); auc_sum = std::numeric_limits<double>::quiet_NaN();
} else { } else {
auc_sum /= tp_sum; auc_sum /= tp_sum;
} }
return auc_sum; return auc_sum;
} }
std::tuple<float, float, float> BinaryROCAUC(common::Span<float const> predts, std::tuple<double, double, double>
common::Span<float const> labels, BinaryROCAUC(common::Span<float const> predts, common::Span<float const> labels,
OptionalWeights weights) { OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{}); auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea);
} }
@ -139,14 +141,14 @@ std::tuple<float, float, float> BinaryROCAUC(common::Span<float const> predts,
/** /**
* Calculate AUC for 1 ranking group; * Calculate AUC for 1 ranking group;
*/ */
float GroupRankingROC(common::Span<float const> predts, double GroupRankingROC(common::Span<float const> predts,
common::Span<float const> labels, float w) { common::Span<float const> labels, float w) {
// on ranking, we just count all pairs. // on ranking, we just count all pairs.
float auc{0}; double auc{0};
auto const sorted_idx = common::ArgSort<size_t>(labels, std::greater<>{}); auto const sorted_idx = common::ArgSort<size_t>(labels, std::greater<>{});
w = common::Sqr(w); w = common::Sqr(w);
float sum_w = 0.0f; double sum_w = 0.0f;
for (size_t i = 0; i < labels.size(); ++i) { for (size_t i = 0; i < labels.size(); ++i) {
for (size_t j = i + 1; j < labels.size(); ++j) { for (size_t j = i + 1; j < labels.size(); ++j) {
auto predt = predts[sorted_idx[i]] - predts[sorted_idx[j]]; auto predt = predts[sorted_idx[i]] - predts[sorted_idx[j]];
@ -173,11 +175,11 @@ float GroupRankingROC(common::Span<float const> predts,
* *
* https://doi.org/10.1371/journal.pone.0092209 * https://doi.org/10.1371/journal.pone.0092209
*/ */
std::tuple<float, float, float> BinaryPRAUC(common::Span<float const> predts, std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
common::Span<float const> labels, common::Span<float const> labels,
OptionalWeights weights) { OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{}); auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
float total_pos{0}, total_neg{0}; double total_pos{0}, total_neg{0};
for (size_t i = 0; i < labels.size(); ++i) { for (size_t i = 0; i < labels.size(); ++i) {
auto w = weights[i]; auto w = weights[i];
total_pos += w * labels[i]; total_pos += w * labels[i];
@ -186,22 +188,22 @@ std::tuple<float, float, float> BinaryPRAUC(common::Span<float const> predts,
if (total_pos <= 0 || total_neg <= 0) { if (total_pos <= 0 || total_neg <= 0) {
return {1.0f, 1.0f, std::numeric_limits<float>::quiet_NaN()}; return {1.0f, 1.0f, std::numeric_limits<float>::quiet_NaN()};
} }
auto fn = [total_pos](float fp_prev, float fp, float tp_prev, float tp) { auto fn = [total_pos](double fp_prev, double fp, double tp_prev, double tp) {
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos);
}; };
float tp{0}, fp{0}, auc{0}; double tp{0}, fp{0}, auc{0};
std::tie(fp, tp, auc) = BinaryAUC(predts, labels, weights, sorted_idx, fn); std::tie(fp, tp, auc) = BinaryAUC(predts, labels, weights, sorted_idx, fn);
return std::make_tuple(1.0, 1.0, auc); return std::make_tuple(1.0, 1.0, auc);
} }
/** /**
* Cast LTR problem to binary classification problem by comparing pairs. * Cast LTR problem to binary classification problem by comparing pairs.
*/ */
template <bool is_roc> template <bool is_roc>
std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts, std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
MetaInfo const &info, int32_t n_threads) { MetaInfo const &info,
int32_t n_threads) {
CHECK_GE(info.group_ptr_.size(), 2); CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1; uint32_t n_groups = info.group_ptr_.size() - 1;
auto s_predts = common::Span<float const>{predts}; auto s_predts = common::Span<float const>{predts};
@ -217,7 +219,7 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; float w = s_weights.empty() ? 1.0f : s_weights[g - 1];
auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt);
auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt);
float auc; double auc;
if (is_roc && g_labels.size() < 3) { if (is_roc && g_labels.size() < 3) {
// With 2 documents, there's only 1 comparison can be made. So either // With 2 documents, there's only 1 comparison can be made. So either
// TP or FP will be zero. // TP or FP will be zero.
@ -236,16 +238,16 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
} }
auc_tloc[omp_get_thread_num()] += auc; auc_tloc[omp_get_thread_num()] += auc;
}); });
float sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0); double sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0);
return std::make_pair(sum_auc, n_groups - invalid_groups); return std::make_pair(sum_auc, n_groups - invalid_groups);
} }
template <typename Curve> template <typename Curve>
class EvalAUC : public Metric { class EvalAUC : public Metric {
float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override { bool distributed) override {
float auc {0}; double auc {0};
if (tparam_->gpu_id != GenericParameter::kCpuId) { if (tparam_->gpu_id != GenericParameter::kCpuId) {
preds.SetDevice(tparam_->gpu_id); preds.SetDevice(tparam_->gpu_id);
info.labels_.SetDevice(tparam_->gpu_id); info.labels_.SetDevice(tparam_->gpu_id);
@ -256,7 +258,7 @@ class EvalAUC : public Metric {
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size()); rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
if (meta[0] == 0) { if (meta[0] == 0) {
// Empty across all workers, which is not supported. // Empty across all workers, which is not supported.
auc = std::numeric_limits<float>::quiet_NaN(); auc = std::numeric_limits<double>::quiet_NaN();
} else if (!info.group_ptr_.empty()) { } else if (!info.group_ptr_.empty()) {
/** /**
* learning to rank * learning to rank
@ -274,13 +276,13 @@ class EvalAUC : public Metric {
InvalidGroupAUC(); InvalidGroupAUC();
} }
std::array<float, 2> results{auc, static_cast<float>(valid_groups)}; std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size()); rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
auc = results[0]; auc = results[0];
valid_groups = static_cast<uint32_t>(results[1]); valid_groups = static_cast<uint32_t>(results[1]);
if (valid_groups <= 0) { if (valid_groups <= 0) {
auc = std::numeric_limits<float>::quiet_NaN(); auc = std::numeric_limits<double>::quiet_NaN();
} else { } else {
auc /= valid_groups; auc /= valid_groups;
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
@ -297,18 +299,18 @@ class EvalAUC : public Metric {
/** /**
* binary classification * binary classification
*/ */
float fp{0}, tp{0}; double fp{0}, tp{0};
if (!(preds.Empty() || info.labels_.Empty())) { if (!(preds.Empty() || info.labels_.Empty())) {
std::tie(fp, tp, auc) = std::tie(fp, tp, auc) =
static_cast<Curve *>(this)->EvalBinary(preds, info); static_cast<Curve *>(this)->EvalBinary(preds, info);
} }
float local_area = fp * tp; double local_area = fp * tp;
std::array<float, 2> result{auc, local_area}; std::array<double, 2> result{auc, local_area};
rabit::Allreduce<rabit::op::Sum>(result.data(), result.size()); rabit::Allreduce<rabit::op::Sum>(result.data(), result.size());
std::tie(auc, local_area) = common::UnpackArr(std::move(result)); std::tie(auc, local_area) = common::UnpackArr(std::move(result));
if (local_area <= 0) { if (local_area <= 0) {
// the dataset across all workers have only positive or negative sample // the dataset across all workers have only positive or negative sample
auc = std::numeric_limits<float>::quiet_NaN(); auc = std::numeric_limits<double>::quiet_NaN();
} else { } else {
CHECK_LE(auc, local_area); CHECK_LE(auc, local_area);
// normalization // normalization
@ -326,9 +328,9 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
std::shared_ptr<DeviceAUCCache> d_cache_; std::shared_ptr<DeviceAUCCache> d_cache_;
public: public:
std::pair<float, uint32_t> EvalRanking(HostDeviceVector<float> const &predts, std::pair<double, uint32_t> EvalRanking(HostDeviceVector<float> const &predts,
MetaInfo const &info) { MetaInfo const &info) {
float auc{0}; double auc{0};
uint32_t valid_groups = 0; uint32_t valid_groups = 0;
auto n_threads = tparam_->Threads(); auto n_threads = tparam_->Threads();
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
@ -341,9 +343,9 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
return std::make_pair(auc, valid_groups); return std::make_pair(auc, valid_groups);
} }
float EvalMultiClass(HostDeviceVector<float> const &predts, double EvalMultiClass(HostDeviceVector<float> const &predts,
MetaInfo const &info, size_t n_classes) { MetaInfo const &info, size_t n_classes) {
float auc{0}; double auc{0};
auto n_threads = tparam_->Threads(); auto n_threads = tparam_->Threads();
CHECK_NE(n_classes, 0); CHECK_NE(n_classes, 0);
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
@ -356,9 +358,9 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
return auc; return auc;
} }
std::tuple<float, float, float> std::tuple<double, double, double>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) { EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
float fp, tp, auc; double fp, tp, auc;
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(fp, tp, auc) = std::tie(fp, tp, auc) =
BinaryROCAUC(predts.ConstHostVector(), info.labels_.ConstHostVector(), BinaryROCAUC(predts.ConstHostVector(), info.labels_.ConstHostVector(),
@ -381,37 +383,37 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
.set_body([](const char*) { return new EvalROCAUC(); }); .set_body([](const char*) { return new EvalROCAUC(); });
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return std::make_tuple(0.0f, 0.0f, 0.0f); return {};
} }
float GPUMultiClassROCAUC(common::Span<float const> predts, double GPUMultiClassROCAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device, MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *cache, std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes) { size_t n_classes) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return 0; return 0.0;
} }
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info, GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return std::make_pair(0.0f, 0u); return {};
} }
struct DeviceAUCCache {}; struct DeviceAUCCache {};
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
class EvalAUCPR : public EvalAUC<EvalAUCPR> { class EvalPRAUC : public EvalAUC<EvalPRAUC> {
std::shared_ptr<DeviceAUCCache> d_cache_; std::shared_ptr<DeviceAUCCache> d_cache_;
public: public:
std::tuple<float, float, float> std::tuple<double, double, double>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) { EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
float pr, re, auc; double pr, re, auc;
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(pr, re, auc) = std::tie(pr, re, auc) =
BinaryPRAUC(predts.ConstHostSpan(), info.labels_.ConstHostSpan(), BinaryPRAUC(predts.ConstHostSpan(), info.labels_.ConstHostSpan(),
@ -423,7 +425,7 @@ class EvalAUCPR : public EvalAUC<EvalAUCPR> {
return std::make_tuple(pr, re, auc); return std::make_tuple(pr, re, auc);
} }
float EvalMultiClass(HostDeviceVector<float> const &predts, double EvalMultiClass(HostDeviceVector<float> const &predts,
MetaInfo const &info, size_t n_classes) { MetaInfo const &info, size_t n_classes) {
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
auto n_threads = this->tparam_->Threads(); auto n_threads = this->tparam_->Threads();
@ -435,9 +437,9 @@ class EvalAUCPR : public EvalAUC<EvalAUCPR> {
} }
} }
std::pair<float, uint32_t> EvalRanking(HostDeviceVector<float> const &predts, std::pair<double, uint32_t> EvalRanking(HostDeviceVector<float> const &predts,
MetaInfo const &info) { MetaInfo const &info) {
float auc{0}; double auc{0};
uint32_t valid_groups = 0; uint32_t valid_groups = 0;
auto n_threads = tparam_->Threads(); auto n_threads = tparam_->Threads();
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
@ -460,24 +462,25 @@ class EvalAUCPR : public EvalAUC<EvalAUCPR> {
XGBOOST_REGISTER_METRIC(AUCPR, "aucpr") XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
.describe("Area under PR curve for both classification and rank.") .describe("Area under PR curve for both classification and rank.")
.set_body([](char const *) { return new EvalAUCPR{}; }); .set_body([](char const *) { return new EvalPRAUC{}; });
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }
float GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info, double GPUMultiClassPRAUC(common::Span<float const> predts,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache, MetaInfo const &info, int32_t device,
size_t n_classes) { std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes) {
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info, GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *cache) {
common::AssertGPUSupport(); common::AssertGPUSupport();

View File

@ -22,7 +22,7 @@ namespace xgboost {
namespace metric { namespace metric {
namespace { namespace {
// Pair of FP/TP // Pair of FP/TP
using Pair = thrust::pair<float, float>; using Pair = thrust::pair<double, double>;
template <typename T, typename U, typename P = thrust::pair<T, U>> template <typename T, typename U, typename P = thrust::pair<T, U>>
struct PairPlus : public thrust::binary_function<P, P, P> { struct PairPlus : public thrust::binary_function<P, P, P> {
@ -38,9 +38,9 @@ struct PairPlus : public thrust::binary_function<P, P, P> {
struct DeviceAUCCache { struct DeviceAUCCache {
// index sorted by prediction value // index sorted by prediction value
dh::device_vector<size_t> sorted_idx; dh::device_vector<size_t> sorted_idx;
// track FP/TP for computation on trapesoid area // track FP/TP for computation on trapezoid area
dh::device_vector<Pair> fptp; dh::device_vector<Pair> fptp;
// track FP_PREV/TP_PREV for computation on trapesoid area // track FP_PREV/TP_PREV for computation on trapezoid area
dh::device_vector<Pair> neg_pos; dh::device_vector<Pair> neg_pos;
// index of unique prediction values. // index of unique prediction values.
dh::device_vector<size_t> unique_idx; dh::device_vector<size_t> unique_idx;
@ -79,13 +79,13 @@ void InitCacheOnce(common::Span<float const> predts, int32_t device,
* The GPU implementation uses same calculation as CPU with a few more steps to distribute * The GPU implementation uses same calculation as CPU with a few more steps to distribute
* work across threads: * work across threads:
* *
* - Run scan to obtain TP/FP values, which are right coordinates of trapesoid. * - Run scan to obtain TP/FP values, which are right coordinates of trapezoid.
* - Find distinct prediction values and get the corresponding FP_PREV/TP_PREV value, * - Find distinct prediction values and get the corresponding FP_PREV/TP_PREV value,
* which are left coordinates of trapesoids. * which are left coordinates of trapezoids.
* - Reduce the scan array into 1 AUC value. * - Reduce the scan array into 1 AUC value.
*/ */
template <typename Fn> template <typename Fn>
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, common::Span<size_t const> d_sorted_idx, int32_t device, common::Span<size_t const> d_sorted_idx,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) { Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
@ -129,7 +129,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx)); d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp), dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp),
PairPlus<float, float>{}, d_fptp.size()); PairPlus<double, double>{}, d_fptp.size());
auto d_neg_pos = dh::ToSpan(cache->neg_pos); auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// scatter unique negaive/positive values // scatter unique negaive/positive values
@ -149,10 +149,10 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
} }
}); });
auto in = dh::MakeTransformIterator<float>( auto in = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
float fp, tp; double fp, tp;
float fp_prev, tp_prev; double fp_prev, tp_prev;
if (i == 0) { if (i == 0) {
// handle the last element // handle the last element
thrust::tie(fp, tp) = d_fptp.back(); thrust::tie(fp, tp) = d_fptp.back();
@ -165,11 +165,11 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
}); });
Pair last = cache->fptp.back(); Pair last = cache->fptp.back();
float auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size()); double auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size());
return std::make_tuple(last.first, last.second, auc); return std::make_tuple(last.first, last.second, auc);
} }
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto &cache = *p_cache; auto &cache = *p_cache;
@ -183,7 +183,7 @@ GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
// Create lambda to avoid pass function pointer. // Create lambda to avoid pass function pointer.
return GPUBinaryAUC( return GPUBinaryAUC(
predts, info, device, d_sorted_idx, predts, info, device, d_sorted_idx,
[] XGBOOST_DEVICE(float x0, float x1, float y0, float y1) { [] XGBOOST_DEVICE(double x0, double x1, double y0, double y1) -> double {
return TrapezoidArea(x0, x1, y0, y1); return TrapezoidArea(x0, x1, y0, y1);
}, },
cache); cache);
@ -209,33 +209,32 @@ XGBOOST_DEVICE size_t LastOf(size_t group, common::Span<Idx> indptr) {
return indptr[group + 1] - 1; return indptr[group + 1] - 1;
} }
double ScaleClasses(common::Span<double> results,
float ScaleClasses(common::Span<float> results, common::Span<float> local_area, common::Span<double> local_area, common::Span<double> fp,
common::Span<float> fp, common::Span<float> tp, common::Span<double> tp, common::Span<double> auc,
common::Span<float> auc, std::shared_ptr<DeviceAUCCache> cache, std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
size_t n_classes) {
dh::XGBDeviceAllocator<char> alloc; dh::XGBDeviceAllocator<char> alloc;
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice()); CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
cache->reducer->AllReduceSum(results.data(), results.data(), results.size()); cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
} }
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>( auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
if (local_area[i] > 0) { if (local_area[i] > 0) {
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]); return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
} }
return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f); return thrust::make_pair(std::numeric_limits<double>::quiet_NaN(), 0.0);
}); });
float tp_sum; double tp_sum;
float auc_sum; double auc_sum;
thrust::tie(auc_sum, tp_sum) = thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0f, 0.0f}, PairPlus<float, float>{}); Pair{0.0, 0.0}, PairPlus<double, double>{});
if (tp_sum != 0 && !std::isnan(auc_sum)) { if (tp_sum != 0 && !std::isnan(auc_sum)) {
auc_sum /= tp_sum; auc_sum /= tp_sum;
} else { } else {
return std::numeric_limits<float>::quiet_NaN(); return std::numeric_limits<double>::quiet_NaN();
} }
return auc_sum; return auc_sum;
} }
@ -246,7 +245,7 @@ float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
*/ */
template <typename Fn> template <typename Fn>
void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) { void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) {
using Triple = thrust::tuple<uint32_t, float, float>; using Triple = thrust::tuple<uint32_t, double, double>;
// expand to tuple to include idx // expand to tuple to include idx
auto fptp_it_in = dh::MakeTransformIterator<Triple>( auto fptp_it_in = dh::MakeTransformIterator<Triple>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
@ -285,7 +284,7 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
std::shared_ptr<DeviceAUCCache> cache, std::shared_ptr<DeviceAUCCache> cache,
Area area_fn, Area area_fn,
Seg segment_id, Seg segment_id,
common::Span<float> d_auc) { common::Span<double> d_auc) {
auto d_fptp = dh::ToSpan(cache->fptp); auto d_fptp = dh::ToSpan(cache->fptp);
auto d_neg_pos = dh::ToSpan(cache->neg_pos); auto d_neg_pos = dh::ToSpan(cache->neg_pos);
dh::XGBDeviceAllocator<char> alloc; dh::XGBDeviceAllocator<char> alloc;
@ -294,11 +293,11 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
size_t class_id = segment_id(d_unique_idx[i]); size_t class_id = segment_id(d_unique_idx[i]);
return class_id; return class_id;
}); });
auto val_in = dh::MakeTransformIterator<float>( auto val_in = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
size_t class_id = segment_id(d_unique_idx[i]); size_t class_id = segment_id(d_unique_idx[i]);
float fp, tp, fp_prev, tp_prev; double fp, tp, fp_prev, tp_prev;
if (i == d_unique_class_ptr[class_id]) { if (i == d_unique_class_ptr[class_id]) {
// first item is ignored, we use this thread to calculate the last item // first item is ignored, we use this thread to calculate the last item
thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)]; thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)];
@ -308,7 +307,7 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
} }
float auc = area_fn(fp_prev, fp, tp_prev, tp, class_id); double auc = area_fn(fp_prev, fp, tp_prev, tp, class_id);
return auc; return auc;
}); });
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in, thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
@ -321,10 +320,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
* up each class in all kernels. * up each class in all kernels.
*/ */
template <bool scale, typename Fn> template <bool scale, typename Fn>
float GPUMultiClassAUCOVR(common::Span<float const> predts, double GPUMultiClassAUCOVR(common::Span<float const> predts,
MetaInfo const &info, int32_t device, MetaInfo const &info, int32_t device,
common::Span<uint32_t> d_class_ptr, size_t n_classes, common::Span<uint32_t> d_class_ptr, size_t n_classes,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) { std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
/** /**
* Sorted idx * Sorted idx
@ -339,7 +338,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts,
size_t n_samples = labels.size(); size_t n_samples = labels.size();
if (n_samples == 0) { if (n_samples == 0) {
dh::TemporaryArray<float> resutls(n_classes * 4, 0.0f); dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
auto d_results = dh::ToSpan(resutls); auto d_results = dh::ToSpan(resutls);
dh::LaunchN(n_classes * 4, dh::LaunchN(n_classes * 4,
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; }); [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
@ -353,7 +352,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts,
/** /**
* Linear scan * Linear scan
*/ */
dh::caching_device_vector<float> d_auc(n_classes, 0); dh::caching_device_vector<double> d_auc(n_classes, 0);
auto get_weight = OptionalWeights{weights}; auto get_weight = OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp); auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) {
@ -432,7 +431,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts,
/** /**
* Scale the classes with number of samples for each class. * Scale the classes with number of samples for each class.
*/ */
dh::TemporaryArray<float> resutls(n_classes * 4); dh::TemporaryArray<double> resutls(n_classes * 4);
auto d_results = dh::ToSpan(resutls); auto d_results = dh::ToSpan(resutls);
auto local_area = d_results.subspan(0, n_classes); auto local_area = d_results.subspan(0, n_classes);
auto fp = d_results.subspan(n_classes, n_classes); auto fp = d_results.subspan(n_classes, n_classes);
@ -470,10 +469,10 @@ void MultiClassSortedIdx(common::Span<float const> predts,
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx); dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
} }
float GPUMultiClassROCAUC(common::Span<float const> predts, double GPUMultiClassROCAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device, MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *p_cache, std::shared_ptr<DeviceAUCCache> *p_cache,
size_t n_classes) { size_t n_classes) {
auto& cache = *p_cache; auto& cache = *p_cache;
InitCacheOnce<true>(predts, device, p_cache); InitCacheOnce<true>(predts, device, p_cache);
@ -483,8 +482,8 @@ float GPUMultiClassROCAUC(common::Span<float const> predts,
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0); dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache); MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache);
auto fn = [] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, float tp, auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
size_t /*class_id*/) { double tp, size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp); return TrapezoidArea(fp_prev, fp, tp_prev, tp);
}; };
return GPUMultiClassAUCOVR<true>(predts, info, device, dh::ToSpan(class_ptr), return GPUMultiClassAUCOVR<true>(predts, info, device, dh::ToSpan(class_ptr),
@ -494,13 +493,13 @@ float GPUMultiClassROCAUC(common::Span<float const> predts,
namespace { namespace {
struct RankScanItem { struct RankScanItem {
size_t idx; size_t idx;
float predt; double predt;
float w; double w;
bst_group_t group_id; bst_group_t group_id;
}; };
} // anonymous namespace } // anonymous namespace
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info, GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache; auto& cache = *p_cache;
@ -523,7 +522,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
InvalidGroupAUC(); InvalidGroupAUC();
} }
if (n_valid == 0) { if (n_valid == 0) {
return std::make_pair(0.0f, 0); return std::make_pair(0.0, 0);
} }
/** /**
@ -583,7 +582,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
return RankScanItem{idx, predt, w, query_group_idx}; return RankScanItem{idx, predt, w, query_group_idx};
}); });
dh::TemporaryArray<float> d_auc(group_ptr.size() - 1); dh::TemporaryArray<double> d_auc(group_ptr.size() - 1);
auto s_d_auc = dh::ToSpan(d_auc); auto s_d_auc = dh::ToSpan(d_auc);
auto out = thrust::make_transform_output_iterator( auto out = thrust::make_transform_output_iterator(
dh::TypedDiscard<RankScanItem>{}, dh::TypedDiscard<RankScanItem>{},
@ -615,12 +614,12 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
/** /**
* Scale the AUC with number of items in each group. * Scale the AUC with number of items in each group.
*/ */
float auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc), double auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0f); dh::tend(s_d_auc), 0.0);
return std::make_pair(auc, n_valid); return std::make_pair(auc, n_valid);
} }
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache; auto& cache = *p_cache;
@ -635,32 +634,32 @@ GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
auto labels = info.labels_.ConstDeviceSpan(); auto labels = info.labels_.ConstDeviceSpan();
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();
auto get_weight = OptionalWeights{d_weights}; auto get_weight = OptionalWeights{d_weights};
auto it = dh::MakeTransformIterator<thrust::pair<float, float>>( auto it = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto w = get_weight[d_sorted_idx[i]]; auto w = get_weight[d_sorted_idx[i]];
return thrust::make_pair(labels[d_sorted_idx[i]] * w, return thrust::make_pair(labels[d_sorted_idx[i]] * w,
(1.0f - labels[d_sorted_idx[i]]) * w); (1.0f - labels[d_sorted_idx[i]]) * w);
}); });
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
float total_pos, total_neg; double total_pos, total_neg;
thrust::tie(total_pos, total_neg) = thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(), thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(),
Pair{0.0f, 0.0f}, PairPlus<float, float>{}); Pair{0.0, 0.0}, PairPlus<double, double>{});
if (total_pos <= 0.0 || total_neg <= 0.0) { if (total_pos <= 0.0 || total_neg <= 0.0) {
return {0.0f, 0.0f, 0.0f}; return {0.0f, 0.0f, 0.0f};
} }
auto fn = [total_pos] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, auto fn = [total_pos] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
float tp) { double tp) {
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos);
}; };
float fp, tp, auc; double fp, tp, auc;
std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache); std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache);
return std::make_tuple(1.0, 1.0, auc); return std::make_tuple(1.0, 1.0, auc);
} }
float GPUMultiClassPRAUC(common::Span<float const> predts, double GPUMultiClassPRAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device, MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *p_cache, std::shared_ptr<DeviceAUCCache> *p_cache,
size_t n_classes) { size_t n_classes) {
@ -682,14 +681,14 @@ float GPUMultiClassPRAUC(common::Span<float const> predts,
*/ */
auto labels = info.labels_.ConstDeviceSpan(); auto labels = info.labels_.ConstDeviceSpan();
auto n_samples = info.num_row_; auto n_samples = info.num_row_;
dh::caching_device_vector<thrust::pair<float, float>> totals(n_classes); dh::caching_device_vector<Pair> totals(n_classes);
auto key_it = auto key_it =
dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul), dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul),
[n_samples] XGBOOST_DEVICE(size_t i) { [n_samples] XGBOOST_DEVICE(size_t i) {
return i / n_samples; // class id return i / n_samples; // class id
}); });
auto get_weight = OptionalWeights{d_weights}; auto get_weight = OptionalWeights{d_weights};
auto val_it = dh::MakeTransformIterator<thrust::pair<float, float>>( auto val_it = dh::MakeTransformIterator<thrust::pair<double, double>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto idx = d_sorted_idx[i] % n_samples; auto idx = d_sorted_idx[i] % n_samples;
auto w = get_weight[idx]; auto w = get_weight[idx];
@ -701,14 +700,14 @@ float GPUMultiClassPRAUC(common::Span<float const> predts,
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it, key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(), thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<float, float>{}); thrust::equal_to<size_t>{}, PairPlus<double, double>{});
/** /**
* Calculate AUC * Calculate AUC
*/ */
auto d_totals = dh::ToSpan(totals); auto d_totals = dh::ToSpan(totals);
auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, auto fn = [d_totals] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
float tp, size_t class_id) { double tp, size_t class_id) {
auto total_pos = d_totals[class_id].first; auto total_pos = d_totals[class_id].first;
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first); d_totals[class_id].first);
@ -718,7 +717,7 @@ float GPUMultiClassPRAUC(common::Span<float const> predts,
} }
template <typename Fn> template <typename Fn>
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info, GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
common::Span<uint32_t> d_group_ptr, int32_t device, common::Span<uint32_t> d_group_ptr, int32_t device,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) { std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
@ -736,7 +735,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
* Linear scan * Linear scan
*/ */
size_t n_samples = labels.size(); size_t n_samples = labels.size();
dh::caching_device_vector<float> d_auc(n_groups, 0); dh::caching_device_vector<double> d_auc(n_groups, 0);
auto get_weight = OptionalWeights{weights}; auto get_weight = OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp); auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) { auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) {
@ -816,33 +815,33 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
/** /**
* Scale the groups with number of samples for each group. * Scale the groups with number of samples for each group.
*/ */
float auc; double auc;
uint32_t invalid_groups; uint32_t invalid_groups;
{ {
auto it = dh::MakeTransformIterator<thrust::pair<float, uint32_t>>( auto it = dh::MakeTransformIterator<thrust::pair<double, uint32_t>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) {
float fp, tp; double fp, tp;
thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)];
float area = fp * tp; double area = fp * tp;
auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g]; auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g];
if (area > 0 && n_documents >= 2) { if (area > 0 && n_documents >= 2) {
return thrust::make_pair(s_d_auc[g], static_cast<uint32_t>(0)); return thrust::make_pair(s_d_auc[g], static_cast<uint32_t>(0));
} }
return thrust::make_pair(0.0f, static_cast<uint32_t>(1)); return thrust::make_pair(0.0, static_cast<uint32_t>(1));
}); });
thrust::tie(auc, invalid_groups) = thrust::reduce( thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::cuda::par(alloc), it, it + n_groups, thrust::cuda::par(alloc), it, it + n_groups,
thrust::pair<float, uint32_t>(0.0f, 0), PairPlus<float, uint32_t>{}); thrust::pair<double, uint32_t>(0.0, 0), PairPlus<double, uint32_t>{});
} }
return std::make_pair(auc, n_groups - invalid_groups); return std::make_pair(auc, n_groups - invalid_groups);
} }
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info, GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) { int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
if (predts.empty()) { if (predts.empty()) {
return std::make_pair(0.0f, static_cast<uint32_t>(0)); return std::make_pair(0.0, static_cast<uint32_t>(0));
} }
auto &cache = *p_cache; auto &cache = *p_cache;
@ -870,11 +869,11 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
* Get total positive/negative for each group. * Get total positive/negative for each group.
*/ */
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();
dh::caching_device_vector<thrust::pair<float, float>> totals(n_groups); dh::caching_device_vector<thrust::pair<double, double>> totals(n_groups);
auto key_it = dh::MakeTransformIterator<size_t>( auto key_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul), thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_group_ptr, i); }); [=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_group_ptr, i); });
auto val_it = dh::MakeTransformIterator<thrust::pair<float, float>>( auto val_it = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
float w = 1.0f; float w = 1.0f;
if (!d_weights.empty()) { if (!d_weights.empty()) {
@ -883,19 +882,19 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
w = d_weights[g]; w = d_weights[g];
} }
auto y = labels[i]; auto y = labels[i];
return thrust::make_pair(y * w, (1.0f - y) * w); return thrust::make_pair(y * w, (1.0 - y) * w);
}); });
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it, key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(), thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<float, float>{}); thrust::equal_to<size_t>{}, PairPlus<double, double>{});
/** /**
* Calculate AUC * Calculate AUC
*/ */
auto d_totals = dh::ToSpan(totals); auto d_totals = dh::ToSpan(totals);
auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, auto fn = [d_totals] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
float tp, size_t group_id) { double tp, size_t group_id) {
auto total_pos = d_totals[group_id].first; auto total_pos = d_totals[group_id].first;
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first); d_totals[group_id].first);

View File

@ -23,59 +23,60 @@ namespace metric {
/*********** /***********
* ROC AUC * * ROC AUC *
***********/ ***********/
XGBOOST_DEVICE inline float TrapezoidArea(float x0, float x1, float y0, float y1) { XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, double y1) {
return std::abs(x0 - x1) * (y0 + y1) * 0.5f; return std::abs(x0 - x1) * (y0 + y1) * 0.5f;
} }
struct DeviceAUCCache; struct DeviceAUCCache;
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache); int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
float GPUMultiClassROCAUC(common::Span<float const> predts, double GPUMultiClassROCAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device, MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *cache, std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes); size_t n_classes);
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info, GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache); int32_t device, std::shared_ptr<DeviceAUCCache> *cache);
/********** /**********
* PR AUC * * PR AUC *
**********/ **********/
std::tuple<float, float, float> std::tuple<double, double, double>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info, GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache); int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
float GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info, double GPUMultiClassPRAUC(common::Span<float const> predts,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache, MetaInfo const &info, int32_t device,
size_t n_classes); std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes);
std::pair<float, uint32_t> std::pair<double, uint32_t>
GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info, GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache); int32_t device, std::shared_ptr<DeviceAUCCache> *cache);
namespace detail { namespace detail {
XGBOOST_DEVICE inline float CalcH(float fp_a, float fp_b, float tp_a, XGBOOST_DEVICE inline double CalcH(double fp_a, double fp_b, double tp_a,
float tp_b) { double tp_b) {
return (fp_b - fp_a) / (tp_b - tp_a); return (fp_b - fp_a) / (tp_b - tp_a);
} }
XGBOOST_DEVICE inline float CalcB(float fp_a, float h, float tp_a, float total_pos) { XGBOOST_DEVICE inline double CalcB(double fp_a, double h, double tp_a, double total_pos) {
return (fp_a - h * tp_a) / total_pos; return (fp_a - h * tp_a) / total_pos;
} }
XGBOOST_DEVICE inline float CalcA(float h) { return h + 1; } XGBOOST_DEVICE inline double CalcA(double h) { return h + 1; }
XGBOOST_DEVICE inline float CalcDeltaPRAUC(float fp_prev, float fp, XGBOOST_DEVICE inline double CalcDeltaPRAUC(double fp_prev, double fp,
float tp_prev, float tp, double tp_prev, double tp,
float total_pos) { double total_pos) {
float pr_prev = tp_prev / total_pos; double pr_prev = tp_prev / total_pos;
float pr = tp / total_pos; double pr = tp / total_pos;
float h{0}, a{0}, b{0}; double h{0}, a{0}, b{0};
if (tp == tp_prev) { if (tp == tp_prev) {
a = 1.0; a = 1.0;
@ -86,7 +87,7 @@ XGBOOST_DEVICE inline float CalcDeltaPRAUC(float fp_prev, float fp,
b = detail::CalcB(fp_prev, h, tp_prev, total_pos); b = detail::CalcB(fp_prev, h, tp_prev, total_pos);
} }
float area = 0; double area = 0;
if (b != 0.0) { if (b != 0.0) {
area = (pr - pr_prev - area = (pr - pr_prev -
b / a * (std::log(a * pr + b) - std::log(a * pr_prev + b))) / b / a * (std::log(a * pr + b) - std::log(a * pr_prev + b))) /

View File

@ -86,9 +86,9 @@ class ElementWiseMetricsReduction {
thrust::cuda::par(alloc), thrust::cuda::par(alloc),
begin, end, begin, end,
[=] XGBOOST_DEVICE(size_t idx) { [=] XGBOOST_DEVICE(size_t idx) {
bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; float weight = is_null_weight ? 1.0f : s_weights[idx];
bst_float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]);
residue *= weight; residue *= weight;
return PackedReduceResult{ residue, weight }; return PackedReduceResult{ residue, weight };
}, },
@ -141,7 +141,7 @@ struct EvalRowRMSE {
bst_float diff = label - pred; bst_float diff = label - pred;
return diff * diff; return diff * diff;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum); return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum);
} }
}; };
@ -155,7 +155,7 @@ struct EvalRowRMSLE {
bst_float diff = std::log1p(label) - std::log1p(pred); bst_float diff = std::log1p(label) - std::log1p(pred);
return diff * diff; return diff * diff;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum); return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum);
} }
}; };
@ -168,7 +168,7 @@ struct EvalRowMAE {
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
return std::abs(label - pred); return std::abs(label - pred);
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -180,7 +180,7 @@ struct EvalRowMAPE {
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
return std::abs((label - pred) / label); return std::abs((label - pred) / label);
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -202,7 +202,7 @@ struct EvalRowLogLoss {
} }
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -215,7 +215,7 @@ struct EvalRowMPHE {
bst_float diff = label - pred; bst_float diff = label - pred;
return std::sqrt( 1 + diff * diff) - 1; return std::sqrt( 1 + diff * diff) - 1;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -244,13 +244,12 @@ struct EvalError {
} }
} }
XGBOOST_DEVICE bst_float EvalRow( XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float label, bst_float pred) const {
// assume label is in [0,1] // assume label is in [0,1]
return pred > threshold_ ? 1.0f - label : label; return pred > threshold_ ? 1.0f - label : label;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
@ -270,7 +269,7 @@ struct EvalPoissonNegLogLik {
return common::LogGamma(y + 1.0f) + py - std::log(py) * y; return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -291,7 +290,7 @@ struct EvalGammaDeviance {
return std::log(predt / label) + label / predt - 1; return std::log(predt / label) + label / predt - 1;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
if (wsum <= 0) { if (wsum <= 0) {
wsum = kRtEps; wsum = kRtEps;
} }
@ -317,7 +316,7 @@ struct EvalGammaNLogLik {
// general form for exponential family. // general form for exponential family.
return -((y * theta - b) / a + c); return -((y * theta - b) / a + c);
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -343,7 +342,7 @@ struct EvalTweedieNLogLik {
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_); bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
return -a + b; return -a + b;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
@ -360,9 +359,8 @@ struct EvalEWiseBase : public Metric {
explicit EvalEWiseBase(char const* policy_param) : explicit EvalEWiseBase(char const* policy_param) :
policy_{policy_param}, reducer_{policy_} {} policy_{policy_param}, reducer_{policy_} {}
bst_float Eval(const HostDeviceVector<bst_float>& preds, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
const MetaInfo& info, bool distributed) override {
bool distributed) override {
CHECK_EQ(preds.Size(), info.labels_.Size()) CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label and prediction size not match, " << "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification"; << "hint: use merror or mlogloss for multi-class classification";

View File

@ -167,9 +167,8 @@ class MultiClassMetricsReduction {
*/ */
template<typename Derived> template<typename Derived>
struct EvalMClassBase : public Metric { struct EvalMClassBase : public Metric {
bst_float Eval(const HostDeviceVector<bst_float> &preds, double Eval(const HostDeviceVector<float> &preds, const MetaInfo &info,
const MetaInfo &info, bool distributed) override {
bool distributed) override {
if (info.labels_.Size() == 0) { if (info.labels_.Size() == 0) {
CHECK_EQ(preds.Size(), 0); CHECK_EQ(preds.Size(), 0);
} else { } else {
@ -206,7 +205,7 @@ struct EvalMClassBase : public Metric {
* \param esum the sum statistics returned by EvalRow * \param esum the sum statistics returned by EvalRow
* \param wsum sum of weight * \param wsum sum of weight
*/ */
inline static bst_float GetFinal(bst_float esum, bst_float wsum) { inline static double GetFinal(double esum, double wsum) {
return esum / wsum; return esum / wsum;
} }

View File

@ -102,9 +102,8 @@ struct EvalAMS : public Metric {
name_ = os.str(); name_ = os.str();
} }
bst_float Eval(const HostDeviceVector<bst_float> &preds, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
const MetaInfo &info, bool distributed) override {
bool distributed) override {
CHECK(!distributed) << "metric AMS do not support distributed evaluation"; CHECK(!distributed) << "metric AMS do not support distributed evaluation";
using namespace std; // NOLINT(*) using namespace std; // NOLINT(*)
@ -163,9 +162,8 @@ struct EvalRank : public Metric, public EvalRankConfig {
std::unique_ptr<xgboost::Metric> rank_gpu_; std::unique_ptr<xgboost::Metric> rank_gpu_;
public: public:
bst_float Eval(const HostDeviceVector<bst_float> &preds, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
const MetaInfo &info, bool distributed) override {
bool distributed) override {
CHECK_EQ(preds.Size(), info.labels_.Size()) CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label size predict size not match"; << "label size predict size not match";
@ -222,14 +220,12 @@ struct EvalRank : public Metric, public EvalRankConfig {
} }
if (distributed) { if (distributed) {
bst_float dat[2]; double dat[2]{sum_metric, static_cast<double>(ngroups)};
dat[0] = static_cast<bst_float>(sum_metric);
dat[1] = static_cast<bst_float>(ngroups);
// approximately estimate the metric using mean // approximately estimate the metric using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1]; return dat[0] / dat[1];
} else { } else {
return static_cast<bst_float>(sum_metric) / ngroups; return sum_metric / ngroups;
} }
} }
@ -335,9 +331,9 @@ struct EvalMAP : public EvalRank {
return sumap; return sumap;
} else { } else {
if (this->minus) { if (this->minus) {
return 0.0f; return 0.0;
} else { } else {
return 1.0f; return 1.0;
} }
} }
} }
@ -347,9 +343,8 @@ struct EvalMAP : public EvalRank {
struct EvalCox : public Metric { struct EvalCox : public Metric {
public: public:
EvalCox() = default; EvalCox() = default;
bst_float Eval(const HostDeviceVector<bst_float> &preds, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
const MetaInfo &info, bool distributed) override {
bool distributed) override {
CHECK(!distributed) << "Cox metric does not support distributed evaluation"; CHECK(!distributed) << "Cox metric does not support distributed evaluation";
using namespace std; // NOLINT(*) using namespace std; // NOLINT(*)

View File

@ -29,9 +29,8 @@ DMLC_REGISTRY_FILE_TAG(rank_metric_gpu);
template <typename EvalMetricT> template <typename EvalMetricT>
struct EvalRankGpu : public Metric, public EvalRankConfig { struct EvalRankGpu : public Metric, public EvalRankConfig {
public: public:
bst_float Eval(const HostDeviceVector<bst_float> &preds, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
const MetaInfo &info, bool distributed) override {
bool distributed) override {
// Sanity check is done by the caller // Sanity check is done by the caller
std::vector<unsigned> tgptr(2, 0); std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(preds.Size()); tgptr[1] = static_cast<unsigned>(preds.Size());

View File

@ -206,9 +206,8 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
CHECK(tparam_); CHECK(tparam_);
} }
bst_float Eval(const HostDeviceVector<bst_float>& preds, double Eval(const HostDeviceVector<float> &preds, const MetaInfo &info,
const MetaInfo& info, bool distributed) override {
bool distributed) override {
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size()); CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size()); CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
CHECK(tparam_); CHECK(tparam_);
@ -221,7 +220,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
} }
return static_cast<bst_float>(Policy::GetFinal(dat[0], dat[1])); return Policy::GetFinal(dat[0], dat[1]);
} }
const char* Name() const override { const char* Name() const override {
@ -241,9 +240,8 @@ struct AFTNLogLikDispatcher : public Metric {
return "aft-nloglik"; return "aft-nloglik";
} }
bst_float Eval(const HostDeviceVector<bst_float>& preds, double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
const MetaInfo& info, bool distributed) override {
bool distributed) override {
CHECK(metric_) << "AFT metric must be configured first, with distribution type and scale"; CHECK(metric_) << "AFT metric must be configured first, with distribution type and scale";
return metric_->Eval(preds, info, distributed); return metric_->Eval(preds, info, distributed);
} }

View File

@ -1331,8 +1331,11 @@ def test_evaluation_metric():
) )
clf.fit(X, y, eval_set=[(X, y)]) clf.fit(X, y, eval_set=[(X, y)])
internal = clf.evals_result() internal = clf.evals_result()
np.testing.assert_allclose( np.testing.assert_allclose(
custom["validation_0"]["merror"], internal["validation_0"]["merror"] custom["validation_0"]["merror"],
internal["validation_0"]["merror"],
atol=1e-6
) )
clf = xgb.XGBRFClassifier( clf = xgb.XGBRFClassifier(