diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 111975870..630965e38 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -91,6 +91,7 @@ OBJECTS= \ $(PKGROOT)/src/common/survival_util.o \ $(PKGROOT)/src/common/threading_utils.o \ $(PKGROOT)/src/common/ranking_utils.o \ + $(PKGROOT)/src/common/quantile_loss_utils.o \ $(PKGROOT)/src/common/timer.o \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index d89aadc3d..09f09598a 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -91,6 +91,7 @@ OBJECTS= \ $(PKGROOT)/src/common/survival_util.o \ $(PKGROOT)/src/common/threading_utils.o \ $(PKGROOT)/src/common/ranking_utils.o \ + $(PKGROOT)/src/common/quantile_loss_utils.o \ $(PKGROOT)/src/common/timer.o \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ diff --git a/python-package/xgboost/testing/metrics.py b/python-package/xgboost/testing/metrics.py new file mode 100644 index 000000000..6edbe0e3d --- /dev/null +++ b/python-package/xgboost/testing/metrics.py @@ -0,0 +1,27 @@ +"""Tests for evaluation metrics.""" +from typing import Dict + +import numpy as np + +import xgboost as xgb + + +def check_quantile_error(tree_method: str) -> None: + """Test for the `quantile` loss.""" + from sklearn.datasets import make_regression + from sklearn.metrics import mean_pinball_loss + + rng = np.random.RandomState(19) + # pylint: disable=unbalanced-tuple-unpacking + X, y = make_regression(128, 3, random_state=rng) + Xy = xgb.QuantileDMatrix(X, y) + evals_result: Dict[str, Dict] = {} + booster = xgb.train( + {"tree_method": tree_method, "eval_metric": "quantile", "quantile_alpha": 0.3}, + Xy, + evals=[(Xy, "Train")], + evals_result=evals_result, + ) + predt = booster.inplace_predict(X) + loss = mean_pinball_loss(y, predt, alpha=0.3) + np.testing.assert_allclose(evals_result["Train"]["quantile"][-1], loss) diff --git a/src/common/quantile_loss_utils.cc b/src/common/quantile_loss_utils.cc new file mode 100644 index 000000000..59397b701 --- /dev/null +++ b/src/common/quantile_loss_utils.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#include "quantile_loss_utils.h" + +#include // std::isspace +#include // std::istream +#include // std::ostream +#include // std::string +#include // std::vector + +#include "xgboost/json.h" // F32Array,TypeCheck,get,Number +#include "xgboost/json_io.h" // JsonWriter + +namespace xgboost { +namespace common { +std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) { + auto const& t = array.Get(); + xgboost::F32Array arr{t.size()}; + for (std::size_t i = 0; i < t.size(); ++i) { + arr.Set(i, t[i]); + } + std::vector stream; + xgboost::JsonWriter writer{&stream}; + arr.Save(&writer); + for (auto c : stream) { + os << c; + } + return os; +} + +std::istream& operator>>(std::istream& is, ParamFloatArray& array) { + auto& t = array.Get(); + t.clear(); + std::string str; + while (!is.eof()) { + std::string tmp; + is >> tmp; + str += tmp; + } + std::size_t head{0}; + // unify notation for parsing. + while (std::isspace(str[head])) { + ++head; + } + if (str[head] == '(') { + str[head] = '['; + } + auto tail = str.size() - 1; + while (std::isspace(str[tail])) { + --tail; + } + if (str[tail] == ')') { + str[tail] = ']'; + } + + auto jarr = xgboost::Json::Load(xgboost::StringView{str}); + // return if there's only one element + if (xgboost::IsA(jarr)) { + t.emplace_back(xgboost::get(jarr)); + return is; + } + + auto jvec = xgboost::get(jarr); + for (auto v : jvec) { + xgboost::TypeCheck(v, "alpha"); + t.emplace_back(get(v)); + } + return is; +} + +DMLC_REGISTER_PARAMETER(QuantileLossParam); +} // namespace common +} // namespace xgboost diff --git a/src/common/quantile_loss_utils.h b/src/common/quantile_loss_utils.h new file mode 100644 index 000000000..bc781de25 --- /dev/null +++ b/src/common/quantile_loss_utils.h @@ -0,0 +1,51 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#ifndef XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_ +#define XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_ + +#include // std::all_of +#include // std::istream +#include // std::ostream +#include // std::vector + +#include "xgboost/logging.h" // CHECK +#include "xgboost/parameter.h" // XGBoostParameter + +namespace xgboost { +namespace common { +// A shim to enable ADL for parameter parsing. Alternatively, we can put the stream +// operators in std namespace, which seems to be less ideal. +class ParamFloatArray { + std::vector values_; + + public: + std::vector& Get() { return values_; } + std::vector const& Get() const { return values_; } + decltype(values_)::const_reference operator[](decltype(values_)::size_type i) const { + return values_[i]; + } +}; + +// For parsing quantile parameters. Input can be a string to a single float or a list of +// floats. +std::ostream& operator<<(std::ostream& os, const ParamFloatArray& t); +std::istream& operator>>(std::istream& is, ParamFloatArray& t); + +struct QuantileLossParam : public XGBoostParameter { + ParamFloatArray quantile_alpha; + DMLC_DECLARE_PARAMETER(QuantileLossParam) { + DMLC_DECLARE_FIELD(quantile_alpha).describe("List of quantiles for quantile loss."); + } + void Validate() const { + CHECK(GetInitialised()); + CHECK(!quantile_alpha.Get().empty()); + auto const& array = quantile_alpha.Get(); + auto valid = + std::all_of(array.cbegin(), array.cend(), [](auto q) { return q >= 0.0 && q <= 1.0; }); + CHECK(valid) << "quantile alpha must be in the range [0.0, 1.0]."; + } +}; +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_ diff --git a/src/learner.cc b/src/learner.cc index 669566e45..ee2b1528b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -695,6 +695,11 @@ class LearnerConfiguration : public Learner { }); } else if (IsA(kv.second)) { stack.push(kv.second); + } else if (kv.first == "metrics") { + auto const& array = get(kv.second); + for (auto const& v : array) { + stack.push(v); + } } } } diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 5f55d85e7..3df3bb65d 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -7,7 +7,6 @@ * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. */ #include -#include #include @@ -16,8 +15,10 @@ #include "../common/math.h" #include "../common/optional_weight.h" // OptionalWeights #include "../common/pseudo_huber.h" +#include "../common/quantile_loss_utils.h" // QuantileLossParam #include "../common/threading_utils.h" #include "metric_common.h" +#include "xgboost/metric.h" #if defined(XGBOOST_USE_CUDA) #include // thrust::cuda::par @@ -421,5 +422,82 @@ XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik") .set_body([](const char* param) { return new EvalEWiseBase(param); }); + +class QuantileError : public Metric { + HostDeviceVector alpha_; + common::QuantileLossParam param_; + + public: + void Configure(Args const& args) override { + param_.UpdateAllowUnknown(args); + param_.Validate(); + alpha_.HostVector() = param_.quantile_alpha.Get(); + } + + double Eval(HostDeviceVector const& preds, const MetaInfo& info) override { + CHECK(!alpha_.Empty()); + if (info.num_row_ == 0) { + // empty DMatrix on distributed env + double dat[2]{0.0, 0.0}; + collective::Allreduce(dat, 2); + CHECK_GT(dat[1], 0); + return dat[0] / dat[1]; + } + + auto const* ctx = ctx_; + auto y_true = info.labels.View(ctx->gpu_id); + preds.SetDevice(ctx->gpu_id); + alpha_.SetDevice(ctx->gpu_id); + auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan(); + std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size(); + CHECK_NE(n_targets, 0); + auto y_predt = linalg::MakeTensorView( + ctx->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(), + {static_cast(info.num_row_), alpha_.Size(), n_targets}, ctx->gpu_id); + + info.weights_.SetDevice(ctx->gpu_id); + common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()}; + + auto result = Reduce( + ctx, info, [=] XGBOOST_DEVICE(std::size_t i, std::size_t sample_id, std::size_t target_id) { + auto idx = linalg::UnravelIndex(i, y_predt.Shape()); + sample_id = std::get<0>(idx); + std::size_t quantile_id = std::get<1>(idx); + target_id = std::get<2>(idx); + + auto loss = [a = alpha[quantile_id]](float p, float y) { + auto d = y - p; + float sign = d >= 0.0f; + auto res = (a * sign * d) - (1.0f - a) * (1.0f - sign) * d; + return res; + }; + auto w = weight[sample_id]; + auto l = + loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w; + return std::make_tuple(l, w); + }); + double dat[2]{result.Residue(), result.Weights()}; + collective::Allreduce(dat, 2); + CHECK_GT(dat[1], 0); + return dat[0] / dat[1]; + } + + const char* Name() const override { return "quantile"; } + void LoadConfig(Json const& in) override { + auto const& name = get(in["name"]); + CHECK_EQ(name, "quantile"); + FromJson(in["quantile_loss_param"], ¶m_); + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(this->Name()); + out["quantile_loss_param"] = ToJson(param_); + } +}; + +XGBOOST_REGISTER_METRIC(QuantileError, "quantile") + .describe("Quantile regression error.") + .set_body([](const char*) { return new QuantileError{}; }); } // namespace metric } // namespace xgboost diff --git a/tests/cpp/common/test_quantile_utils.cc b/tests/cpp/common/test_quantile_utils.cc new file mode 100644 index 000000000..18a816121 --- /dev/null +++ b/tests/cpp/common/test_quantile_utils.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#include + +#include "../../../src/common/quantile_loss_utils.h" +#include "xgboost/base.h" // Args + +namespace xgboost { +namespace common { +TEST(QuantileLossParam, Basic) { + QuantileLossParam param; + auto& ref = param.quantile_alpha.Get(); + + param.UpdateAllowUnknown(Args{{"quantile_alpha", "0.3"}}); + ASSERT_EQ(ref.size(), 1); + ASSERT_NEAR(ref[0], 0.3, kRtEps); + + param.UpdateAllowUnknown(Args{{"quantile_alpha", "[0.3, 0.6]"}}); + ASSERT_EQ(param.quantile_alpha.Get().size(), 2); + ASSERT_NEAR(ref[0], 0.3, kRtEps); + ASSERT_NEAR(ref[1], 0.6, kRtEps); + + param.UpdateAllowUnknown(Args{{"quantile_alpha", "(0.6, 0.3)"}}); + ASSERT_EQ(param.quantile_alpha.Get().size(), 2); + ASSERT_NEAR(ref[0], 0.6, kRtEps); + ASSERT_NEAR(ref[1], 0.3, kRtEps); +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index fde9e42f2..5ac50be2a 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2018-2022 by XGBoost contributors +/** + * Copyright 2018-2023 by XGBoost contributors */ #include #include @@ -311,5 +311,36 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { ASSERT_FLOAT_EQ(ret, loss); ASSERT_FLOAT_EQ(ret, loss_w); } + +TEST(Metric, DeclareUnifiedTest(Quantile)) { + auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + std::unique_ptr metric{Metric::Create("quantile", &ctx)}; + + HostDeviceVector predts{0.1f, 0.9f, 0.1f, 0.9f}; + std::vector labels{0.5f, 0.5f, 0.9f, 0.1f}; + std::vector weights{0.2f, 0.4f,0.6f, 0.8f}; + + metric->Configure(Args{{"quantile_alpha", "[0.0]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.400f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[0.2]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.376f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[0.4]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.352f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[0.8]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.304f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[1.0]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.28f, 0.001f); + + metric->Configure(Args{{"quantile_alpha", "[0.0]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[0.2]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[0.4]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[0.8]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f); + metric->Configure(Args{{"quantile_alpha", "[1.0]"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f); +} } // namespace metric } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_eval_metrics.py b/tests/python-gpu/test_gpu_eval_metrics.py index cb4d8eb6c..2e3b29f99 100644 --- a/tests/python-gpu/test_gpu_eval_metrics.py +++ b/tests/python-gpu/test_gpu_eval_metrics.py @@ -1,8 +1,10 @@ import sys import pytest +from xgboost.testing.metrics import check_quantile_error import xgboost +from xgboost import testing as tm sys.path.append("tests/python") import test_eval_metrics as test_em # noqa @@ -58,3 +60,7 @@ class TestGPUEvalMetrics: def test_pr_auc_ltr(self): self.cpu_test.run_pr_auc_ltr("gpu_hist") + + @pytest.mark.skipif(**tm.no_sklearn()) + def test_quantile_error(self) -> None: + check_quantile_error("gpu_hist") diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 24e3817ce..5b4e5751c 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from xgboost.testing.metrics import check_quantile_error import xgboost as xgb from xgboost import testing as tm @@ -306,10 +307,14 @@ class TestEvalMetrics: group=groups, eval_set=[(X, y)], eval_group=[groups], - eval_metric="aucpr" + eval_metric="aucpr", ) results = ltr.evals_result()["validation_0"]["aucpr"] assert results[-1] >= 0.99 def test_pr_auc_ltr(self): self.run_pr_auc_ltr("hist") + + @pytest.mark.skipif(**tm.no_sklearn()) + def test_quantile_error(self) -> None: + check_quantile_error("hist")