Use double precision in metric calculation. (#7364)
This commit is contained in:
@@ -206,9 +206,8 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
CHECK(tparam_);
|
||||
}
|
||||
|
||||
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) override {
|
||||
double Eval(const HostDeviceVector<float> &preds, const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
||||
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
||||
CHECK(tparam_);
|
||||
@@ -221,7 +220,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return static_cast<bst_float>(Policy::GetFinal(dat[0], dat[1]));
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
@@ -241,9 +240,8 @@ struct AFTNLogLikDispatcher : public Metric {
|
||||
return "aft-nloglik";
|
||||
}
|
||||
|
||||
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) override {
|
||||
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK(metric_) << "AFT metric must be configured first, with distribution type and scale";
|
||||
return metric_->Eval(preds, info, distributed);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user