Use double precision in metric calculation. (#7364)

This commit is contained in:
Jiaming Yuan
2021-11-02 12:00:32 +08:00
committed by GitHub
parent 239dbb3c0a
commit 0f7a9b42f1
11 changed files with 219 additions and 224 deletions

View File

@@ -206,9 +206,8 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
CHECK(tparam_);
}
bst_float Eval(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
bool distributed) override {
double Eval(const HostDeviceVector<float> &preds, const MetaInfo &info,
bool distributed) override {
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
CHECK(tparam_);
@@ -221,7 +220,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
return static_cast<bst_float>(Policy::GetFinal(dat[0], dat[1]));
return Policy::GetFinal(dat[0], dat[1]);
}
const char* Name() const override {
@@ -241,9 +240,8 @@ struct AFTNLogLikDispatcher : public Metric {
return "aft-nloglik";
}
bst_float Eval(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
bool distributed) override {
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override {
CHECK(metric_) << "AFT metric must be configured first, with distribution type and scale";
return metric_->Eval(preds, info, distributed);
}