tests/cpp: Add tests for Metric RMSE
This commit is contained in:
parent
582c373274
commit
03abd47f49
@ -49,3 +49,14 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
|
||||
<< " weight=" << weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row = labels.size();
|
||||
info.labels = labels;
|
||||
info.weights = weights;
|
||||
return metric->Eval(preds, info, false);
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
std::string TempFileName();
|
||||
|
||||
@ -29,4 +30,10 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
|
||||
std::vector<xgboost::bst_float> out_grad,
|
||||
std::vector<xgboost::bst_float> out_hess);
|
||||
|
||||
xgboost::bst_float GetMetricEval(
|
||||
xgboost::Metric * metric,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
||||
|
||||
#endif
|
||||
|
||||
14
tests/cpp/metric/test_elementwise_metric.cc
Normal file
14
tests/cpp/metric/test_elementwise_metric.cc
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright by Contributors
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Metric, RMSE) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||
ASSERT_STREQ(metric->Name(), "rmse");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1, 0.9, 0.1, 0.9},
|
||||
{ 0, 0, 1, 1}),
|
||||
0.6403, 0.001);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user