Add quantile metric. (#8761)

This commit is contained in:
Jiaming Yuan
2023-02-13 19:07:40 +08:00
committed by GitHub
parent d11a0044cf
commit 457f704e3d
11 changed files with 313 additions and 4 deletions

View File

@@ -0,0 +1,30 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/quantile_loss_utils.h"
#include "xgboost/base.h" // Args
namespace xgboost {
namespace common {
TEST(QuantileLossParam, Basic) {
QuantileLossParam param;
auto& ref = param.quantile_alpha.Get();
param.UpdateAllowUnknown(Args{{"quantile_alpha", "0.3"}});
ASSERT_EQ(ref.size(), 1);
ASSERT_NEAR(ref[0], 0.3, kRtEps);
param.UpdateAllowUnknown(Args{{"quantile_alpha", "[0.3, 0.6]"}});
ASSERT_EQ(param.quantile_alpha.Get().size(), 2);
ASSERT_NEAR(ref[0], 0.3, kRtEps);
ASSERT_NEAR(ref[1], 0.6, kRtEps);
param.UpdateAllowUnknown(Args{{"quantile_alpha", "(0.6, 0.3)"}});
ASSERT_EQ(param.quantile_alpha.Get().size(), 2);
ASSERT_NEAR(ref[0], 0.6, kRtEps);
ASSERT_NEAR(ref[1], 0.3, kRtEps);
}
} // namespace common
} // namespace xgboost

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2018-2022 by XGBoost contributors
/**
* Copyright 2018-2023 by XGBoost contributors
*/
#include <xgboost/json.h>
#include <xgboost/metric.h>
@@ -311,5 +311,36 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) {
ASSERT_FLOAT_EQ(ret, loss);
ASSERT_FLOAT_EQ(ret, loss_w);
}
TEST(Metric, DeclareUnifiedTest(Quantile)) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> metric{Metric::Create("quantile", &ctx)};
HostDeviceVector<float> predts{0.1f, 0.9f, 0.1f, 0.9f};
std::vector<float> labels{0.5f, 0.5f, 0.9f, 0.1f};
std::vector<float> weights{0.2f, 0.4f,0.6f, 0.8f};
metric->Configure(Args{{"quantile_alpha", "[0.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.400f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.2]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.376f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.4]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.352f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.8]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.304f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.28f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.2]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.4]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.8]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
}
} // namespace metric
} // namespace xgboost