Add quantile metric. (#8761)

This commit is contained in:
Jiaming Yuan 2023-02-13 19:07:40 +08:00 committed by GitHub
parent d11a0044cf
commit 457f704e3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 313 additions and 4 deletions

View File

@ -91,6 +91,7 @@ OBJECTS= \
$(PKGROOT)/src/common/survival_util.o \ $(PKGROOT)/src/common/survival_util.o \
$(PKGROOT)/src/common/threading_utils.o \ $(PKGROOT)/src/common/threading_utils.o \
$(PKGROOT)/src/common/ranking_utils.o \ $(PKGROOT)/src/common/ranking_utils.o \
$(PKGROOT)/src/common/quantile_loss_utils.o \
$(PKGROOT)/src/common/timer.o \ $(PKGROOT)/src/common/timer.o \
$(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \ $(PKGROOT)/src/c_api/c_api.o \

View File

@ -91,6 +91,7 @@ OBJECTS= \
$(PKGROOT)/src/common/survival_util.o \ $(PKGROOT)/src/common/survival_util.o \
$(PKGROOT)/src/common/threading_utils.o \ $(PKGROOT)/src/common/threading_utils.o \
$(PKGROOT)/src/common/ranking_utils.o \ $(PKGROOT)/src/common/ranking_utils.o \
$(PKGROOT)/src/common/quantile_loss_utils.o \
$(PKGROOT)/src/common/timer.o \ $(PKGROOT)/src/common/timer.o \
$(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \ $(PKGROOT)/src/c_api/c_api.o \

View File

@ -0,0 +1,27 @@
"""Tests for evaluation metrics."""
from typing import Dict
import numpy as np
import xgboost as xgb
def check_quantile_error(tree_method: str) -> None:
"""Test for the `quantile` loss."""
from sklearn.datasets import make_regression
from sklearn.metrics import mean_pinball_loss
rng = np.random.RandomState(19)
# pylint: disable=unbalanced-tuple-unpacking
X, y = make_regression(128, 3, random_state=rng)
Xy = xgb.QuantileDMatrix(X, y)
evals_result: Dict[str, Dict] = {}
booster = xgb.train(
{"tree_method": tree_method, "eval_metric": "quantile", "quantile_alpha": 0.3},
Xy,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
predt = booster.inplace_predict(X)
loss = mean_pinball_loss(y, predt, alpha=0.3)
np.testing.assert_allclose(evals_result["Train"]["quantile"][-1], loss)

View File

@ -0,0 +1,74 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include "quantile_loss_utils.h"
#include <cctype> // std::isspace
#include <istream> // std::istream
#include <ostream> // std::ostream
#include <string> // std::string
#include <vector> // std::vector
#include "xgboost/json.h" // F32Array,TypeCheck,get,Number
#include "xgboost/json_io.h" // JsonWriter
namespace xgboost {
namespace common {
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
auto const& t = array.Get();
xgboost::F32Array arr{t.size()};
for (std::size_t i = 0; i < t.size(); ++i) {
arr.Set(i, t[i]);
}
std::vector<char> stream;
xgboost::JsonWriter writer{&stream};
arr.Save(&writer);
for (auto c : stream) {
os << c;
}
return os;
}
std::istream& operator>>(std::istream& is, ParamFloatArray& array) {
auto& t = array.Get();
t.clear();
std::string str;
while (!is.eof()) {
std::string tmp;
is >> tmp;
str += tmp;
}
std::size_t head{0};
// unify notation for parsing.
while (std::isspace(str[head])) {
++head;
}
if (str[head] == '(') {
str[head] = '[';
}
auto tail = str.size() - 1;
while (std::isspace(str[tail])) {
--tail;
}
if (str[tail] == ')') {
str[tail] = ']';
}
auto jarr = xgboost::Json::Load(xgboost::StringView{str});
// return if there's only one element
if (xgboost::IsA<xgboost::Number>(jarr)) {
t.emplace_back(xgboost::get<xgboost::Number const>(jarr));
return is;
}
auto jvec = xgboost::get<xgboost::Array const>(jarr);
for (auto v : jvec) {
xgboost::TypeCheck<xgboost::Number>(v, "alpha");
t.emplace_back(get<xgboost::Number const>(v));
}
return is;
}
DMLC_REGISTER_PARAMETER(QuantileLossParam);
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,51 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#ifndef XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
#define XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
#include <algorithm> // std::all_of
#include <istream> // std::istream
#include <ostream> // std::ostream
#include <vector> // std::vector
#include "xgboost/logging.h" // CHECK
#include "xgboost/parameter.h" // XGBoostParameter
namespace xgboost {
namespace common {
// A shim to enable ADL for parameter parsing. Alternatively, we can put the stream
// operators in std namespace, which seems to be less ideal.
class ParamFloatArray {
std::vector<float> values_;
public:
std::vector<float>& Get() { return values_; }
std::vector<float> const& Get() const { return values_; }
decltype(values_)::const_reference operator[](decltype(values_)::size_type i) const {
return values_[i];
}
};
// For parsing quantile parameters. Input can be a string to a single float or a list of
// floats.
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& t);
std::istream& operator>>(std::istream& is, ParamFloatArray& t);
struct QuantileLossParam : public XGBoostParameter<QuantileLossParam> {
ParamFloatArray quantile_alpha;
DMLC_DECLARE_PARAMETER(QuantileLossParam) {
DMLC_DECLARE_FIELD(quantile_alpha).describe("List of quantiles for quantile loss.");
}
void Validate() const {
CHECK(GetInitialised());
CHECK(!quantile_alpha.Get().empty());
auto const& array = quantile_alpha.Get();
auto valid =
std::all_of(array.cbegin(), array.cend(), [](auto q) { return q >= 0.0 && q <= 1.0; });
CHECK(valid) << "quantile alpha must be in the range [0.0, 1.0].";
}
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_

View File

@ -695,6 +695,11 @@ class LearnerConfiguration : public Learner {
}); });
} else if (IsA<Object>(kv.second)) { } else if (IsA<Object>(kv.second)) {
stack.push(kv.second); stack.push(kv.second);
} else if (kv.first == "metrics") {
auto const& array = get<Array const>(kv.second);
for (auto const& v : array) {
stack.push(v);
}
} }
} }
} }

View File

@ -7,7 +7,6 @@
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
*/ */
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <xgboost/metric.h>
#include <cmath> #include <cmath>
@ -16,8 +15,10 @@
#include "../common/math.h" #include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights #include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h" #include "../common/pseudo_huber.h"
#include "../common/quantile_loss_utils.h" // QuantileLossParam
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "metric_common.h" #include "metric_common.h"
#include "xgboost/metric.h"
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par #include <thrust/execution_policy.h> // thrust::cuda::par
@ -421,5 +422,82 @@ XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
.set_body([](const char* param) { .set_body([](const char* param) {
return new EvalEWiseBase<EvalTweedieNLogLik>(param); return new EvalEWiseBase<EvalTweedieNLogLik>(param);
}); });
class QuantileError : public Metric {
HostDeviceVector<float> alpha_;
common::QuantileLossParam param_;
public:
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
param_.Validate();
alpha_.HostVector() = param_.quantile_alpha.Get();
}
double Eval(HostDeviceVector<bst_float> const& preds, const MetaInfo& info) override {
CHECK(!alpha_.Empty());
if (info.num_row_ == 0) {
// empty DMatrix on distributed env
double dat[2]{0.0, 0.0};
collective::Allreduce<collective::Operation::kSum>(dat, 2);
CHECK_GT(dat[1], 0);
return dat[0] / dat[1];
}
auto const* ctx = ctx_;
auto y_true = info.labels.View(ctx->gpu_id);
preds.SetDevice(ctx->gpu_id);
alpha_.SetDevice(ctx->gpu_id);
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
CHECK_NE(n_targets, 0);
auto y_predt = linalg::MakeTensorView(
ctx->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(),
{static_cast<std::size_t>(info.num_row_), alpha_.Size(), n_targets}, ctx->gpu_id);
info.weights_.SetDevice(ctx->gpu_id);
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
auto result = Reduce(
ctx, info, [=] XGBOOST_DEVICE(std::size_t i, std::size_t sample_id, std::size_t target_id) {
auto idx = linalg::UnravelIndex(i, y_predt.Shape());
sample_id = std::get<0>(idx);
std::size_t quantile_id = std::get<1>(idx);
target_id = std::get<2>(idx);
auto loss = [a = alpha[quantile_id]](float p, float y) {
auto d = y - p;
float sign = d >= 0.0f;
auto res = (a * sign * d) - (1.0f - a) * (1.0f - sign) * d;
return res;
};
auto w = weight[sample_id];
auto l =
loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w;
return std::make_tuple(l, w);
});
double dat[2]{result.Residue(), result.Weights()};
collective::Allreduce<collective::Operation::kSum>(dat, 2);
CHECK_GT(dat[1], 0);
return dat[0] / dat[1];
}
const char* Name() const override { return "quantile"; }
void LoadConfig(Json const& in) override {
auto const& name = get<String const>(in["name"]);
CHECK_EQ(name, "quantile");
FromJson(in["quantile_loss_param"], &param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(this->Name());
out["quantile_loss_param"] = ToJson(param_);
}
};
XGBOOST_REGISTER_METRIC(QuantileError, "quantile")
.describe("Quantile regression error.")
.set_body([](const char*) { return new QuantileError{}; });
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,30 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/quantile_loss_utils.h"
#include "xgboost/base.h" // Args
namespace xgboost {
namespace common {
TEST(QuantileLossParam, Basic) {
QuantileLossParam param;
auto& ref = param.quantile_alpha.Get();
param.UpdateAllowUnknown(Args{{"quantile_alpha", "0.3"}});
ASSERT_EQ(ref.size(), 1);
ASSERT_NEAR(ref[0], 0.3, kRtEps);
param.UpdateAllowUnknown(Args{{"quantile_alpha", "[0.3, 0.6]"}});
ASSERT_EQ(param.quantile_alpha.Get().size(), 2);
ASSERT_NEAR(ref[0], 0.3, kRtEps);
ASSERT_NEAR(ref[1], 0.6, kRtEps);
param.UpdateAllowUnknown(Args{{"quantile_alpha", "(0.6, 0.3)"}});
ASSERT_EQ(param.quantile_alpha.Get().size(), 2);
ASSERT_NEAR(ref[0], 0.6, kRtEps);
ASSERT_NEAR(ref[1], 0.3, kRtEps);
}
} // namespace common
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2018-2022 by XGBoost contributors * Copyright 2018-2023 by XGBoost contributors
*/ */
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/metric.h> #include <xgboost/metric.h>
@ -311,5 +311,36 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) {
ASSERT_FLOAT_EQ(ret, loss); ASSERT_FLOAT_EQ(ret, loss);
ASSERT_FLOAT_EQ(ret, loss_w); ASSERT_FLOAT_EQ(ret, loss_w);
} }
TEST(Metric, DeclareUnifiedTest(Quantile)) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> metric{Metric::Create("quantile", &ctx)};
HostDeviceVector<float> predts{0.1f, 0.9f, 0.1f, 0.9f};
std::vector<float> labels{0.5f, 0.5f, 0.9f, 0.1f};
std::vector<float> weights{0.2f, 0.4f,0.6f, 0.8f};
metric->Configure(Args{{"quantile_alpha", "[0.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.400f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.2]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.376f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.4]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.352f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.8]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.304f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.28f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.2]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.4]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[0.8]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
}
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost

View File

@ -1,8 +1,10 @@
import sys import sys
import pytest import pytest
from xgboost.testing.metrics import check_quantile_error
import xgboost import xgboost
from xgboost import testing as tm
sys.path.append("tests/python") sys.path.append("tests/python")
import test_eval_metrics as test_em # noqa import test_eval_metrics as test_em # noqa
@ -58,3 +60,7 @@ class TestGPUEvalMetrics:
def test_pr_auc_ltr(self): def test_pr_auc_ltr(self):
self.cpu_test.run_pr_auc_ltr("gpu_hist") self.cpu_test.run_pr_auc_ltr("gpu_hist")
@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_error(self) -> None:
check_quantile_error("gpu_hist")

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
import pytest import pytest
from xgboost.testing.metrics import check_quantile_error
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
@ -306,10 +307,14 @@ class TestEvalMetrics:
group=groups, group=groups,
eval_set=[(X, y)], eval_set=[(X, y)],
eval_group=[groups], eval_group=[groups],
eval_metric="aucpr" eval_metric="aucpr",
) )
results = ltr.evals_result()["validation_0"]["aucpr"] results = ltr.evals_result()["validation_0"]["aucpr"]
assert results[-1] >= 0.99 assert results[-1] >= 0.99
def test_pr_auc_ltr(self): def test_pr_auc_ltr(self):
self.run_pr_auc_ltr("hist") self.run_pr_auc_ltr("hist")
@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_error(self) -> None:
check_quantile_error("hist")