tests/cpp: Add tests for Metric RMSE
This commit is contained in:
parent
582c373274
commit
03abd47f49
@ -49,3 +49,14 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
|
|||||||
<< " weight=" << weights[i];
|
<< " weight=" << weights[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
||||||
|
std::vector<xgboost::bst_float> preds,
|
||||||
|
std::vector<xgboost::bst_float> labels,
|
||||||
|
std::vector<xgboost::bst_float> weights) {
|
||||||
|
xgboost::MetaInfo info;
|
||||||
|
info.num_row = labels.size();
|
||||||
|
info.labels = labels;
|
||||||
|
info.weights = weights;
|
||||||
|
return metric->Eval(preds, info, false);
|
||||||
|
}
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/objective.h>
|
#include <xgboost/objective.h>
|
||||||
|
#include <xgboost/metric.h>
|
||||||
|
|
||||||
std::string TempFileName();
|
std::string TempFileName();
|
||||||
|
|
||||||
@ -29,4 +30,10 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
|
|||||||
std::vector<xgboost::bst_float> out_grad,
|
std::vector<xgboost::bst_float> out_grad,
|
||||||
std::vector<xgboost::bst_float> out_hess);
|
std::vector<xgboost::bst_float> out_hess);
|
||||||
|
|
||||||
|
xgboost::bst_float GetMetricEval(
|
||||||
|
xgboost::Metric * metric,
|
||||||
|
std::vector<xgboost::bst_float> preds,
|
||||||
|
std::vector<xgboost::bst_float> labels,
|
||||||
|
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
14
tests/cpp/metric/test_elementwise_metric.cc
Normal file
14
tests/cpp/metric/test_elementwise_metric.cc
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright by Contributors
|
||||||
|
#include <xgboost/metric.h>
|
||||||
|
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
TEST(Metric, RMSE) {
|
||||||
|
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||||
|
ASSERT_STREQ(metric->Name(), "rmse");
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
{0.1, 0.9, 0.1, 0.9},
|
||||||
|
{ 0, 0, 1, 1}),
|
||||||
|
0.6403, 0.001);
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user