Pass DMatrix into metric for caching. (#8790)

This commit is contained in:
Jiaming Yuan
2023-02-13 22:15:05 +08:00
committed by GitHub
parent 31d3ec07af
commit 81b2ee1153
17 changed files with 95 additions and 70 deletions

View File

@@ -10,7 +10,8 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
HostDeviceVector<float> predts;
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
@@ -35,9 +36,9 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
}
}
auto result = metric->Eval(predts, info);
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info), result);
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
} // namespace xgboost