Rename context in Metric. (#8686)
This commit is contained in:
parent
d6018eb4b9
commit
9f598efc3e
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* \file metric.h
|
||||
* \brief interface of evaluation metric function supported in xgboost.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
@ -27,7 +27,7 @@ struct Context;
|
||||
*/
|
||||
class Metric : public Configurable {
|
||||
protected:
|
||||
Context const* tparam_;
|
||||
Context const* ctx_;
|
||||
|
||||
public:
|
||||
/*!
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "auc.h"
|
||||
|
||||
@ -255,10 +255,10 @@ template <typename Curve>
|
||||
class EvalAUC : public Metric {
|
||||
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
|
||||
double auc {0};
|
||||
if (tparam_->gpu_id != Context::kCpuId) {
|
||||
preds.SetDevice(tparam_->gpu_id);
|
||||
info.labels.SetDevice(tparam_->gpu_id);
|
||||
info.weights_.SetDevice(tparam_->gpu_id);
|
||||
if (ctx_->gpu_id != Context::kCpuId) {
|
||||
preds.SetDevice(ctx_->gpu_id);
|
||||
info.labels.SetDevice(ctx_->gpu_id);
|
||||
info.weights_.SetDevice(ctx_->gpu_id);
|
||||
}
|
||||
// We use the global size to handle empty dataset.
|
||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||
@ -339,13 +339,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
||||
MetaInfo const &info) {
|
||||
double auc{0};
|
||||
uint32_t valid_groups = 0;
|
||||
auto n_threads = tparam_->Threads();
|
||||
if (tparam_->gpu_id == Context::kCpuId) {
|
||||
auto n_threads = ctx_->Threads();
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
std::tie(auc, valid_groups) =
|
||||
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
|
||||
} else {
|
||||
std::tie(auc, valid_groups) = GPURankingAUC(
|
||||
predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
|
||||
predts.ConstDeviceSpan(), info, ctx_->gpu_id, &this->d_cache_);
|
||||
}
|
||||
return std::make_pair(auc, valid_groups);
|
||||
}
|
||||
@ -353,13 +353,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
||||
double EvalMultiClass(HostDeviceVector<float> const &predts,
|
||||
MetaInfo const &info, size_t n_classes) {
|
||||
double auc{0};
|
||||
auto n_threads = tparam_->Threads();
|
||||
auto n_threads = ctx_->Threads();
|
||||
CHECK_NE(n_classes, 0);
|
||||
if (tparam_->gpu_id == Context::kCpuId) {
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
|
||||
BinaryROCAUC);
|
||||
} else {
|
||||
auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id,
|
||||
auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, ctx_->gpu_id,
|
||||
&this->d_cache_, n_classes);
|
||||
}
|
||||
return auc;
|
||||
@ -368,13 +368,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
||||
std::tuple<double, double, double>
|
||||
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
||||
double fp, tp, auc;
|
||||
if (tparam_->gpu_id == Context::kCpuId) {
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
std::tie(fp, tp, auc) =
|
||||
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||
} else {
|
||||
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
|
||||
tparam_->gpu_id, &this->d_cache_);
|
||||
ctx_->gpu_id, &this->d_cache_);
|
||||
}
|
||||
return std::make_tuple(fp, tp, auc);
|
||||
}
|
||||
@ -418,25 +418,25 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
||||
std::tuple<double, double, double>
|
||||
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
||||
double pr, re, auc;
|
||||
if (tparam_->gpu_id == Context::kCpuId) {
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
std::tie(pr, re, auc) =
|
||||
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||
} else {
|
||||
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
|
||||
tparam_->gpu_id, &this->d_cache_);
|
||||
ctx_->gpu_id, &this->d_cache_);
|
||||
}
|
||||
return std::make_tuple(pr, re, auc);
|
||||
}
|
||||
|
||||
double EvalMultiClass(HostDeviceVector<float> const &predts, MetaInfo const &info,
|
||||
size_t n_classes) {
|
||||
if (tparam_->gpu_id == Context::kCpuId) {
|
||||
auto n_threads = this->tparam_->Threads();
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
auto n_threads = this->ctx_->Threads();
|
||||
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
|
||||
BinaryPRAUC);
|
||||
} else {
|
||||
return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id,
|
||||
return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, ctx_->gpu_id,
|
||||
&d_cache_, n_classes);
|
||||
}
|
||||
}
|
||||
@ -445,8 +445,8 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
||||
MetaInfo const &info) {
|
||||
double auc{0};
|
||||
uint32_t valid_groups = 0;
|
||||
auto n_threads = tparam_->Threads();
|
||||
if (tparam_->gpu_id == Context::kCpuId) {
|
||||
auto n_threads = ctx_->Threads();
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
auto labels = info.labels.Data()->ConstHostSpan();
|
||||
if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) {
|
||||
InvalidLabels();
|
||||
@ -455,7 +455,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
||||
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
|
||||
} else {
|
||||
std::tie(auc, valid_groups) = GPURankingPRAUC(
|
||||
predts.ConstDeviceSpan(), info, tparam_->gpu_id, &d_cache_);
|
||||
predts.ConstDeviceSpan(), info, ctx_->gpu_id, &d_cache_);
|
||||
}
|
||||
return std::make_pair(auc, valid_groups);
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file elementwise_metric.cc
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file elementwise_metric.cu
|
||||
* \brief evaluation metrics for elementwise binary or regression.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*
|
||||
@ -180,16 +180,16 @@ class PseudoErrorLoss : public Metric {
|
||||
|
||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
||||
auto labels = info.labels.View(tparam_->gpu_id);
|
||||
preds.SetDevice(tparam_->gpu_id);
|
||||
auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
info.weights_.SetDevice(tparam_->gpu_id);
|
||||
common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
auto labels = info.labels.View(ctx_->gpu_id);
|
||||
preds.SetDevice(ctx_->gpu_id);
|
||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
info.weights_.SetDevice(ctx_->gpu_id);
|
||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan());
|
||||
float slope = this->param_.huber_slope;
|
||||
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
||||
PackedReduceResult result =
|
||||
Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
||||
Reduce(ctx_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
||||
float wt = weights[sample_id];
|
||||
auto a = labels(sample_id, target_id) - predts[i];
|
||||
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
|
||||
@ -348,16 +348,16 @@ struct EvalEWiseBase : public Metric {
|
||||
if (info.labels.Size() != 0) {
|
||||
CHECK_NE(info.labels.Shape(1), 0);
|
||||
}
|
||||
auto labels = info.labels.View(tparam_->gpu_id);
|
||||
info.weights_.SetDevice(tparam_->gpu_id);
|
||||
common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
auto labels = info.labels.View(ctx_->gpu_id);
|
||||
info.weights_.SetDevice(ctx_->gpu_id);
|
||||
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
: info.weights_.ConstDeviceSpan());
|
||||
preds.SetDevice(tparam_->gpu_id);
|
||||
auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
preds.SetDevice(ctx_->gpu_id);
|
||||
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||
|
||||
auto d_policy = policy_;
|
||||
auto result =
|
||||
Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
||||
Reduce(ctx_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
||||
float wt = weights[sample_id];
|
||||
float residue = d_policy.EvalRow(labels(sample_id, target_id), predts[i]);
|
||||
residue *= wt;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2020 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file metric_registry.cc
|
||||
* \brief Registry of objective functions.
|
||||
*/
|
||||
@ -43,18 +43,18 @@ Metric* CreateMetricImpl(const std::string& name) {
|
||||
}
|
||||
|
||||
Metric *
|
||||
Metric::Create(const std::string& name, Context const* tparam) {
|
||||
Metric::Create(const std::string& name, Context const* ctx) {
|
||||
auto metric = CreateMetricImpl<MetricReg>(name);
|
||||
if (metric == nullptr) {
|
||||
LOG(FATAL) << "Unknown metric function " << name;
|
||||
}
|
||||
|
||||
metric->tparam_ = tparam;
|
||||
metric->ctx_ = ctx;
|
||||
return metric;
|
||||
}
|
||||
|
||||
Metric *
|
||||
GPUMetric::CreateGPUMetric(const std::string& name, Context const* tparam) {
|
||||
GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
|
||||
auto metric = CreateMetricImpl<MetricGPUReg>(name);
|
||||
if (metric == nullptr) {
|
||||
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
|
||||
@ -65,7 +65,7 @@ GPUMetric::CreateGPUMetric(const std::string& name, Context const* tparam) {
|
||||
// Narrowing reference only for the compiler to allow assignment to a base class member.
|
||||
// As such, using this narrowed reference to refer to derived members will be an illegal op.
|
||||
// This is moot, as this type is stateless.
|
||||
static_cast<GPUMetric *>(metric)->tparam_ = tparam;
|
||||
static_cast<GPUMetric *>(metric)->ctx_ = ctx;
|
||||
return metric;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2019 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file multiclass_metric.cc
|
||||
* \brief evaluation metrics for multiclass classification.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
@ -175,9 +175,9 @@ struct EvalMClassBase : public Metric {
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
int device = tparam_->gpu_id;
|
||||
int device = ctx_->gpu_id;
|
||||
auto result =
|
||||
reducer_.Reduce(*tparam_, device, nclass, info.weights_, *info.labels.Data(), preds);
|
||||
reducer_.Reduce(*ctx_, device, nclass, info.weights_, *info.labels.Data(), preds);
|
||||
dat[0] = result.Residue();
|
||||
dat[1] = result.Weights();
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost contributors
|
||||
*/
|
||||
// When device ordinal is present, we would want to build the metrics on the GPU. It is *not*
|
||||
// possible for a valid device ordinal to be present for non GPU builds. However, it is possible
|
||||
@ -110,7 +110,7 @@ struct EvalAMS : public Metric {
|
||||
PredIndPairContainer rec(ndata);
|
||||
|
||||
const auto &h_preds = preds.ConstHostVector();
|
||||
common::ParallelFor(ndata, tparam_->Threads(),
|
||||
common::ParallelFor(ndata, ctx_->Threads(),
|
||||
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
@ -178,24 +178,24 @@ struct EvalRank : public Metric, public EvalRankConfig {
|
||||
double sum_metric = 0.0f;
|
||||
|
||||
// Check and see if we have the GPU metric registered in the internal registry
|
||||
if (tparam_->gpu_id >= 0) {
|
||||
if (ctx_->gpu_id >= 0) {
|
||||
if (!rank_gpu_) {
|
||||
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), tparam_));
|
||||
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_));
|
||||
}
|
||||
if (rank_gpu_) {
|
||||
sum_metric = rank_gpu_->Eval(preds, info);
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(tparam_);
|
||||
std::vector<double> sum_tloc(tparam_->Threads(), 0.0);
|
||||
CHECK(ctx_);
|
||||
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
|
||||
|
||||
if (!rank_gpu_ || tparam_->gpu_id < 0) {
|
||||
if (!rank_gpu_ || ctx_->gpu_id < 0) {
|
||||
const auto& labels = info.labels.View(Context::kCpuId);
|
||||
const auto &h_preds = preds.ConstHostVector();
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(tparam_->Threads())
|
||||
#pragma omp parallel num_threads(ctx_->Threads())
|
||||
{
|
||||
exc.Run([&]() {
|
||||
// each thread takes a local rec
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file rank_metric.cc
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
* \file rank_metric.cu
|
||||
* \brief prediction rank based metrics.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
@ -34,7 +34,7 @@ struct EvalRankGpu : public GPUMetric, public EvalRankConfig {
|
||||
|
||||
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
|
||||
auto device = tparam_->gpu_id;
|
||||
auto device = ctx_->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
info.labels.SetDevice(device);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2020 by Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by Contributors
|
||||
* \file survival_metric.cu
|
||||
* \brief Metrics for survival analysis
|
||||
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
||||
@ -196,21 +196,21 @@ struct EvalAFTNLogLik {
|
||||
|
||||
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
explicit EvalEWiseSurvivalBase(Context const *ctx) {
|
||||
tparam_ = ctx;
|
||||
ctx_ = ctx;
|
||||
}
|
||||
EvalEWiseSurvivalBase() = default;
|
||||
|
||||
void Configure(const Args& args) override {
|
||||
policy_.Configure(args);
|
||||
reducer_.Configure(policy_);
|
||||
CHECK(tparam_);
|
||||
CHECK(ctx_);
|
||||
}
|
||||
|
||||
double Eval(const HostDeviceVector<float>& preds, const MetaInfo& info) override {
|
||||
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
||||
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
||||
CHECK(tparam_);
|
||||
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_lower_bound_,
|
||||
CHECK(ctx_);
|
||||
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
|
||||
info.labels_upper_bound_, preds);
|
||||
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
@ -244,17 +244,13 @@ struct AFTNLogLikDispatcher : public Metric {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
switch (param_.aft_loss_distribution) {
|
||||
case common::ProbabilityDistributionType::kNormal:
|
||||
metric_.reset(
|
||||
new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(
|
||||
tparam_));
|
||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(ctx_));
|
||||
break;
|
||||
case common::ProbabilityDistributionType::kLogistic:
|
||||
metric_.reset(new EvalEWiseSurvivalBase<
|
||||
EvalAFTNLogLik<common::LogisticDistribution>>(tparam_));
|
||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::LogisticDistribution>>(ctx_));
|
||||
break;
|
||||
case common::ProbabilityDistributionType::kExtreme:
|
||||
metric_.reset(new EvalEWiseSurvivalBase<
|
||||
EvalAFTNLogLik<common::ExtremeDistribution>>(tparam_));
|
||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::ExtremeDistribution>>(ctx_));
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown probability distribution";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user