Pass DMatrix into metric for caching. (#8790)
This commit is contained in:
@@ -156,14 +156,15 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
xgboost::linalg::Tensor<float, 2> const& labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups) {
|
||||
xgboost::MetaInfo info;
|
||||
std::shared_ptr<xgboost::DMatrix> p_fmat{xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix()};
|
||||
auto& info = p_fmat->Info();
|
||||
info.num_row_ = labels.Shape(0);
|
||||
info.labels.Reshape(labels.Shape()[0], labels.Shape()[1]);
|
||||
info.labels.Data()->Copy(*labels.Data());
|
||||
info.weights_.HostVector() = weights;
|
||||
info.group_ptr_ = groups;
|
||||
|
||||
return metric->Eval(preds, info);
|
||||
return metric->Evaluate(preds, p_fmat);
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
@@ -661,4 +662,4 @@ void DeleteRMMResource(RMMAllocator*) {}
|
||||
|
||||
RMMAllocatorPtr SetUpRMMResourceForCppTests(int, char**) { return {nullptr, DeleteRMMResource}; }
|
||||
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user