Rename context in Metric. (#8686)
This commit is contained in:
parent
d6018eb4b9
commit
9f598efc3e
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014 by Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file metric.h
|
* \file metric.h
|
||||||
* \brief interface of evaluation metric function supported in xgboost.
|
* \brief interface of evaluation metric function supported in xgboost.
|
||||||
* \author Tianqi Chen, Kailong Chen
|
* \author Tianqi Chen, Kailong Chen
|
||||||
@ -27,7 +27,7 @@ struct Context;
|
|||||||
*/
|
*/
|
||||||
class Metric : public Configurable {
|
class Metric : public Configurable {
|
||||||
protected:
|
protected:
|
||||||
Context const* tparam_;
|
Context const* ctx_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 by XGBoost Contributors
|
* Copyright 2021-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "auc.h"
|
#include "auc.h"
|
||||||
|
|
||||||
@ -255,10 +255,10 @@ template <typename Curve>
|
|||||||
class EvalAUC : public Metric {
|
class EvalAUC : public Metric {
|
||||||
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
|
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
|
||||||
double auc {0};
|
double auc {0};
|
||||||
if (tparam_->gpu_id != Context::kCpuId) {
|
if (ctx_->gpu_id != Context::kCpuId) {
|
||||||
preds.SetDevice(tparam_->gpu_id);
|
preds.SetDevice(ctx_->gpu_id);
|
||||||
info.labels.SetDevice(tparam_->gpu_id);
|
info.labels.SetDevice(ctx_->gpu_id);
|
||||||
info.weights_.SetDevice(tparam_->gpu_id);
|
info.weights_.SetDevice(ctx_->gpu_id);
|
||||||
}
|
}
|
||||||
// We use the global size to handle empty dataset.
|
// We use the global size to handle empty dataset.
|
||||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||||
@ -339,13 +339,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
|||||||
MetaInfo const &info) {
|
MetaInfo const &info) {
|
||||||
double auc{0};
|
double auc{0};
|
||||||
uint32_t valid_groups = 0;
|
uint32_t valid_groups = 0;
|
||||||
auto n_threads = tparam_->Threads();
|
auto n_threads = ctx_->Threads();
|
||||||
if (tparam_->gpu_id == Context::kCpuId) {
|
if (ctx_->gpu_id == Context::kCpuId) {
|
||||||
std::tie(auc, valid_groups) =
|
std::tie(auc, valid_groups) =
|
||||||
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
|
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
|
||||||
} else {
|
} else {
|
||||||
std::tie(auc, valid_groups) = GPURankingAUC(
|
std::tie(auc, valid_groups) = GPURankingAUC(
|
||||||
predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
|
predts.ConstDeviceSpan(), info, ctx_->gpu_id, &this->d_cache_);
|
||||||
}
|
}
|
||||||
return std::make_pair(auc, valid_groups);
|
return std::make_pair(auc, valid_groups);
|
||||||
}
|
}
|
||||||
@ -353,13 +353,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
|||||||
double EvalMultiClass(HostDeviceVector<float> const &predts,
|
double EvalMultiClass(HostDeviceVector<float> const &predts,
|
||||||
MetaInfo const &info, size_t n_classes) {
|
MetaInfo const &info, size_t n_classes) {
|
||||||
double auc{0};
|
double auc{0};
|
||||||
auto n_threads = tparam_->Threads();
|
auto n_threads = ctx_->Threads();
|
||||||
CHECK_NE(n_classes, 0);
|
CHECK_NE(n_classes, 0);
|
||||||
if (tparam_->gpu_id == Context::kCpuId) {
|
if (ctx_->gpu_id == Context::kCpuId) {
|
||||||
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
|
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
|
||||||
BinaryROCAUC);
|
BinaryROCAUC);
|
||||||
} else {
|
} else {
|
||||||
auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id,
|
auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, ctx_->gpu_id,
|
||||||
&this->d_cache_, n_classes);
|
&this->d_cache_, n_classes);
|
||||||
}
|
}
|
||||||
return auc;
|
return auc;
|
||||||
@ -368,13 +368,13 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
|||||||
std::tuple<double, double, double>
|
std::tuple<double, double, double>
|
||||||
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
||||||
double fp, tp, auc;
|
double fp, tp, auc;
|
||||||
if (tparam_->gpu_id == Context::kCpuId) {
|
if (ctx_->gpu_id == Context::kCpuId) {
|
||||||
std::tie(fp, tp, auc) =
|
std::tie(fp, tp, auc) =
|
||||||
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
|
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||||
} else {
|
} else {
|
||||||
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
|
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
|
||||||
tparam_->gpu_id, &this->d_cache_);
|
ctx_->gpu_id, &this->d_cache_);
|
||||||
}
|
}
|
||||||
return std::make_tuple(fp, tp, auc);
|
return std::make_tuple(fp, tp, auc);
|
||||||
}
|
}
|
||||||
@ -418,25 +418,25 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
|||||||
std::tuple<double, double, double>
|
std::tuple<double, double, double>
|
||||||
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
||||||
double pr, re, auc;
|
double pr, re, auc;
|
||||||
if (tparam_->gpu_id == Context::kCpuId) {
|
if (ctx_->gpu_id == Context::kCpuId) {
|
||||||
std::tie(pr, re, auc) =
|
std::tie(pr, re, auc) =
|
||||||
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||||
} else {
|
} else {
|
||||||
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
|
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
|
||||||
tparam_->gpu_id, &this->d_cache_);
|
ctx_->gpu_id, &this->d_cache_);
|
||||||
}
|
}
|
||||||
return std::make_tuple(pr, re, auc);
|
return std::make_tuple(pr, re, auc);
|
||||||
}
|
}
|
||||||
|
|
||||||
double EvalMultiClass(HostDeviceVector<float> const &predts, MetaInfo const &info,
|
double EvalMultiClass(HostDeviceVector<float> const &predts, MetaInfo const &info,
|
||||||
size_t n_classes) {
|
size_t n_classes) {
|
||||||
if (tparam_->gpu_id == Context::kCpuId) {
|
if (ctx_->gpu_id == Context::kCpuId) {
|
||||||
auto n_threads = this->tparam_->Threads();
|
auto n_threads = this->ctx_->Threads();
|
||||||
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
|
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
|
||||||
BinaryPRAUC);
|
BinaryPRAUC);
|
||||||
} else {
|
} else {
|
||||||
return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id,
|
return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, ctx_->gpu_id,
|
||||||
&d_cache_, n_classes);
|
&d_cache_, n_classes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -445,8 +445,8 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
|||||||
MetaInfo const &info) {
|
MetaInfo const &info) {
|
||||||
double auc{0};
|
double auc{0};
|
||||||
uint32_t valid_groups = 0;
|
uint32_t valid_groups = 0;
|
||||||
auto n_threads = tparam_->Threads();
|
auto n_threads = ctx_->Threads();
|
||||||
if (tparam_->gpu_id == Context::kCpuId) {
|
if (ctx_->gpu_id == Context::kCpuId) {
|
||||||
auto labels = info.labels.Data()->ConstHostSpan();
|
auto labels = info.labels.Data()->ConstHostSpan();
|
||||||
if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) {
|
if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) {
|
||||||
InvalidLabels();
|
InvalidLabels();
|
||||||
@ -455,7 +455,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
|||||||
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
|
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
|
||||||
} else {
|
} else {
|
||||||
std::tie(auc, valid_groups) = GPURankingPRAUC(
|
std::tie(auc, valid_groups) = GPURankingPRAUC(
|
||||||
predts.ConstDeviceSpan(), info, tparam_->gpu_id, &d_cache_);
|
predts.ConstDeviceSpan(), info, ctx_->gpu_id, &d_cache_);
|
||||||
}
|
}
|
||||||
return std::make_pair(auc, valid_groups);
|
return std::make_pair(auc, valid_groups);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2022 by XGBoost Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file elementwise_metric.cc
|
* \file elementwise_metric.cu
|
||||||
* \brief evaluation metrics for elementwise binary or regression.
|
* \brief evaluation metrics for elementwise binary or regression.
|
||||||
* \author Kailong Chen, Tianqi Chen
|
* \author Kailong Chen, Tianqi Chen
|
||||||
*
|
*
|
||||||
@ -180,16 +180,16 @@ class PseudoErrorLoss : public Metric {
|
|||||||
|
|
||||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||||
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
CHECK_EQ(info.labels.Shape(0), info.num_row_);
|
||||||
auto labels = info.labels.View(tparam_->gpu_id);
|
auto labels = info.labels.View(ctx_->gpu_id);
|
||||||
preds.SetDevice(tparam_->gpu_id);
|
preds.SetDevice(ctx_->gpu_id);
|
||||||
auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||||
info.weights_.SetDevice(tparam_->gpu_id);
|
info.weights_.SetDevice(ctx_->gpu_id);
|
||||||
common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan()
|
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||||
: info.weights_.ConstDeviceSpan());
|
: info.weights_.ConstDeviceSpan());
|
||||||
float slope = this->param_.huber_slope;
|
float slope = this->param_.huber_slope;
|
||||||
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
|
||||||
PackedReduceResult result =
|
PackedReduceResult result =
|
||||||
Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
Reduce(ctx_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
||||||
float wt = weights[sample_id];
|
float wt = weights[sample_id];
|
||||||
auto a = labels(sample_id, target_id) - predts[i];
|
auto a = labels(sample_id, target_id) - predts[i];
|
||||||
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
|
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
|
||||||
@ -348,16 +348,16 @@ struct EvalEWiseBase : public Metric {
|
|||||||
if (info.labels.Size() != 0) {
|
if (info.labels.Size() != 0) {
|
||||||
CHECK_NE(info.labels.Shape(1), 0);
|
CHECK_NE(info.labels.Shape(1), 0);
|
||||||
}
|
}
|
||||||
auto labels = info.labels.View(tparam_->gpu_id);
|
auto labels = info.labels.View(ctx_->gpu_id);
|
||||||
info.weights_.SetDevice(tparam_->gpu_id);
|
info.weights_.SetDevice(ctx_->gpu_id);
|
||||||
common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan()
|
common::OptionalWeights weights(ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||||
: info.weights_.ConstDeviceSpan());
|
: info.weights_.ConstDeviceSpan());
|
||||||
preds.SetDevice(tparam_->gpu_id);
|
preds.SetDevice(ctx_->gpu_id);
|
||||||
auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
auto predts = ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
|
||||||
|
|
||||||
auto d_policy = policy_;
|
auto d_policy = policy_;
|
||||||
auto result =
|
auto result =
|
||||||
Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
Reduce(ctx_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
|
||||||
float wt = weights[sample_id];
|
float wt = weights[sample_id];
|
||||||
float residue = d_policy.EvalRow(labels(sample_id, target_id), predts[i]);
|
float residue = d_policy.EvalRow(labels(sample_id, target_id), predts[i]);
|
||||||
residue *= wt;
|
residue *= wt;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2020 by Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file metric_registry.cc
|
* \file metric_registry.cc
|
||||||
* \brief Registry of objective functions.
|
* \brief Registry of objective functions.
|
||||||
*/
|
*/
|
||||||
@ -43,18 +43,18 @@ Metric* CreateMetricImpl(const std::string& name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Metric *
|
Metric *
|
||||||
Metric::Create(const std::string& name, Context const* tparam) {
|
Metric::Create(const std::string& name, Context const* ctx) {
|
||||||
auto metric = CreateMetricImpl<MetricReg>(name);
|
auto metric = CreateMetricImpl<MetricReg>(name);
|
||||||
if (metric == nullptr) {
|
if (metric == nullptr) {
|
||||||
LOG(FATAL) << "Unknown metric function " << name;
|
LOG(FATAL) << "Unknown metric function " << name;
|
||||||
}
|
}
|
||||||
|
|
||||||
metric->tparam_ = tparam;
|
metric->ctx_ = ctx;
|
||||||
return metric;
|
return metric;
|
||||||
}
|
}
|
||||||
|
|
||||||
Metric *
|
Metric *
|
||||||
GPUMetric::CreateGPUMetric(const std::string& name, Context const* tparam) {
|
GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
|
||||||
auto metric = CreateMetricImpl<MetricGPUReg>(name);
|
auto metric = CreateMetricImpl<MetricGPUReg>(name);
|
||||||
if (metric == nullptr) {
|
if (metric == nullptr) {
|
||||||
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
|
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
|
||||||
@ -65,7 +65,7 @@ GPUMetric::CreateGPUMetric(const std::string& name, Context const* tparam) {
|
|||||||
// Narrowing reference only for the compiler to allow assignment to a base class member.
|
// Narrowing reference only for the compiler to allow assignment to a base class member.
|
||||||
// As such, using this narrowed reference to refer to derived members will be an illegal op.
|
// As such, using this narrowed reference to refer to derived members will be an illegal op.
|
||||||
// This is moot, as this type is stateless.
|
// This is moot, as this type is stateless.
|
||||||
static_cast<GPUMetric *>(metric)->tparam_ = tparam;
|
static_cast<GPUMetric *>(metric)->ctx_ = ctx;
|
||||||
return metric;
|
return metric;
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2019 by Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file multiclass_metric.cc
|
* \file multiclass_metric.cc
|
||||||
* \brief evaluation metrics for multiclass classification.
|
* \brief evaluation metrics for multiclass classification.
|
||||||
* \author Kailong Chen, Tianqi Chen
|
* \author Kailong Chen, Tianqi Chen
|
||||||
@ -175,9 +175,9 @@ struct EvalMClassBase : public Metric {
|
|||||||
CHECK_GE(nclass, 1U)
|
CHECK_GE(nclass, 1U)
|
||||||
<< "mlogloss and merror are only used for multi-class classification,"
|
<< "mlogloss and merror are only used for multi-class classification,"
|
||||||
<< " use logloss for binary classification";
|
<< " use logloss for binary classification";
|
||||||
int device = tparam_->gpu_id;
|
int device = ctx_->gpu_id;
|
||||||
auto result =
|
auto result =
|
||||||
reducer_.Reduce(*tparam_, device, nclass, info.weights_, *info.labels.Data(), preds);
|
reducer_.Reduce(*ctx_, device, nclass, info.weights_, *info.labels.Data(), preds);
|
||||||
dat[0] = result.Residue();
|
dat[0] = result.Residue();
|
||||||
dat[1] = result.Weights();
|
dat[1] = result.Weights();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020 XGBoost contributors
|
* Copyright 2020-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
// When device ordinal is present, we would want to build the metrics on the GPU. It is *not*
|
// When device ordinal is present, we would want to build the metrics on the GPU. It is *not*
|
||||||
// possible for a valid device ordinal to be present for non GPU builds. However, it is possible
|
// possible for a valid device ordinal to be present for non GPU builds. However, it is possible
|
||||||
@ -110,7 +110,7 @@ struct EvalAMS : public Metric {
|
|||||||
PredIndPairContainer rec(ndata);
|
PredIndPairContainer rec(ndata);
|
||||||
|
|
||||||
const auto &h_preds = preds.ConstHostVector();
|
const auto &h_preds = preds.ConstHostVector();
|
||||||
common::ParallelFor(ndata, tparam_->Threads(),
|
common::ParallelFor(ndata, ctx_->Threads(),
|
||||||
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
|
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
|
||||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||||
@ -178,24 +178,24 @@ struct EvalRank : public Metric, public EvalRankConfig {
|
|||||||
double sum_metric = 0.0f;
|
double sum_metric = 0.0f;
|
||||||
|
|
||||||
// Check and see if we have the GPU metric registered in the internal registry
|
// Check and see if we have the GPU metric registered in the internal registry
|
||||||
if (tparam_->gpu_id >= 0) {
|
if (ctx_->gpu_id >= 0) {
|
||||||
if (!rank_gpu_) {
|
if (!rank_gpu_) {
|
||||||
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), tparam_));
|
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_));
|
||||||
}
|
}
|
||||||
if (rank_gpu_) {
|
if (rank_gpu_) {
|
||||||
sum_metric = rank_gpu_->Eval(preds, info);
|
sum_metric = rank_gpu_->Eval(preds, info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(tparam_);
|
CHECK(ctx_);
|
||||||
std::vector<double> sum_tloc(tparam_->Threads(), 0.0);
|
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
|
||||||
|
|
||||||
if (!rank_gpu_ || tparam_->gpu_id < 0) {
|
if (!rank_gpu_ || ctx_->gpu_id < 0) {
|
||||||
const auto& labels = info.labels.View(Context::kCpuId);
|
const auto& labels = info.labels.View(Context::kCpuId);
|
||||||
const auto &h_preds = preds.ConstHostVector();
|
const auto &h_preds = preds.ConstHostVector();
|
||||||
|
|
||||||
dmlc::OMPException exc;
|
dmlc::OMPException exc;
|
||||||
#pragma omp parallel num_threads(tparam_->Threads())
|
#pragma omp parallel num_threads(ctx_->Threads())
|
||||||
{
|
{
|
||||||
exc.Run([&]() {
|
exc.Run([&]() {
|
||||||
// each thread takes a local rec
|
// each thread takes a local rec
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020 by Contributors
|
* Copyright 2020-2023 by XGBoost Contributors
|
||||||
* \file rank_metric.cc
|
* \file rank_metric.cu
|
||||||
* \brief prediction rank based metrics.
|
* \brief prediction rank based metrics.
|
||||||
* \author Kailong Chen, Tianqi Chen
|
* \author Kailong Chen, Tianqi Chen
|
||||||
*/
|
*/
|
||||||
@ -34,7 +34,7 @@ struct EvalRankGpu : public GPUMetric, public EvalRankConfig {
|
|||||||
|
|
||||||
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
|
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||||
|
|
||||||
auto device = tparam_->gpu_id;
|
auto device = ctx_->gpu_id;
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
|
|
||||||
info.labels.SetDevice(device);
|
info.labels.SetDevice(device);
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2020 by Contributors
|
* Copyright 2019-2023 by Contributors
|
||||||
* \file survival_metric.cu
|
* \file survival_metric.cu
|
||||||
* \brief Metrics for survival analysis
|
* \brief Metrics for survival analysis
|
||||||
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
||||||
@ -196,21 +196,21 @@ struct EvalAFTNLogLik {
|
|||||||
|
|
||||||
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||||
explicit EvalEWiseSurvivalBase(Context const *ctx) {
|
explicit EvalEWiseSurvivalBase(Context const *ctx) {
|
||||||
tparam_ = ctx;
|
ctx_ = ctx;
|
||||||
}
|
}
|
||||||
EvalEWiseSurvivalBase() = default;
|
EvalEWiseSurvivalBase() = default;
|
||||||
|
|
||||||
void Configure(const Args& args) override {
|
void Configure(const Args& args) override {
|
||||||
policy_.Configure(args);
|
policy_.Configure(args);
|
||||||
reducer_.Configure(policy_);
|
reducer_.Configure(policy_);
|
||||||
CHECK(tparam_);
|
CHECK(ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
double Eval(const HostDeviceVector<float>& preds, const MetaInfo& info) override {
|
double Eval(const HostDeviceVector<float>& preds, const MetaInfo& info) override {
|
||||||
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
||||||
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
||||||
CHECK(tparam_);
|
CHECK(ctx_);
|
||||||
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_lower_bound_,
|
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
|
||||||
info.labels_upper_bound_, preds);
|
info.labels_upper_bound_, preds);
|
||||||
|
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
double dat[2]{result.Residue(), result.Weights()};
|
||||||
@ -244,17 +244,13 @@ struct AFTNLogLikDispatcher : public Metric {
|
|||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
switch (param_.aft_loss_distribution) {
|
switch (param_.aft_loss_distribution) {
|
||||||
case common::ProbabilityDistributionType::kNormal:
|
case common::ProbabilityDistributionType::kNormal:
|
||||||
metric_.reset(
|
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(ctx_));
|
||||||
new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(
|
|
||||||
tparam_));
|
|
||||||
break;
|
break;
|
||||||
case common::ProbabilityDistributionType::kLogistic:
|
case common::ProbabilityDistributionType::kLogistic:
|
||||||
metric_.reset(new EvalEWiseSurvivalBase<
|
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::LogisticDistribution>>(ctx_));
|
||||||
EvalAFTNLogLik<common::LogisticDistribution>>(tparam_));
|
|
||||||
break;
|
break;
|
||||||
case common::ProbabilityDistributionType::kExtreme:
|
case common::ProbabilityDistributionType::kExtreme:
|
||||||
metric_.reset(new EvalEWiseSurvivalBase<
|
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::ExtremeDistribution>>(ctx_));
|
||||||
EvalAFTNLogLik<common::ExtremeDistribution>>(tparam_));
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown probability distribution";
|
LOG(FATAL) << "Unknown probability distribution";
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user