Add quantile metric. (#8761)
This commit is contained in:
parent
d11a0044cf
commit
457f704e3d
@ -91,6 +91,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/common/survival_util.o \
|
$(PKGROOT)/src/common/survival_util.o \
|
||||||
$(PKGROOT)/src/common/threading_utils.o \
|
$(PKGROOT)/src/common/threading_utils.o \
|
||||||
$(PKGROOT)/src/common/ranking_utils.o \
|
$(PKGROOT)/src/common/ranking_utils.o \
|
||||||
|
$(PKGROOT)/src/common/quantile_loss_utils.o \
|
||||||
$(PKGROOT)/src/common/timer.o \
|
$(PKGROOT)/src/common/timer.o \
|
||||||
$(PKGROOT)/src/common/version.o \
|
$(PKGROOT)/src/common/version.o \
|
||||||
$(PKGROOT)/src/c_api/c_api.o \
|
$(PKGROOT)/src/c_api/c_api.o \
|
||||||
|
|||||||
@ -91,6 +91,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/common/survival_util.o \
|
$(PKGROOT)/src/common/survival_util.o \
|
||||||
$(PKGROOT)/src/common/threading_utils.o \
|
$(PKGROOT)/src/common/threading_utils.o \
|
||||||
$(PKGROOT)/src/common/ranking_utils.o \
|
$(PKGROOT)/src/common/ranking_utils.o \
|
||||||
|
$(PKGROOT)/src/common/quantile_loss_utils.o \
|
||||||
$(PKGROOT)/src/common/timer.o \
|
$(PKGROOT)/src/common/timer.o \
|
||||||
$(PKGROOT)/src/common/version.o \
|
$(PKGROOT)/src/common/version.o \
|
||||||
$(PKGROOT)/src/c_api/c_api.o \
|
$(PKGROOT)/src/c_api/c_api.o \
|
||||||
|
|||||||
27
python-package/xgboost/testing/metrics.py
Normal file
27
python-package/xgboost/testing/metrics.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""Tests for evaluation metrics."""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
|
||||||
|
def check_quantile_error(tree_method: str) -> None:
|
||||||
|
"""Test for the `quantile` loss."""
|
||||||
|
from sklearn.datasets import make_regression
|
||||||
|
from sklearn.metrics import mean_pinball_loss
|
||||||
|
|
||||||
|
rng = np.random.RandomState(19)
|
||||||
|
# pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
X, y = make_regression(128, 3, random_state=rng)
|
||||||
|
Xy = xgb.QuantileDMatrix(X, y)
|
||||||
|
evals_result: Dict[str, Dict] = {}
|
||||||
|
booster = xgb.train(
|
||||||
|
{"tree_method": tree_method, "eval_metric": "quantile", "quantile_alpha": 0.3},
|
||||||
|
Xy,
|
||||||
|
evals=[(Xy, "Train")],
|
||||||
|
evals_result=evals_result,
|
||||||
|
)
|
||||||
|
predt = booster.inplace_predict(X)
|
||||||
|
loss = mean_pinball_loss(y, predt, alpha=0.3)
|
||||||
|
np.testing.assert_allclose(evals_result["Train"]["quantile"][-1], loss)
|
||||||
74
src/common/quantile_loss_utils.cc
Normal file
74
src/common/quantile_loss_utils.cc
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "quantile_loss_utils.h"
|
||||||
|
|
||||||
|
#include <cctype> // std::isspace
|
||||||
|
#include <istream> // std::istream
|
||||||
|
#include <ostream> // std::ostream
|
||||||
|
#include <string> // std::string
|
||||||
|
#include <vector> // std::vector
|
||||||
|
|
||||||
|
#include "xgboost/json.h" // F32Array,TypeCheck,get,Number
|
||||||
|
#include "xgboost/json_io.h" // JsonWriter
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& array) {
|
||||||
|
auto const& t = array.Get();
|
||||||
|
xgboost::F32Array arr{t.size()};
|
||||||
|
for (std::size_t i = 0; i < t.size(); ++i) {
|
||||||
|
arr.Set(i, t[i]);
|
||||||
|
}
|
||||||
|
std::vector<char> stream;
|
||||||
|
xgboost::JsonWriter writer{&stream};
|
||||||
|
arr.Save(&writer);
|
||||||
|
for (auto c : stream) {
|
||||||
|
os << c;
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::istream& operator>>(std::istream& is, ParamFloatArray& array) {
|
||||||
|
auto& t = array.Get();
|
||||||
|
t.clear();
|
||||||
|
std::string str;
|
||||||
|
while (!is.eof()) {
|
||||||
|
std::string tmp;
|
||||||
|
is >> tmp;
|
||||||
|
str += tmp;
|
||||||
|
}
|
||||||
|
std::size_t head{0};
|
||||||
|
// unify notation for parsing.
|
||||||
|
while (std::isspace(str[head])) {
|
||||||
|
++head;
|
||||||
|
}
|
||||||
|
if (str[head] == '(') {
|
||||||
|
str[head] = '[';
|
||||||
|
}
|
||||||
|
auto tail = str.size() - 1;
|
||||||
|
while (std::isspace(str[tail])) {
|
||||||
|
--tail;
|
||||||
|
}
|
||||||
|
if (str[tail] == ')') {
|
||||||
|
str[tail] = ']';
|
||||||
|
}
|
||||||
|
|
||||||
|
auto jarr = xgboost::Json::Load(xgboost::StringView{str});
|
||||||
|
// return if there's only one element
|
||||||
|
if (xgboost::IsA<xgboost::Number>(jarr)) {
|
||||||
|
t.emplace_back(xgboost::get<xgboost::Number const>(jarr));
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto jvec = xgboost::get<xgboost::Array const>(jarr);
|
||||||
|
for (auto v : jvec) {
|
||||||
|
xgboost::TypeCheck<xgboost::Number>(v, "alpha");
|
||||||
|
t.emplace_back(get<xgboost::Number const>(v));
|
||||||
|
}
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(QuantileLossParam);
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
51
src/common/quantile_loss_utils.h
Normal file
51
src/common/quantile_loss_utils.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
|
||||||
|
#define XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
|
||||||
|
|
||||||
|
#include <algorithm> // std::all_of
|
||||||
|
#include <istream> // std::istream
|
||||||
|
#include <ostream> // std::ostream
|
||||||
|
#include <vector> // std::vector
|
||||||
|
|
||||||
|
#include "xgboost/logging.h" // CHECK
|
||||||
|
#include "xgboost/parameter.h" // XGBoostParameter
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
// A shim to enable ADL for parameter parsing. Alternatively, we can put the stream
|
||||||
|
// operators in std namespace, which seems to be less ideal.
|
||||||
|
class ParamFloatArray {
|
||||||
|
std::vector<float> values_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::vector<float>& Get() { return values_; }
|
||||||
|
std::vector<float> const& Get() const { return values_; }
|
||||||
|
decltype(values_)::const_reference operator[](decltype(values_)::size_type i) const {
|
||||||
|
return values_[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// For parsing quantile parameters. Input can be a string to a single float or a list of
|
||||||
|
// floats.
|
||||||
|
std::ostream& operator<<(std::ostream& os, const ParamFloatArray& t);
|
||||||
|
std::istream& operator>>(std::istream& is, ParamFloatArray& t);
|
||||||
|
|
||||||
|
struct QuantileLossParam : public XGBoostParameter<QuantileLossParam> {
|
||||||
|
ParamFloatArray quantile_alpha;
|
||||||
|
DMLC_DECLARE_PARAMETER(QuantileLossParam) {
|
||||||
|
DMLC_DECLARE_FIELD(quantile_alpha).describe("List of quantiles for quantile loss.");
|
||||||
|
}
|
||||||
|
void Validate() const {
|
||||||
|
CHECK(GetInitialised());
|
||||||
|
CHECK(!quantile_alpha.Get().empty());
|
||||||
|
auto const& array = quantile_alpha.Get();
|
||||||
|
auto valid =
|
||||||
|
std::all_of(array.cbegin(), array.cend(), [](auto q) { return q >= 0.0 && q <= 1.0; });
|
||||||
|
CHECK(valid) << "quantile alpha must be in the range [0.0, 1.0].";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_QUANTILE_LOSS_UTILS_H_
|
||||||
@ -695,6 +695,11 @@ class LearnerConfiguration : public Learner {
|
|||||||
});
|
});
|
||||||
} else if (IsA<Object>(kv.second)) {
|
} else if (IsA<Object>(kv.second)) {
|
||||||
stack.push(kv.second);
|
stack.push(kv.second);
|
||||||
|
} else if (kv.first == "metrics") {
|
||||||
|
auto const& array = get<Array const>(kv.second);
|
||||||
|
for (auto const& v : array) {
|
||||||
|
stack.push(v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
|
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
|
||||||
*/
|
*/
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <xgboost/metric.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
@ -16,8 +15,10 @@
|
|||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/optional_weight.h" // OptionalWeights
|
#include "../common/optional_weight.h" // OptionalWeights
|
||||||
#include "../common/pseudo_huber.h"
|
#include "../common/pseudo_huber.h"
|
||||||
|
#include "../common/quantile_loss_utils.h" // QuantileLossParam
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
|
#include "xgboost/metric.h"
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||||
@ -421,5 +422,82 @@ XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
|
|||||||
.set_body([](const char* param) {
|
.set_body([](const char* param) {
|
||||||
return new EvalEWiseBase<EvalTweedieNLogLik>(param);
|
return new EvalEWiseBase<EvalTweedieNLogLik>(param);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
class QuantileError : public Metric {
|
||||||
|
HostDeviceVector<float> alpha_;
|
||||||
|
common::QuantileLossParam param_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Configure(Args const& args) override {
|
||||||
|
param_.UpdateAllowUnknown(args);
|
||||||
|
param_.Validate();
|
||||||
|
alpha_.HostVector() = param_.quantile_alpha.Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
double Eval(HostDeviceVector<bst_float> const& preds, const MetaInfo& info) override {
|
||||||
|
CHECK(!alpha_.Empty());
|
||||||
|
if (info.num_row_ == 0) {
|
||||||
|
// empty DMatrix on distributed env
|
||||||
|
double dat[2]{0.0, 0.0};
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
|
CHECK_GT(dat[1], 0);
|
||||||
|
return dat[0] / dat[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const* ctx = ctx_;
|
||||||
|
auto y_true = info.labels.View(ctx->gpu_id);
|
||||||
|
preds.SetDevice(ctx->gpu_id);
|
||||||
|
alpha_.SetDevice(ctx->gpu_id);
|
||||||
|
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
|
||||||
|
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
|
||||||
|
CHECK_NE(n_targets, 0);
|
||||||
|
auto y_predt = linalg::MakeTensorView(
|
||||||
|
ctx->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(),
|
||||||
|
{static_cast<std::size_t>(info.num_row_), alpha_.Size(), n_targets}, ctx->gpu_id);
|
||||||
|
|
||||||
|
info.weights_.SetDevice(ctx->gpu_id);
|
||||||
|
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()
|
||||||
|
: info.weights_.ConstDeviceSpan()};
|
||||||
|
|
||||||
|
auto result = Reduce(
|
||||||
|
ctx, info, [=] XGBOOST_DEVICE(std::size_t i, std::size_t sample_id, std::size_t target_id) {
|
||||||
|
auto idx = linalg::UnravelIndex(i, y_predt.Shape());
|
||||||
|
sample_id = std::get<0>(idx);
|
||||||
|
std::size_t quantile_id = std::get<1>(idx);
|
||||||
|
target_id = std::get<2>(idx);
|
||||||
|
|
||||||
|
auto loss = [a = alpha[quantile_id]](float p, float y) {
|
||||||
|
auto d = y - p;
|
||||||
|
float sign = d >= 0.0f;
|
||||||
|
auto res = (a * sign * d) - (1.0f - a) * (1.0f - sign) * d;
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
auto w = weight[sample_id];
|
||||||
|
auto l =
|
||||||
|
loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w;
|
||||||
|
return std::make_tuple(l, w);
|
||||||
|
});
|
||||||
|
double dat[2]{result.Residue(), result.Weights()};
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
||||||
|
CHECK_GT(dat[1], 0);
|
||||||
|
return dat[0] / dat[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* Name() const override { return "quantile"; }
|
||||||
|
void LoadConfig(Json const& in) override {
|
||||||
|
auto const& name = get<String const>(in["name"]);
|
||||||
|
CHECK_EQ(name, "quantile");
|
||||||
|
FromJson(in["quantile_loss_param"], ¶m_);
|
||||||
|
}
|
||||||
|
void SaveConfig(Json* p_out) const override {
|
||||||
|
auto& out = *p_out;
|
||||||
|
out["name"] = String(this->Name());
|
||||||
|
out["quantile_loss_param"] = ToJson(param_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(QuantileError, "quantile")
|
||||||
|
.describe("Quantile regression error.")
|
||||||
|
.set_body([](const char*) { return new QuantileError{}; });
|
||||||
} // namespace metric
|
} // namespace metric
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
30
tests/cpp/common/test_quantile_utils.cc
Normal file
30
tests/cpp/common/test_quantile_utils.cc
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../../../src/common/quantile_loss_utils.h"
|
||||||
|
#include "xgboost/base.h" // Args
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
TEST(QuantileLossParam, Basic) {
|
||||||
|
QuantileLossParam param;
|
||||||
|
auto& ref = param.quantile_alpha.Get();
|
||||||
|
|
||||||
|
param.UpdateAllowUnknown(Args{{"quantile_alpha", "0.3"}});
|
||||||
|
ASSERT_EQ(ref.size(), 1);
|
||||||
|
ASSERT_NEAR(ref[0], 0.3, kRtEps);
|
||||||
|
|
||||||
|
param.UpdateAllowUnknown(Args{{"quantile_alpha", "[0.3, 0.6]"}});
|
||||||
|
ASSERT_EQ(param.quantile_alpha.Get().size(), 2);
|
||||||
|
ASSERT_NEAR(ref[0], 0.3, kRtEps);
|
||||||
|
ASSERT_NEAR(ref[1], 0.6, kRtEps);
|
||||||
|
|
||||||
|
param.UpdateAllowUnknown(Args{{"quantile_alpha", "(0.6, 0.3)"}});
|
||||||
|
ASSERT_EQ(param.quantile_alpha.Get().size(), 2);
|
||||||
|
ASSERT_NEAR(ref[0], 0.6, kRtEps);
|
||||||
|
ASSERT_NEAR(ref[1], 0.3, kRtEps);
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018-2022 by XGBoost contributors
|
* Copyright 2018-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
@ -311,5 +311,36 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) {
|
|||||||
ASSERT_FLOAT_EQ(ret, loss);
|
ASSERT_FLOAT_EQ(ret, loss);
|
||||||
ASSERT_FLOAT_EQ(ret, loss_w);
|
ASSERT_FLOAT_EQ(ret, loss_w);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Metric, DeclareUnifiedTest(Quantile)) {
|
||||||
|
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||||
|
std::unique_ptr<Metric> metric{Metric::Create("quantile", &ctx)};
|
||||||
|
|
||||||
|
HostDeviceVector<float> predts{0.1f, 0.9f, 0.1f, 0.9f};
|
||||||
|
std::vector<float> labels{0.5f, 0.5f, 0.9f, 0.1f};
|
||||||
|
std::vector<float> weights{0.2f, 0.4f,0.6f, 0.8f};
|
||||||
|
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.0]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.400f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.2]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.376f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.4]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.352f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.8]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.304f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, weights), 0.28f, 0.001f);
|
||||||
|
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.0]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.2]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.4]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[0.8]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
|
||||||
|
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels), 0.3f, 0.001f);
|
||||||
|
}
|
||||||
} // namespace metric
|
} // namespace metric
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from xgboost.testing.metrics import check_quantile_error
|
||||||
|
|
||||||
import xgboost
|
import xgboost
|
||||||
|
from xgboost import testing as tm
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
import test_eval_metrics as test_em # noqa
|
import test_eval_metrics as test_em # noqa
|
||||||
@ -58,3 +60,7 @@ class TestGPUEvalMetrics:
|
|||||||
|
|
||||||
def test_pr_auc_ltr(self):
|
def test_pr_auc_ltr(self):
|
||||||
self.cpu_test.run_pr_auc_ltr("gpu_hist")
|
self.cpu_test.run_pr_auc_ltr("gpu_hist")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_quantile_error(self) -> None:
|
||||||
|
check_quantile_error("gpu_hist")
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from xgboost.testing.metrics import check_quantile_error
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
@ -306,10 +307,14 @@ class TestEvalMetrics:
|
|||||||
group=groups,
|
group=groups,
|
||||||
eval_set=[(X, y)],
|
eval_set=[(X, y)],
|
||||||
eval_group=[groups],
|
eval_group=[groups],
|
||||||
eval_metric="aucpr"
|
eval_metric="aucpr",
|
||||||
)
|
)
|
||||||
results = ltr.evals_result()["validation_0"]["aucpr"]
|
results = ltr.evals_result()["validation_0"]["aucpr"]
|
||||||
assert results[-1] >= 0.99
|
assert results[-1] >= 0.99
|
||||||
|
|
||||||
def test_pr_auc_ltr(self):
|
def test_pr_auc_ltr(self):
|
||||||
self.run_pr_auc_ltr("hist")
|
self.run_pr_auc_ltr("hist")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_quantile_error(self) -> None:
|
||||||
|
check_quantile_error("hist")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user