diff --git a/include/xgboost/metric.h b/include/xgboost/metric.h index b1d69b99a..42d517819 100644 --- a/include/xgboost/metric.h +++ b/include/xgboost/metric.h @@ -58,9 +58,8 @@ class Metric : public Configurable { * the average statistics across all the node, * this is only supported by some metrics */ - virtual bst_float Eval(const HostDeviceVector& preds, - const MetaInfo& info, - bool distributed) = 0; + virtual double Eval(const HostDeviceVector &preds, + const MetaInfo &info, bool distributed) = 0; /*! \return name of metric */ virtual const char* Name() const = 0; /*! \brief virtual destructor */ diff --git a/src/learner.cc b/src/learner.cc index c6baf5b5e..399d299f5 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1110,6 +1110,7 @@ class LearnerImpl : public LearnerIO { this->Configure(); std::ostringstream os; + os.precision(std::numeric_limits::max_digits10); os << '[' << iter << ']' << std::setiosflags(std::ios::fixed); if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) { auto warn_default_eval_metric = [](const std::string& objective, const std::string& before, diff --git a/src/metric/auc.cc b/src/metric/auc.cc index d58aefca3..b657c72ea 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -30,7 +30,7 @@ namespace metric { * handle the normalization. */ template -std::tuple +std::tuple BinaryAUC(common::Span predts, common::Span labels, OptionalWeights weights, std::vector const &sorted_idx, Fn &&area_fn) { @@ -39,12 +39,12 @@ BinaryAUC(common::Span predts, common::Span labels, auto p_predts = predts.data(); auto p_labels = labels.data(); - float auc{0}; + double auc{0}; float label = p_labels[sorted_idx.front()]; float w = weights[sorted_idx[0]]; - float fp = (1.0 - label) * w, tp = label * w; - float tp_prev = 0, fp_prev = 0; + double fp = (1.0 - label) * w, tp = label * w; + double tp_prev = 0, fp_prev = 0; // TODO(jiaming): We can parallize this if we have a parallel scan for CPU. for (size_t i = 1; i < sorted_idx.size(); ++i) { if (p_predts[sorted_idx[i]] != p_predts[sorted_idx[i - 1]]) { @@ -77,17 +77,19 @@ BinaryAUC(common::Span predts, common::Span labels, * Machine Learning Models */ template -float MultiClassOVR(common::Span predts, MetaInfo const &info, - size_t n_classes, int32_t n_threads, - BinaryAUC &&binary_auc) { +double MultiClassOVR(common::Span predts, MetaInfo const &info, + size_t n_classes, int32_t n_threads, + BinaryAUC &&binary_auc) { CHECK_NE(n_classes, 0); auto const &labels = info.labels_.ConstHostVector(); - std::vector results(n_classes * 3, 0); - auto s_results = common::Span(results); + std::vector results(n_classes * 3, 0); + auto s_results = common::Span(results); + auto local_area = s_results.subspan(0, n_classes); auto tp = s_results.subspan(n_classes, n_classes); auto auc = s_results.subspan(2 * n_classes, n_classes); + auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; if (!info.labels_.Empty()) { @@ -98,7 +100,7 @@ float MultiClassOVR(common::Span predts, MetaInfo const &info, proba[i] = predts[i * n_classes + c]; response[i] = labels[i] == c ? 1.0f : 0.0; } - float fp; + double fp; std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights); local_area[c] = fp * tp[c]; }); @@ -107,8 +109,8 @@ float MultiClassOVR(common::Span predts, MetaInfo const &info, // we have 2 averages going in here, first is among workers, second is among // classes. allreduce sums up fp/tp auc for each class. rabit::Allreduce(results.data(), results.size()); - float auc_sum{0}; - float tp_sum{0}; + double auc_sum{0}; + double tp_sum{0}; for (size_t c = 0; c < n_classes; ++c) { if (local_area[c] != 0) { // normalize and weight it by prevalence. After allreduce, `local_area` @@ -117,21 +119,21 @@ float MultiClassOVR(common::Span predts, MetaInfo const &info, auc_sum += auc[c] / local_area[c] * tp[c]; tp_sum += tp[c]; } else { - auc_sum = std::numeric_limits::quiet_NaN(); + auc_sum = std::numeric_limits::quiet_NaN(); break; } } if (tp_sum == 0 || std::isnan(auc_sum)) { - auc_sum = std::numeric_limits::quiet_NaN(); + auc_sum = std::numeric_limits::quiet_NaN(); } else { auc_sum /= tp_sum; } return auc_sum; } -std::tuple BinaryROCAUC(common::Span predts, - common::Span labels, - OptionalWeights weights) { +std::tuple +BinaryROCAUC(common::Span predts, common::Span labels, + OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); } @@ -139,14 +141,14 @@ std::tuple BinaryROCAUC(common::Span predts, /** * Calculate AUC for 1 ranking group; */ -float GroupRankingROC(common::Span predts, - common::Span labels, float w) { +double GroupRankingROC(common::Span predts, + common::Span labels, float w) { // on ranking, we just count all pairs. - float auc{0}; + double auc{0}; auto const sorted_idx = common::ArgSort(labels, std::greater<>{}); w = common::Sqr(w); - float sum_w = 0.0f; + double sum_w = 0.0f; for (size_t i = 0; i < labels.size(); ++i) { for (size_t j = i + 1; j < labels.size(); ++j) { auto predt = predts[sorted_idx[i]] - predts[sorted_idx[j]]; @@ -173,11 +175,11 @@ float GroupRankingROC(common::Span predts, * * https://doi.org/10.1371/journal.pone.0092209 */ -std::tuple BinaryPRAUC(common::Span predts, - common::Span labels, - OptionalWeights weights) { +std::tuple BinaryPRAUC(common::Span predts, + common::Span labels, + OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); - float total_pos{0}, total_neg{0}; + double total_pos{0}, total_neg{0}; for (size_t i = 0; i < labels.size(); ++i) { auto w = weights[i]; total_pos += w * labels[i]; @@ -186,22 +188,22 @@ std::tuple BinaryPRAUC(common::Span predts, if (total_pos <= 0 || total_neg <= 0) { return {1.0f, 1.0f, std::numeric_limits::quiet_NaN()}; } - auto fn = [total_pos](float fp_prev, float fp, float tp_prev, float tp) { + auto fn = [total_pos](double fp_prev, double fp, double tp_prev, double tp) { return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); }; - float tp{0}, fp{0}, auc{0}; + double tp{0}, fp{0}, auc{0}; std::tie(fp, tp, auc) = BinaryAUC(predts, labels, weights, sorted_idx, fn); return std::make_tuple(1.0, 1.0, auc); } - /** * Cast LTR problem to binary classification problem by comparing pairs. */ template -std::pair RankingAUC(std::vector const &predts, - MetaInfo const &info, int32_t n_threads) { +std::pair RankingAUC(std::vector const &predts, + MetaInfo const &info, + int32_t n_threads) { CHECK_GE(info.group_ptr_.size(), 2); uint32_t n_groups = info.group_ptr_.size() - 1; auto s_predts = common::Span{predts}; @@ -217,7 +219,7 @@ std::pair RankingAUC(std::vector const &predts, float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); - float auc; + double auc; if (is_roc && g_labels.size() < 3) { // With 2 documents, there's only 1 comparison can be made. So either // TP or FP will be zero. @@ -236,16 +238,16 @@ std::pair RankingAUC(std::vector const &predts, } auc_tloc[omp_get_thread_num()] += auc; }); - float sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0); + double sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0); return std::make_pair(sum_auc, n_groups - invalid_groups); } template class EvalAUC : public Metric { - float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) override { - float auc {0}; + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { + double auc {0}; if (tparam_->gpu_id != GenericParameter::kCpuId) { preds.SetDevice(tparam_->gpu_id); info.labels_.SetDevice(tparam_->gpu_id); @@ -256,7 +258,7 @@ class EvalAUC : public Metric { rabit::Allreduce(meta.data(), meta.size()); if (meta[0] == 0) { // Empty across all workers, which is not supported. - auc = std::numeric_limits::quiet_NaN(); + auc = std::numeric_limits::quiet_NaN(); } else if (!info.group_ptr_.empty()) { /** * learning to rank @@ -274,13 +276,13 @@ class EvalAUC : public Metric { InvalidGroupAUC(); } - std::array results{auc, static_cast(valid_groups)}; + std::array results{auc, static_cast(valid_groups)}; rabit::Allreduce(results.data(), results.size()); auc = results[0]; valid_groups = static_cast(results[1]); if (valid_groups <= 0) { - auc = std::numeric_limits::quiet_NaN(); + auc = std::numeric_limits::quiet_NaN(); } else { auc /= valid_groups; CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups @@ -297,18 +299,18 @@ class EvalAUC : public Metric { /** * binary classification */ - float fp{0}, tp{0}; + double fp{0}, tp{0}; if (!(preds.Empty() || info.labels_.Empty())) { std::tie(fp, tp, auc) = static_cast(this)->EvalBinary(preds, info); } - float local_area = fp * tp; - std::array result{auc, local_area}; + double local_area = fp * tp; + std::array result{auc, local_area}; rabit::Allreduce(result.data(), result.size()); std::tie(auc, local_area) = common::UnpackArr(std::move(result)); if (local_area <= 0) { // the dataset across all workers have only positive or negative sample - auc = std::numeric_limits::quiet_NaN(); + auc = std::numeric_limits::quiet_NaN(); } else { CHECK_LE(auc, local_area); // normalization @@ -326,9 +328,9 @@ class EvalROCAUC : public EvalAUC { std::shared_ptr d_cache_; public: - std::pair EvalRanking(HostDeviceVector const &predts, - MetaInfo const &info) { - float auc{0}; + std::pair EvalRanking(HostDeviceVector const &predts, + MetaInfo const &info) { + double auc{0}; uint32_t valid_groups = 0; auto n_threads = tparam_->Threads(); if (tparam_->gpu_id == GenericParameter::kCpuId) { @@ -341,9 +343,9 @@ class EvalROCAUC : public EvalAUC { return std::make_pair(auc, valid_groups); } - float EvalMultiClass(HostDeviceVector const &predts, - MetaInfo const &info, size_t n_classes) { - float auc{0}; + double EvalMultiClass(HostDeviceVector const &predts, + MetaInfo const &info, size_t n_classes) { + double auc{0}; auto n_threads = tparam_->Threads(); CHECK_NE(n_classes, 0); if (tparam_->gpu_id == GenericParameter::kCpuId) { @@ -356,9 +358,9 @@ class EvalROCAUC : public EvalAUC { return auc; } - std::tuple + std::tuple EvalBinary(HostDeviceVector const &predts, MetaInfo const &info) { - float fp, tp, auc; + double fp, tp, auc; if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(fp, tp, auc) = BinaryROCAUC(predts.ConstHostVector(), info.labels_.ConstHostVector(), @@ -381,37 +383,37 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc") .set_body([](const char*) { return new EvalROCAUC(); }); #if !defined(XGBOOST_USE_CUDA) -std::tuple +std::tuple GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { common::AssertGPUSupport(); - return std::make_tuple(0.0f, 0.0f, 0.0f); + return {}; } -float GPUMultiClassROCAUC(common::Span predts, - MetaInfo const &info, int32_t device, - std::shared_ptr *cache, - size_t n_classes) { +double GPUMultiClassROCAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *cache, + size_t n_classes) { common::AssertGPUSupport(); - return 0; + return 0.0; } -std::pair +std::pair GPURankingAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { common::AssertGPUSupport(); - return std::make_pair(0.0f, 0u); + return {}; } struct DeviceAUCCache {}; #endif // !defined(XGBOOST_USE_CUDA) -class EvalAUCPR : public EvalAUC { +class EvalPRAUC : public EvalAUC { std::shared_ptr d_cache_; public: - std::tuple + std::tuple EvalBinary(HostDeviceVector const &predts, MetaInfo const &info) { - float pr, re, auc; + double pr, re, auc; if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(pr, re, auc) = BinaryPRAUC(predts.ConstHostSpan(), info.labels_.ConstHostSpan(), @@ -423,7 +425,7 @@ class EvalAUCPR : public EvalAUC { return std::make_tuple(pr, re, auc); } - float EvalMultiClass(HostDeviceVector const &predts, + double EvalMultiClass(HostDeviceVector const &predts, MetaInfo const &info, size_t n_classes) { if (tparam_->gpu_id == GenericParameter::kCpuId) { auto n_threads = this->tparam_->Threads(); @@ -435,9 +437,9 @@ class EvalAUCPR : public EvalAUC { } } - std::pair EvalRanking(HostDeviceVector const &predts, - MetaInfo const &info) { - float auc{0}; + std::pair EvalRanking(HostDeviceVector const &predts, + MetaInfo const &info) { + double auc{0}; uint32_t valid_groups = 0; auto n_threads = tparam_->Threads(); if (tparam_->gpu_id == GenericParameter::kCpuId) { @@ -460,24 +462,25 @@ class EvalAUCPR : public EvalAUC { XGBOOST_REGISTER_METRIC(AUCPR, "aucpr") .describe("Area under PR curve for both classification and rank.") - .set_body([](char const *) { return new EvalAUCPR{}; }); + .set_body([](char const *) { return new EvalPRAUC{}; }); #if !defined(XGBOOST_USE_CUDA) -std::tuple +std::tuple GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { common::AssertGPUSupport(); return {}; } -float GPUMultiClassPRAUC(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr *cache, - size_t n_classes) { +double GPUMultiClassPRAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *cache, + size_t n_classes) { common::AssertGPUSupport(); return {}; } -std::pair +std::pair GPURankingPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *cache) { common::AssertGPUSupport(); diff --git a/src/metric/auc.cu b/src/metric/auc.cu index cd30ead95..153a0290a 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -22,7 +22,7 @@ namespace xgboost { namespace metric { namespace { // Pair of FP/TP -using Pair = thrust::pair; +using Pair = thrust::pair; template > struct PairPlus : public thrust::binary_function { @@ -38,9 +38,9 @@ struct PairPlus : public thrust::binary_function { struct DeviceAUCCache { // index sorted by prediction value dh::device_vector sorted_idx; - // track FP/TP for computation on trapesoid area + // track FP/TP for computation on trapezoid area dh::device_vector fptp; - // track FP_PREV/TP_PREV for computation on trapesoid area + // track FP_PREV/TP_PREV for computation on trapezoid area dh::device_vector neg_pos; // index of unique prediction values. dh::device_vector unique_idx; @@ -79,13 +79,13 @@ void InitCacheOnce(common::Span predts, int32_t device, * The GPU implementation uses same calculation as CPU with a few more steps to distribute * work across threads: * - * - Run scan to obtain TP/FP values, which are right coordinates of trapesoid. + * - Run scan to obtain TP/FP values, which are right coordinates of trapezoid. * - Find distinct prediction values and get the corresponding FP_PREV/TP_PREV value, - * which are left coordinates of trapesoids. + * which are left coordinates of trapezoids. * - Reduce the scan array into 1 AUC value. */ template -std::tuple +std::tuple GPUBinaryAUC(common::Span predts, MetaInfo const &info, int32_t device, common::Span d_sorted_idx, Fn area_fn, std::shared_ptr cache) { @@ -129,7 +129,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx)); dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp), - PairPlus{}, d_fptp.size()); + PairPlus{}, d_fptp.size()); auto d_neg_pos = dh::ToSpan(cache->neg_pos); // scatter unique negaive/positive values @@ -149,10 +149,10 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, } }); - auto in = dh::MakeTransformIterator( + auto in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { - float fp, tp; - float fp_prev, tp_prev; + double fp, tp; + double fp_prev, tp_prev; if (i == 0) { // handle the last element thrust::tie(fp, tp) = d_fptp.back(); @@ -165,11 +165,11 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, }); Pair last = cache->fptp.back(); - float auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size()); + double auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size()); return std::make_tuple(last.first, last.second, auc); } -std::tuple +std::tuple GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { auto &cache = *p_cache; @@ -183,7 +183,7 @@ GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, // Create lambda to avoid pass function pointer. return GPUBinaryAUC( predts, info, device, d_sorted_idx, - [] XGBOOST_DEVICE(float x0, float x1, float y0, float y1) { + [] XGBOOST_DEVICE(double x0, double x1, double y0, double y1) -> double { return TrapezoidArea(x0, x1, y0, y1); }, cache); @@ -209,33 +209,32 @@ XGBOOST_DEVICE size_t LastOf(size_t group, common::Span indptr) { return indptr[group + 1] - 1; } - -float ScaleClasses(common::Span results, common::Span local_area, - common::Span fp, common::Span tp, - common::Span auc, std::shared_ptr cache, - size_t n_classes) { +double ScaleClasses(common::Span results, + common::Span local_area, common::Span fp, + common::Span tp, common::Span auc, + std::shared_ptr cache, size_t n_classes) { dh::XGBDeviceAllocator alloc; if (rabit::IsDistributed()) { CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice()); cache->reducer->AllReduceSum(results.data(), results.data(), results.size()); } - auto reduce_in = dh::MakeTransformIterator>( + auto reduce_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { if (local_area[i] > 0) { return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]); } - return thrust::make_pair(std::numeric_limits::quiet_NaN(), 0.0f); + return thrust::make_pair(std::numeric_limits::quiet_NaN(), 0.0); }); - float tp_sum; - float auc_sum; + double tp_sum; + double auc_sum; thrust::tie(auc_sum, tp_sum) = thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, - Pair{0.0f, 0.0f}, PairPlus{}); + Pair{0.0, 0.0}, PairPlus{}); if (tp_sum != 0 && !std::isnan(auc_sum)) { auc_sum /= tp_sum; } else { - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); } return auc_sum; } @@ -246,7 +245,7 @@ float ScaleClasses(common::Span results, common::Span local_area, */ template void SegmentedFPTP(common::Span d_fptp, Fn segment_id) { - using Triple = thrust::tuple; + using Triple = thrust::tuple; // expand to tuple to include idx auto fptp_it_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { @@ -285,7 +284,7 @@ void SegmentedReduceAUC(common::Span d_unique_idx, std::shared_ptr cache, Area area_fn, Seg segment_id, - common::Span d_auc) { + common::Span d_auc) { auto d_fptp = dh::ToSpan(cache->fptp); auto d_neg_pos = dh::ToSpan(cache->neg_pos); dh::XGBDeviceAllocator alloc; @@ -294,11 +293,11 @@ void SegmentedReduceAUC(common::Span d_unique_idx, size_t class_id = segment_id(d_unique_idx[i]); return class_id; }); - auto val_in = dh::MakeTransformIterator( + auto val_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { size_t class_id = segment_id(d_unique_idx[i]); - float fp, tp, fp_prev, tp_prev; + double fp, tp, fp_prev, tp_prev; if (i == d_unique_class_ptr[class_id]) { // first item is ignored, we use this thread to calculate the last item thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)]; @@ -308,7 +307,7 @@ void SegmentedReduceAUC(common::Span d_unique_idx, thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; } - float auc = area_fn(fp_prev, fp, tp_prev, tp, class_id); + double auc = area_fn(fp_prev, fp, tp_prev, tp, class_id); return auc; }); thrust::reduce_by_key(thrust::cuda::par(alloc), key_in, @@ -321,10 +320,10 @@ void SegmentedReduceAUC(common::Span d_unique_idx, * up each class in all kernels. */ template -float GPUMultiClassAUCOVR(common::Span predts, - MetaInfo const &info, int32_t device, - common::Span d_class_ptr, size_t n_classes, - std::shared_ptr cache, Fn area_fn) { +double GPUMultiClassAUCOVR(common::Span predts, + MetaInfo const &info, int32_t device, + common::Span d_class_ptr, size_t n_classes, + std::shared_ptr cache, Fn area_fn) { dh::safe_cuda(cudaSetDevice(device)); /** * Sorted idx @@ -339,7 +338,7 @@ float GPUMultiClassAUCOVR(common::Span predts, size_t n_samples = labels.size(); if (n_samples == 0) { - dh::TemporaryArray resutls(n_classes * 4, 0.0f); + dh::TemporaryArray resutls(n_classes * 4, 0.0f); auto d_results = dh::ToSpan(resutls); dh::LaunchN(n_classes * 4, [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; }); @@ -353,7 +352,7 @@ float GPUMultiClassAUCOVR(common::Span predts, /** * Linear scan */ - dh::caching_device_vector d_auc(n_classes, 0); + dh::caching_device_vector d_auc(n_classes, 0); auto get_weight = OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { @@ -432,7 +431,7 @@ float GPUMultiClassAUCOVR(common::Span predts, /** * Scale the classes with number of samples for each class. */ - dh::TemporaryArray resutls(n_classes * 4); + dh::TemporaryArray resutls(n_classes * 4); auto d_results = dh::ToSpan(resutls); auto local_area = d_results.subspan(0, n_classes); auto fp = d_results.subspan(n_classes, n_classes); @@ -470,10 +469,10 @@ void MultiClassSortedIdx(common::Span predts, dh::SegmentedArgSort(d_predts_t, d_class_ptr, d_sorted_idx); } -float GPUMultiClassROCAUC(common::Span predts, - MetaInfo const &info, int32_t device, - std::shared_ptr *p_cache, - size_t n_classes) { +double GPUMultiClassROCAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *p_cache, + size_t n_classes) { auto& cache = *p_cache; InitCacheOnce(predts, device, p_cache); @@ -483,8 +482,8 @@ float GPUMultiClassROCAUC(common::Span predts, dh::TemporaryArray class_ptr(n_classes + 1, 0); MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache); - auto fn = [] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, float tp, - size_t /*class_id*/) { + auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, + double tp, size_t /*class_id*/) { return TrapezoidArea(fp_prev, fp, tp_prev, tp); }; return GPUMultiClassAUCOVR(predts, info, device, dh::ToSpan(class_ptr), @@ -494,13 +493,13 @@ float GPUMultiClassROCAUC(common::Span predts, namespace { struct RankScanItem { size_t idx; - float predt; - float w; + double predt; + double w; bst_group_t group_id; }; } // anonymous namespace -std::pair +std::pair GPURankingAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { auto& cache = *p_cache; @@ -523,7 +522,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, InvalidGroupAUC(); } if (n_valid == 0) { - return std::make_pair(0.0f, 0); + return std::make_pair(0.0, 0); } /** @@ -583,7 +582,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, return RankScanItem{idx, predt, w, query_group_idx}; }); - dh::TemporaryArray d_auc(group_ptr.size() - 1); + dh::TemporaryArray d_auc(group_ptr.size() - 1); auto s_d_auc = dh::ToSpan(d_auc); auto out = thrust::make_transform_output_iterator( dh::TypedDiscard{}, @@ -615,12 +614,12 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, /** * Scale the AUC with number of items in each group. */ - float auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc), - dh::tend(s_d_auc), 0.0f); + double auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc), + dh::tend(s_d_auc), 0.0); return std::make_pair(auc, n_valid); } -std::tuple +std::tuple GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { auto& cache = *p_cache; @@ -635,32 +634,32 @@ GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, auto labels = info.labels_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan(); auto get_weight = OptionalWeights{d_weights}; - auto it = dh::MakeTransformIterator>( + auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto w = get_weight[d_sorted_idx[i]]; return thrust::make_pair(labels[d_sorted_idx[i]] * w, (1.0f - labels[d_sorted_idx[i]]) * w); }); dh::XGBCachingDeviceAllocator alloc; - float total_pos, total_neg; + double total_pos, total_neg; thrust::tie(total_pos, total_neg) = thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(), - Pair{0.0f, 0.0f}, PairPlus{}); + Pair{0.0, 0.0}, PairPlus{}); if (total_pos <= 0.0 || total_neg <= 0.0) { return {0.0f, 0.0f, 0.0f}; } - auto fn = [total_pos] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, - float tp) { + auto fn = [total_pos] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, + double tp) { return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); }; - float fp, tp, auc; + double fp, tp, auc; std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache); return std::make_tuple(1.0, 1.0, auc); } -float GPUMultiClassPRAUC(common::Span predts, +double GPUMultiClassPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache, size_t n_classes) { @@ -682,14 +681,14 @@ float GPUMultiClassPRAUC(common::Span predts, */ auto labels = info.labels_.ConstDeviceSpan(); auto n_samples = info.num_row_; - dh::caching_device_vector> totals(n_classes); + dh::caching_device_vector totals(n_classes); auto key_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), [n_samples] XGBOOST_DEVICE(size_t i) { return i / n_samples; // class id }); auto get_weight = OptionalWeights{d_weights}; - auto val_it = dh::MakeTransformIterator>( + auto val_it = dh::MakeTransformIterator>( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto idx = d_sorted_idx[i] % n_samples; auto w = get_weight[idx]; @@ -701,14 +700,14 @@ float GPUMultiClassPRAUC(common::Span predts, thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + predts.size(), val_it, thrust::make_discard_iterator(), totals.begin(), - thrust::equal_to{}, PairPlus{}); + thrust::equal_to{}, PairPlus{}); /** * Calculate AUC */ auto d_totals = dh::ToSpan(totals); - auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, - float tp, size_t class_id) { + auto fn = [d_totals] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, + double tp, size_t class_id) { auto total_pos = d_totals[class_id].first; return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, d_totals[class_id].first); @@ -718,7 +717,7 @@ float GPUMultiClassPRAUC(common::Span predts, } template -std::pair +std::pair GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, common::Span d_group_ptr, int32_t device, std::shared_ptr cache, Fn area_fn) { @@ -736,7 +735,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, * Linear scan */ size_t n_samples = labels.size(); - dh::caching_device_vector d_auc(n_groups, 0); + dh::caching_device_vector d_auc(n_groups, 0); auto get_weight = OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) { @@ -816,33 +815,33 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, /** * Scale the groups with number of samples for each group. */ - float auc; + double auc; uint32_t invalid_groups; { - auto it = dh::MakeTransformIterator>( + auto it = dh::MakeTransformIterator>( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) { - float fp, tp; + double fp, tp; thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; - float area = fp * tp; + double area = fp * tp; auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g]; if (area > 0 && n_documents >= 2) { return thrust::make_pair(s_d_auc[g], static_cast(0)); } - return thrust::make_pair(0.0f, static_cast(1)); + return thrust::make_pair(0.0, static_cast(1)); }); thrust::tie(auc, invalid_groups) = thrust::reduce( thrust::cuda::par(alloc), it, it + n_groups, - thrust::pair(0.0f, 0), PairPlus{}); + thrust::pair(0.0, 0), PairPlus{}); } return std::make_pair(auc, n_groups - invalid_groups); } -std::pair +std::pair GPURankingPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache) { dh::safe_cuda(cudaSetDevice(device)); if (predts.empty()) { - return std::make_pair(0.0f, static_cast(0)); + return std::make_pair(0.0, static_cast(0)); } auto &cache = *p_cache; @@ -870,11 +869,11 @@ GPURankingPRAUC(common::Span predts, MetaInfo const &info, * Get total positive/negative for each group. */ auto d_weights = info.weights_.ConstDeviceSpan(); - dh::caching_device_vector> totals(n_groups); + dh::caching_device_vector> totals(n_groups); auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_group_ptr, i); }); - auto val_it = dh::MakeTransformIterator>( + auto val_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { float w = 1.0f; if (!d_weights.empty()) { @@ -883,19 +882,19 @@ GPURankingPRAUC(common::Span predts, MetaInfo const &info, w = d_weights[g]; } auto y = labels[i]; - return thrust::make_pair(y * w, (1.0f - y) * w); + return thrust::make_pair(y * w, (1.0 - y) * w); }); thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + predts.size(), val_it, thrust::make_discard_iterator(), totals.begin(), - thrust::equal_to{}, PairPlus{}); + thrust::equal_to{}, PairPlus{}); /** * Calculate AUC */ auto d_totals = dh::ToSpan(totals); - auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, - float tp, size_t group_id) { + auto fn = [d_totals] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev, + double tp, size_t group_id) { auto total_pos = d_totals[group_id].first; return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, d_totals[group_id].first); diff --git a/src/metric/auc.h b/src/metric/auc.h index db399a060..cde8febf2 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -23,59 +23,60 @@ namespace metric { /*********** * ROC AUC * ***********/ -XGBOOST_DEVICE inline float TrapezoidArea(float x0, float x1, float y0, float y1) { +XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, double y1) { return std::abs(x0 - x1) * (y0 + y1) * 0.5f; } struct DeviceAUCCache; -std::tuple +std::tuple GPUBinaryROCAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache); -float GPUMultiClassROCAUC(common::Span predts, - MetaInfo const &info, int32_t device, - std::shared_ptr *cache, - size_t n_classes); +double GPUMultiClassROCAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *cache, + size_t n_classes); -std::pair +std::pair GPURankingAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *cache); /********** * PR AUC * **********/ -std::tuple +std::tuple GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache); -float GPUMultiClassPRAUC(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr *cache, - size_t n_classes); +double GPUMultiClassPRAUC(common::Span predts, + MetaInfo const &info, int32_t device, + std::shared_ptr *cache, + size_t n_classes); -std::pair +std::pair GPURankingPRAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *cache); namespace detail { -XGBOOST_DEVICE inline float CalcH(float fp_a, float fp_b, float tp_a, - float tp_b) { +XGBOOST_DEVICE inline double CalcH(double fp_a, double fp_b, double tp_a, + double tp_b) { return (fp_b - fp_a) / (tp_b - tp_a); } -XGBOOST_DEVICE inline float CalcB(float fp_a, float h, float tp_a, float total_pos) { +XGBOOST_DEVICE inline double CalcB(double fp_a, double h, double tp_a, double total_pos) { return (fp_a - h * tp_a) / total_pos; } -XGBOOST_DEVICE inline float CalcA(float h) { return h + 1; } +XGBOOST_DEVICE inline double CalcA(double h) { return h + 1; } -XGBOOST_DEVICE inline float CalcDeltaPRAUC(float fp_prev, float fp, - float tp_prev, float tp, - float total_pos) { - float pr_prev = tp_prev / total_pos; - float pr = tp / total_pos; +XGBOOST_DEVICE inline double CalcDeltaPRAUC(double fp_prev, double fp, + double tp_prev, double tp, + double total_pos) { + double pr_prev = tp_prev / total_pos; + double pr = tp / total_pos; - float h{0}, a{0}, b{0}; + double h{0}, a{0}, b{0}; if (tp == tp_prev) { a = 1.0; @@ -86,7 +87,7 @@ XGBOOST_DEVICE inline float CalcDeltaPRAUC(float fp_prev, float fp, b = detail::CalcB(fp_prev, h, tp_prev, total_pos); } - float area = 0; + double area = 0; if (b != 0.0) { area = (pr - pr_prev - b / a * (std::log(a * pr + b) - std::log(a * pr_prev + b))) / diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 29130c89e..ddc955768 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -86,9 +86,9 @@ class ElementWiseMetricsReduction { thrust::cuda::par(alloc), begin, end, [=] XGBOOST_DEVICE(size_t idx) { - bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; + float weight = is_null_weight ? 1.0f : s_weights[idx]; - bst_float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); + float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); residue *= weight; return PackedReduceResult{ residue, weight }; }, @@ -141,7 +141,7 @@ struct EvalRowRMSE { bst_float diff = label - pred; return diff * diff; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum); } }; @@ -155,7 +155,7 @@ struct EvalRowRMSLE { bst_float diff = std::log1p(label) - std::log1p(pred); return diff * diff; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum); } }; @@ -168,7 +168,7 @@ struct EvalRowMAE { XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { return std::abs(label - pred); } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; @@ -180,7 +180,7 @@ struct EvalRowMAPE { XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { return std::abs((label - pred) / label); } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; @@ -202,7 +202,7 @@ struct EvalRowLogLoss { } } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; @@ -215,7 +215,7 @@ struct EvalRowMPHE { bst_float diff = label - pred; return std::sqrt( 1 + diff * diff) - 1; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; @@ -244,13 +244,12 @@ struct EvalError { } } - XGBOOST_DEVICE bst_float EvalRow( - bst_float label, bst_float pred) const { + XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { // assume label is in [0,1] return pred > threshold_ ? 1.0f - label : label; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } @@ -270,7 +269,7 @@ struct EvalPoissonNegLogLik { return common::LogGamma(y + 1.0f) + py - std::log(py) * y; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; @@ -291,7 +290,7 @@ struct EvalGammaDeviance { return std::log(predt / label) + label / predt - 1; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { if (wsum <= 0) { wsum = kRtEps; } @@ -317,7 +316,7 @@ struct EvalGammaNLogLik { // general form for exponential family. return -((y * theta - b) / a + c); } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; @@ -343,7 +342,7 @@ struct EvalTweedieNLogLik { bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_); return -a + b; } - static bst_float GetFinal(bst_float esum, bst_float wsum) { + static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } @@ -360,9 +359,8 @@ struct EvalEWiseBase : public Metric { explicit EvalEWiseBase(char const* policy_param) : policy_{policy_param}, reducer_{policy_} {} - bst_float Eval(const HostDeviceVector& preds, - const MetaInfo& info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK_EQ(preds.Size(), info.labels_.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 580edf453..3a42c46e7 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -167,9 +167,8 @@ class MultiClassMetricsReduction { */ template struct EvalMClassBase : public Metric { - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { if (info.labels_.Size() == 0) { CHECK_EQ(preds.Size(), 0); } else { @@ -206,7 +205,7 @@ struct EvalMClassBase : public Metric { * \param esum the sum statistics returned by EvalRow * \param wsum sum of weight */ - inline static bst_float GetFinal(bst_float esum, bst_float wsum) { + inline static double GetFinal(double esum, double wsum) { return esum / wsum; } diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index fd219c39d..f57d13926 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -102,9 +102,8 @@ struct EvalAMS : public Metric { name_ = os.str(); } - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK(!distributed) << "metric AMS do not support distributed evaluation"; using namespace std; // NOLINT(*) @@ -163,9 +162,8 @@ struct EvalRank : public Metric, public EvalRankConfig { std::unique_ptr rank_gpu_; public: - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; @@ -222,14 +220,12 @@ struct EvalRank : public Metric, public EvalRankConfig { } if (distributed) { - bst_float dat[2]; - dat[0] = static_cast(sum_metric); - dat[1] = static_cast(ngroups); + double dat[2]{sum_metric, static_cast(ngroups)}; // approximately estimate the metric using mean rabit::Allreduce(dat, 2); return dat[0] / dat[1]; } else { - return static_cast(sum_metric) / ngroups; + return sum_metric / ngroups; } } @@ -335,9 +331,9 @@ struct EvalMAP : public EvalRank { return sumap; } else { if (this->minus) { - return 0.0f; + return 0.0; } else { - return 1.0f; + return 1.0; } } } @@ -347,9 +343,8 @@ struct EvalMAP : public EvalRank { struct EvalCox : public Metric { public: EvalCox() = default; - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK(!distributed) << "Cox metric does not support distributed evaluation"; using namespace std; // NOLINT(*) diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 57757b22b..0e7f9cc15 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -29,9 +29,8 @@ DMLC_REGISTRY_FILE_TAG(rank_metric_gpu); template struct EvalRankGpu : public Metric, public EvalRankConfig { public: - bst_float Eval(const HostDeviceVector &preds, - const MetaInfo &info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { // Sanity check is done by the caller std::vector tgptr(2, 0); tgptr[1] = static_cast(preds.Size()); diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index d0263456d..69ce2d943 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -206,9 +206,8 @@ template struct EvalEWiseSurvivalBase : public Metric { CHECK(tparam_); } - bst_float Eval(const HostDeviceVector& preds, - const MetaInfo& info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size()); CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size()); CHECK(tparam_); @@ -221,7 +220,7 @@ template struct EvalEWiseSurvivalBase : public Metric { if (distributed) { rabit::Allreduce(dat, 2); } - return static_cast(Policy::GetFinal(dat[0], dat[1])); + return Policy::GetFinal(dat[0], dat[1]); } const char* Name() const override { @@ -241,9 +240,8 @@ struct AFTNLogLikDispatcher : public Metric { return "aft-nloglik"; } - bst_float Eval(const HostDeviceVector& preds, - const MetaInfo& info, - bool distributed) override { + double Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK(metric_) << "AFT metric must be configured first, with distribution type and scale"; return metric_->Eval(preds, info, distributed); } diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 2a400871d..6a0d0fc19 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1331,8 +1331,11 @@ def test_evaluation_metric(): ) clf.fit(X, y, eval_set=[(X, y)]) internal = clf.evals_result() + np.testing.assert_allclose( - custom["validation_0"]["merror"], internal["validation_0"]["merror"] + custom["validation_0"]["merror"], + internal["validation_0"]["merror"], + atol=1e-6 ) clf = xgb.XGBRFClassifier(