Pass DMatrix into metric for caching. (#8790)

This commit is contained in:
Jiaming Yuan
2023-02-13 22:15:05 +08:00
committed by GitHub
parent 31d3ec07af
commit 81b2ee1153
17 changed files with 95 additions and 70 deletions

View File

@@ -18,7 +18,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
metric->Configure(Args{});
HostDeviceVector<float> predts;
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
@@ -40,9 +41,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
h_upper[i] = 10;
}
auto result = metric->Eval(predts, info);
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info), result);
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
} // anonymous namespace
@@ -54,7 +55,8 @@ TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
* Test aggregate output from the AFT metric over a small test data set.
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
**/
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector()
= { 100.0f, 0.0f, 60.0f, 16.0f };
@@ -72,14 +74,15 @@ TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
{"aft_loss_distribution_scale", "1.0"} });
EXPECT_NEAR(metric->Eval(preds, info), test_case.reference_value, 1e-4);
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
}
}
TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
@@ -87,15 +90,15 @@ TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.75f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
info.labels_lower_bound_.HostVector()[2] = 70.0f;
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_lower_bound_.HostVector()[0] = 70.0f;
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.25f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
}