Rename context in Metric. (#8686)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2020 by Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by Contributors
|
||||
* \file survival_metric.cu
|
||||
* \brief Metrics for survival analysis
|
||||
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
|
||||
@@ -196,21 +196,21 @@ struct EvalAFTNLogLik {
|
||||
|
||||
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
explicit EvalEWiseSurvivalBase(Context const *ctx) {
|
||||
tparam_ = ctx;
|
||||
ctx_ = ctx;
|
||||
}
|
||||
EvalEWiseSurvivalBase() = default;
|
||||
|
||||
void Configure(const Args& args) override {
|
||||
policy_.Configure(args);
|
||||
reducer_.Configure(policy_);
|
||||
CHECK(tparam_);
|
||||
CHECK(ctx_);
|
||||
}
|
||||
|
||||
double Eval(const HostDeviceVector<float>& preds, const MetaInfo& info) override {
|
||||
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
||||
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
||||
CHECK(tparam_);
|
||||
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_lower_bound_,
|
||||
CHECK(ctx_);
|
||||
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
|
||||
info.labels_upper_bound_, preds);
|
||||
|
||||
double dat[2]{result.Residue(), result.Weights()};
|
||||
@@ -244,17 +244,13 @@ struct AFTNLogLikDispatcher : public Metric {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
switch (param_.aft_loss_distribution) {
|
||||
case common::ProbabilityDistributionType::kNormal:
|
||||
metric_.reset(
|
||||
new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(
|
||||
tparam_));
|
||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(ctx_));
|
||||
break;
|
||||
case common::ProbabilityDistributionType::kLogistic:
|
||||
metric_.reset(new EvalEWiseSurvivalBase<
|
||||
EvalAFTNLogLik<common::LogisticDistribution>>(tparam_));
|
||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::LogisticDistribution>>(ctx_));
|
||||
break;
|
||||
case common::ProbabilityDistributionType::kExtreme:
|
||||
metric_.reset(new EvalEWiseSurvivalBase<
|
||||
EvalAFTNLogLik<common::ExtremeDistribution>>(tparam_));
|
||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::ExtremeDistribution>>(ctx_));
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown probability distribution";
|
||||
|
||||
Reference in New Issue
Block a user