Initial support for quantile loss. (#8750)

- Add support for Python.
- Add objective.
This commit is contained in:
Jiaming Yuan
2023-02-16 02:30:18 +08:00
committed by GitHub
parent 282b1729da
commit cce4af4acf
26 changed files with 701 additions and 70 deletions

View File

@@ -151,6 +151,7 @@ def main(args: argparse.Namespace) -> None:
"demo/guide-python/sklearn_parallel.py",
"demo/guide-python/spark_estimator_examples.py",
"demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",
# CI
"tests/ci_build/lint_python.py",
"tests/ci_build/test_r_package.py",
@@ -193,6 +194,7 @@ def main(args: argparse.Namespace) -> None:
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/feature_weights.py",
"demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",
# tests
"tests/python/test_dt.py",
"tests/python/test_data_iterator.py",

View File

@@ -11,19 +11,20 @@
namespace xgboost {
namespace common {
TEST(Stats, Quantile) {
Context ctx;
{
linalg::Tensor<float, 1> arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId);
std::vector<size_t> index{0, 2, 3, 4, 6};
auto h_arr = arr.HostView();
auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); });
auto end = beg + index.size();
auto q = Quantile(0.40f, beg, end);
auto q = Quantile(&ctx, 0.40f, beg, end);
ASSERT_EQ(q, 26.0);
q = Quantile(0.20f, beg, end);
q = Quantile(&ctx, 0.20f, beg, end);
ASSERT_EQ(q, 16.0);
q = Quantile(0.10f, beg, end);
q = Quantile(&ctx, 0.10f, beg, end);
ASSERT_EQ(q, 15.0);
}
@@ -31,12 +32,13 @@ TEST(Stats, Quantile) {
std::vector<float> vec{1., 2., 3., 4., 5.};
auto beg = MakeIndexTransformIter([&](size_t i) { return vec[i]; });
auto end = beg + vec.size();
auto q = Quantile(0.5f, beg, end);
auto q = Quantile(&ctx, 0.5f, beg, end);
ASSERT_EQ(q, 3.);
}
}
TEST(Stats, WeightedQuantile) {
Context ctx;
linalg::Tensor<float, 1> arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId);
linalg::Tensor<float, 1> weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId);
@@ -47,13 +49,13 @@ TEST(Stats, WeightedQuantile) {
auto end = beg + arr.Size();
auto w = MakeIndexTransformIter([&](size_t i) { return h_weight(i); });
auto q = WeightedQuantile(0.50f, beg, end, w);
auto q = WeightedQuantile(&ctx, 0.50f, beg, end, w);
ASSERT_EQ(q, 3);
q = WeightedQuantile(0.0, beg, end, w);
q = WeightedQuantile(&ctx, 0.0, beg, end, w);
ASSERT_EQ(q, 1);
q = WeightedQuantile(1.0, beg, end, w);
q = WeightedQuantile(&ctx, 1.0, beg, end, w);
ASSERT_EQ(q, 5);
}

View File

@@ -1,4 +1,6 @@
// Copyright by Contributors
/**
* Copyright 2016-2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h>
#include <xgboost/objective.h>
@@ -25,11 +27,14 @@ TEST(Objective, PredTransform) {
tparam.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
size_t n = 100;
for (const auto &entry :
::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
std::unique_ptr<xgboost::ObjFunction> obj{
xgboost::ObjFunction::Create(entry->name, &tparam)};
obj->Configure(Args{{"num_class", "2"}});
for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create(entry->name, &tparam)};
if (entry->name.find("multi") != std::string::npos) {
obj->Configure(Args{{"num_class", "2"}});
}
if (entry->name.find("quantile") != std::string::npos) {
obj->Configure(Args{{"quantile_alpha", "0.5"}});
}
HostDeviceVector<float> predts;
predts.Resize(n, 3.14f); // prediction is performed on host.
ASSERT_FALSE(predts.DeviceCanRead());

View File

@@ -0,0 +1,74 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/base.h> // Args
#include <xgboost/context.h> // Context
#include <xgboost/objective.h> // ObjFunction
#include <xgboost/span.h> // Span
#include <memory> // std::unique_ptr
#include <vector> // std::vector
#include "../helpers.h" // CheckConfigReload,CreateEmptyGenericParam,DeclareUnifiedTest
namespace xgboost {
TEST(Objective, DeclareUnifiedTest(Quantile)) {
Context ctx = CreateEmptyGenericParam(GPUIDX);
{
Args args{{"quantile_alpha", "[0.6, 0.8]"}};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "reg:quantileerror");
}
Args args{{"quantile_alpha", "0.6"}};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "reg:quantileerror");
std::vector<float> predts{1.0f, 2.0f, 3.0f};
std::vector<float> labels{3.0f, 2.0f, 1.0f};
std::vector<float> weights{1.0f, 1.0f, 1.0f};
std::vector<float> grad{-0.6f, 0.4f, 0.4f};
std::vector<float> hess = weights;
CheckObjFunction(obj, predts, labels, weights, grad, hess);
}
TEST(Objective, DeclareUnifiedTest(QuantileIntercept)) {
Context ctx = CreateEmptyGenericParam(GPUIDX);
Args args{{"quantile_alpha", "[0.6, 0.8]"}};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", &ctx)};
obj->Configure(args);
MetaInfo info;
info.num_row_ = 10;
info.labels.ModifyInplace([&](HostDeviceVector<float>* data, common::Span<std::size_t> shape) {
data->SetDevice(ctx.gpu_id);
data->Resize(info.num_row_);
shape[0] = info.num_row_;
shape[1] = 1;
auto& h_labels = data->HostVector();
for (std::size_t i = 0; i < info.num_row_; ++i) {
h_labels[i] = i;
}
});
linalg::Vector<float> base_scores;
obj->InitEstimation(info, &base_scores);
ASSERT_EQ(base_scores.Size(), 1) << "Vector is not yet supported.";
// mean([5.6, 7.8])
ASSERT_NEAR(base_scores(0), 6.7, kRtEps);
for (std::size_t i = 0; i < info.num_row_; ++i) {
info.weights_.HostVector().emplace_back(info.num_row_ - i - 1.0);
}
obj->InitEstimation(info, &base_scores);
ASSERT_EQ(base_scores.Size(), 1) << "Vector is not yet supported.";
// mean([3, 5])
ASSERT_NEAR(base_scores(0), 4.0, kRtEps);
}
} // namespace xgboost

View File

@@ -0,0 +1,5 @@
/**
* Copyright 2023 XGBoost contributors
*/
// Dummy file to enable the CUDA tests.
#include "test_quantile_obj.cc"

View File

@@ -5,7 +5,7 @@ import numpy as np
import pytest
from hypothesis import assume, given, note, settings, strategies
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy
from xgboost.testing.updater import check_init_estimation
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
import xgboost as xgb
from xgboost import testing as tm
@@ -209,3 +209,7 @@ class TestGPUUpdaters:
def test_init_estimation(self) -> None:
check_init_estimation("gpu_hist")
@pytest.mark.parametrize("weighted", [True, False])
def test_quantile_loss(self, weighted: bool) -> None:
check_quantile_loss("gpu_hist", weighted)

View File

@@ -146,6 +146,13 @@ def test_multioutput_reg() -> None:
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_reg() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "quantile_regression.py")
cmd = ['python', script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_ubjson())
def test_json_model() -> None:
script = os.path.join(DEMO_DIR, "json-model", "json_parser.py")

View File

@@ -10,7 +10,7 @@ from xgboost.testing.params import (
exact_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import check_init_estimation
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
import xgboost as xgb
from xgboost import testing as tm
@@ -469,3 +469,7 @@ class TestTreeMethod:
def test_init_estimation(self) -> None:
check_init_estimation("hist")
@pytest.mark.parametrize("weighted", [True, False])
def test_quantile_loss(self, weighted: bool) -> None:
check_quantile_loss("hist", weighted)