Initial support for quantile loss. (#8750)

- Add support for Python.
- Add objective.
This commit is contained in:
Jiaming Yuan 2023-02-16 02:30:18 +08:00 committed by GitHub
parent 282b1729da
commit cce4af4acf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 701 additions and 70 deletions

View File

@ -37,6 +37,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \ $(PKGROOT)/src/objective/adaptive.o \
$(PKGROOT)/src/objective/init_estimation.o \ $(PKGROOT)/src/objective/init_estimation.o \
$(PKGROOT)/src/objective/quantile_obj.o \
$(PKGROOT)/src/gbm/gbm.o \ $(PKGROOT)/src/gbm/gbm.o \
$(PKGROOT)/src/gbm/gbtree.o \ $(PKGROOT)/src/gbm/gbtree.o \
$(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gbtree_model.o \

View File

@ -37,6 +37,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \ $(PKGROOT)/src/objective/adaptive.o \
$(PKGROOT)/src/objective/init_estimation.o \ $(PKGROOT)/src/objective/init_estimation.o \
$(PKGROOT)/src/objective/quantile_obj.o \
$(PKGROOT)/src/gbm/gbm.o \ $(PKGROOT)/src/gbm/gbm.o \
$(PKGROOT)/src/gbm/gbtree.o \ $(PKGROOT)/src/gbm/gbtree.o \
$(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gbtree_model.o \

View File

@ -0,0 +1,124 @@
"""
Quantile Regression
===================
The script is inspired by this awesome example in sklearn:
https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html
"""
import argparse
from typing import Dict
import numpy as np
from sklearn.model_selection import train_test_split
import xgboost as xgb
def f(x: np.ndarray) -> np.ndarray:
"""The function to predict."""
return x * np.sin(x)
def quantile_loss(args: argparse.Namespace) -> None:
"""Train a quantile regression model."""
rng = np.random.RandomState(1994)
# Generate a synthetic dataset for demo, the generate process is from the sklearn
# example.
X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T
expected_y = f(X).ravel()
sigma = 0.5 + X.ravel() / 10.0
noise = rng.lognormal(sigma=sigma) - np.exp(sigma**2.0 / 2.0)
y = expected_y + noise
# Train on 0.05 and 0.95 quantiles. The model is similar to multi-class and
# multi-target models.
alpha = np.array([0.05, 0.5, 0.95])
evals_result: Dict[str, Dict] = {}
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)
# We will be using the `hist` tree method, quantile DMatrix can be used to preserve
# memory.
# Do not use the `exact` tree method for quantile regression, otherwise the
# performance might drop.
Xy = xgb.QuantileDMatrix(X, y)
# use Xy as a reference
Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy)
booster = xgb.train(
{
# Use the quantile objective function.
"objective": "reg:quantileerror",
"tree_method": "hist",
"quantile_alpha": alpha,
# Let's try not to overfit.
"learning_rate": 0.01,
"max_depth": 3,
"min_child_weight": 16.0,
},
Xy,
num_boost_round=32,
early_stopping_rounds=2,
# The evaluation result is a weighted average across multiple quantiles.
evals=[(Xy, "Train"), (Xy_test, "Test")],
evals_result=evals_result,
)
xx = np.atleast_2d(np.linspace(0, 10, 1000)).T
scores = booster.inplace_predict(xx)
# dim 1 is the quantiles
assert scores.shape[0] == xx.shape[0]
assert scores.shape[1] == alpha.shape[0]
y_lower = scores[:, 0] # alpha=0.05
y_med = scores[:, 1] # alpha=0.5, median
y_upper = scores[:, 2] # alpha=0.95
# Train a mse model for comparison
booster = xgb.train(
{
"objective": "reg:squarederror",
"tree_method": "hist",
# Let's try not to overfit.
"learning_rate": 0.01,
"max_depth": 3,
"min_child_weight": 16.0,
},
Xy,
num_boost_round=32,
early_stopping_rounds=2,
evals=[(Xy, "Train"), (Xy_test, "Test")],
evals_result=evals_result,
)
xx = np.atleast_2d(np.linspace(0, 10, 1000)).T
y_pred = booster.inplace_predict(xx)
if args.plot:
from matplotlib import pyplot as plt
fig = plt.figure(figsize=(10, 10))
plt.plot(xx, f(xx), "g:", linewidth=3, label=r"$f(x) = x\,\sin(x)$")
plt.plot(X_test, y_test, "b.", markersize=10, label="Test observations")
plt.plot(xx, y_med, "r-", label="Predicted median")
plt.plot(xx, y_pred, "m-", label="Predicted mean")
plt.plot(xx, y_upper, "k-")
plt.plot(xx, y_lower, "k-")
plt.fill_between(
xx.ravel(), y_lower, y_upper, alpha=0.4, label="Predicted 90% interval"
)
plt.xlabel("$x$")
plt.ylabel("$f(x)$")
plt.ylim(-10, 25)
plt.legend(loc="upper left")
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--plot",
action="store_true",
help="Specify it to enable plotting the outputs.",
)
args = parser.parse_args()
quantile_loss(args)

View File

@ -440,6 +440,20 @@
}, },
"type": "object" "type": "object"
}, },
{
"properties": {
"name": {
"const": "reg:quantileerror"
},
"quantile_loss_param": {
"type": "object",
"properties": {
"quantle_alpha": {"type": "array"}
}
}
},
"type": "object"
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@ -348,6 +348,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:logistic``: logistic regression. - ``reg:logistic``: logistic regression.
- ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss.
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction. If used in distributed training, the leaf value is calculated as the mean value from all workers, which is not guaranteed to be optimal. - ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction. If used in distributed training, the leaf value is calculated as the mean value from all workers, which is not guaranteed to be optimal.
- ``reg:quantileerror``: Quantile loss, also known as ``pinball loss``. See later sections for its parameter and :ref:`sphx_glr_python_examples_quantile_regression.py` for a worked example.
- ``binary:logistic``: logistic regression for binary classification, output probability - ``binary:logistic``: logistic regression for binary classification, output probability
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
@ -441,6 +442,11 @@ Parameter for using Pseudo-Huber (``reg:pseudohubererror``)
* ``huber_slope`` : A parameter used for Pseudo-Huber loss to define the :math:`\delta` term. [default = 1.0] * ``huber_slope`` : A parameter used for Pseudo-Huber loss to define the :math:`\delta` term. [default = 1.0]
Parameter for using Quantile Loss (``reg:quantileerror``)
=========================================================
* ``quantile_alpha``: A scala or a list of targeted quantiles.
*********************** ***********************
Command Line Parameters Command Line Parameters
*********************** ***********************

View File

@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cinttypes> // std::int32_t #include <cinttypes> // std::int32_t
#include <cstddef> // std::size_t
#include <limits> #include <limits>
#include <string> #include <string>
#include <tuple> #include <tuple>
@ -552,6 +553,11 @@ LINALG_HD auto UnravelIndex(size_t idx, common::Span<size_t const, D> shape) {
} }
} }
template <size_t D>
LINALG_HD auto UnravelIndex(size_t idx, std::size_t const (&shape)[D]) {
return UnravelIndex(idx, common::Span<std::size_t const, D>(shape));
}
/** /**
* \brief A view over a vector, specialization of Tensor * \brief A view over a vector, specialization of Tensor
* *

View File

@ -1926,6 +1926,8 @@ class Booster:
elif isinstance(params, str) and value is not None: elif isinstance(params, str) and value is not None:
params = [(params, value)] params = [(params, value)]
for key, val in cast(Iterable[Tuple[str, str]], params): for key, val in cast(Iterable[Tuple[str, str]], params):
if isinstance(val, np.ndarray):
val = val.tolist()
if val is not None: if val is not None:
_check_call( _check_call(
_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))) _LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val)))

View File

@ -1,7 +1,10 @@
"""Tests for updaters.""" """Tests for updaters."""
import json import json
from functools import partial, update_wrapper
from typing import Dict
import numpy as np import numpy as np
import xgboost.testing as tm
import xgboost as xgb import xgboost as xgb
@ -68,3 +71,90 @@ def check_init_estimation(tree_method: str) -> None:
n_samples=4096, n_labels=3, n_classes=5, random_state=17 n_samples=4096, n_labels=3, n_classes=5, random_state=17
) )
run_clf(X, y) run_clf(X, y)
# pylint: disable=too-many-locals
def check_quantile_loss(tree_method: str, weighted: bool) -> None:
"""Test for quantile loss."""
from sklearn.datasets import make_regression
from sklearn.metrics import mean_pinball_loss
from xgboost.sklearn import _metric_decorator
n_samples = 4096
n_features = 8
n_estimators = 8
# non-zero base score can cause floating point difference with GPU predictor.
# multi-class has small difference than single target in the prediction kernel
base_score = 0.0
rng = np.random.RandomState(1994)
# pylint: disable=unbalanced-tuple-unpacking
X, y = make_regression(
n_samples=n_samples,
n_features=n_features,
random_state=rng,
)
if weighted:
weight = rng.random(size=n_samples)
else:
weight = None
Xy = xgb.QuantileDMatrix(X, y, weight=weight)
alpha = np.array([0.1, 0.5])
evals_result: Dict[str, Dict] = {}
booster_multi = xgb.train(
{
"objective": "reg:quantileerror",
"tree_method": tree_method,
"quantile_alpha": alpha,
"base_score": base_score,
},
Xy,
num_boost_round=n_estimators,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
predt_multi = booster_multi.predict(Xy, strict_shape=True)
assert tm.non_increasing(evals_result["Train"]["quantile"])
assert evals_result["Train"]["quantile"][-1] < 20.0
# check that there's a way to use custom metric and compare the results.
metrics = [
_metric_decorator(
update_wrapper(
partial(mean_pinball_loss, sample_weight=weight, alpha=alpha[i]),
mean_pinball_loss,
)
)
for i in range(alpha.size)
]
predts = np.empty(predt_multi.shape)
for i in range(alpha.shape[0]):
a = alpha[i]
booster_i = xgb.train(
{
"objective": "reg:quantileerror",
"tree_method": tree_method,
"quantile_alpha": a,
"base_score": base_score,
},
Xy,
num_boost_round=n_estimators,
evals=[(Xy, "Train")],
custom_metric=metrics[i],
evals_result=evals_result,
)
assert tm.non_increasing(evals_result["Train"]["quantile"])
assert evals_result["Train"]["quantile"][-1] < 30.0
np.testing.assert_allclose(
np.array(evals_result["Train"]["quantile"]),
np.array(evals_result["Train"]["mean_pinball_loss"]),
atol=1e-6,
rtol=1e-6,
)
predts[:, i] = booster_i.predict(Xy)
for i in range(alpha.shape[0]):
np.testing.assert_allclose(predts[:, i], predt_multi[:, i])

View File

@ -35,11 +35,11 @@ void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
auto iter = linalg::cbegin(ti_v); auto iter = linalg::cbegin(ti_v);
float q{0}; float q{0};
if (opt_weights.Empty()) { if (opt_weights.Empty()) {
q = common::Quantile(0.5, iter, iter + ti_v.Size()); q = common::Quantile(ctx, 0.5, iter, iter + ti_v.Size());
} else { } else {
CHECK_NE(t_v.Shape(1), 0); CHECK_NE(t_v.Shape(1), 0);
auto w_it = common::MakeIndexTransformIter([&](std::size_t i) { return opt_weights[i]; }); auto w_it = common::MakeIndexTransformIter([&](std::size_t i) { return opt_weights[i]; });
q = common::WeightedQuantile(0.5, iter, iter + ti_v.Size(), w_it); q = common::WeightedQuantile(ctx, 0.5, iter, iter + ti_v.Size(), w_it);
} }
h_out(i) = q; h_out(i) = q;
} }

View File

@ -4,43 +4,49 @@
#ifndef XGBOOST_COMMON_STATS_H_ #ifndef XGBOOST_COMMON_STATS_H_
#define XGBOOST_COMMON_STATS_H_ #define XGBOOST_COMMON_STATS_H_
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator> // for distance
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "algorithm.h" // for StableSort
#include "common.h" // AssertGPUSupport, OptionalWeights #include "common.h" // AssertGPUSupport, OptionalWeights
#include "optional_weight.h" // OptionalWeights #include "optional_weight.h" // OptionalWeights
#include "transform_iterator.h" // MakeIndexTransformIter #include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
#include "xgboost/linalg.h" #include "xgboost/linalg.h" // TensorView,VectorView
#include "xgboost/logging.h" // CHECK_GE #include "xgboost/logging.h" // CHECK_GE
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/** /**
* \brief Percentile with masked array using linear interpolation. * @brief Quantile using linear interpolation.
* *
* https://www.itl.nist.gov/div898/handbook/prc/section2/prc262.htm * https://www.itl.nist.gov/div898/handbook/prc/section2/prc262.htm
* *
* \param alpha Percentile, must be in range [0, 1]. * \param alpha Quantile, must be in range [0, 1].
* \param begin Iterator begin for input array. * \param begin Iterator begin for input array.
* \param end Iterator end for input array. * \param end Iterator end for input array.
* *
* \return The result of interpolation. * \return The result of interpolation.
*/ */
template <typename Iter> template <typename Iter>
float Quantile(double alpha, Iter const& begin, Iter const& end) { float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const& end) {
CHECK(alpha >= 0 && alpha <= 1); CHECK(alpha >= 0 && alpha <= 1);
auto n = static_cast<double>(std::distance(begin, end)); auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) { if (n == 0) {
return std::numeric_limits<float>::quiet_NaN(); return std::numeric_limits<float>::quiet_NaN();
} }
std::vector<size_t> sorted_idx(n); std::vector<std::size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0); std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), if (omp_in_parallel()) {
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); }); std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); }; auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
static_assert(std::is_same<decltype(val(0)), float>::value, ""); static_assert(std::is_same<decltype(val(0)), float>::value, "");
@ -51,7 +57,7 @@ float Quantile(double alpha, Iter const& begin, Iter const& end) {
if (alpha >= (n / (n + 1))) { if (alpha >= (n / (n + 1))) {
return val(sorted_idx.size() - 1); return val(sorted_idx.size() - 1);
} }
assert(n != 0 && "The number of rows in a leaf can not be zero.");
double x = alpha * static_cast<double>((n + 1)); double x = alpha * static_cast<double>((n + 1));
double k = std::floor(x) - 1; double k = std::floor(x) - 1;
CHECK_GE(k, 0); CHECK_GE(k, 0);
@ -66,30 +72,35 @@ float Quantile(double alpha, Iter const& begin, Iter const& end) {
* \brief Calculate the weighted quantile with step function. Unlike the unweighted * \brief Calculate the weighted quantile with step function. Unlike the unweighted
* version, no interpolation is used. * version, no interpolation is used.
* *
* See https://aakinshin.net/posts/weighted-quantiles/ for some discussion on computing * See https://aakinshin.net/posts/weighted-quantiles/ for some discussions on computing
* weighted quantile with interpolation. * weighted quantile with interpolation.
*/ */
template <typename Iter, typename WeightIter> template <typename Iter, typename WeightIter>
float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) { float WeightedQuantile(Context const* ctx, double alpha, Iter begin, Iter end, WeightIter w_begin) {
auto n = static_cast<double>(std::distance(begin, end)); auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) { if (n == 0) {
return std::numeric_limits<float>::quiet_NaN(); return std::numeric_limits<float>::quiet_NaN();
} }
std::vector<size_t> sorted_idx(n); std::vector<size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0); std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), if (omp_in_parallel()) {
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); }); std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); }; auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
std::vector<float> weight_cdf(n); // S_n std::vector<float> weight_cdf(n); // S_n
// weighted cdf is sorted during construction // weighted cdf is sorted during construction
weight_cdf[0] = *(weights + sorted_idx[0]); weight_cdf[0] = *(w_begin + sorted_idx[0]);
for (size_t i = 1; i < n; ++i) { for (size_t i = 1; i < n; ++i) {
weight_cdf[i] = weight_cdf[i - 1] + *(weights + sorted_idx[i]); weight_cdf[i] = weight_cdf[i - 1] + w_begin[sorted_idx[i]];
} }
float thresh = weight_cdf.back() * alpha; float thresh = weight_cdf.back() * alpha;
size_t idx = std::size_t idx =
std::lower_bound(weight_cdf.cbegin(), weight_cdf.cend(), thresh) - weight_cdf.cbegin(); std::lower_bound(weight_cdf.cbegin(), weight_cdf.cend(), thresh) - weight_cdf.cbegin();
idx = std::min(idx, static_cast<size_t>(n - 1)); idx = std::min(idx, static_cast<size_t>(n - 1));
return val(idx); return val(idx);

View File

@ -3,17 +3,25 @@
*/ */
#include "adaptive.h" #include "adaptive.h"
#include <limits> #include <algorithm> // std::transform,std::find_if,std::copy,std::unique
#include <vector> #include <cmath> // std::isnan
#include <cstddef> // std::size_t
#include <iterator> // std::distance
#include <vector> // std::vector
#include "../common/algorithm.h" // ArgSort #include "../common/algorithm.h" // ArgSort
#include "../common/common.h" // AssertGPUSupport
#include "../common/numeric.h" // RunLengthEncode #include "../common/numeric.h" // RunLengthEncode
#include "../common/stats.h" // Quantile,WeightedQuantile #include "../common/stats.h" // Quantile,WeightedQuantile
#include "../common/threading_utils.h" // ParallelFor #include "../common/threading_utils.h" // ParallelFor
#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/base.h" // bst_node_t
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
#include "xgboost/linalg.h" #include "xgboost/data.h" // MetaInfo
#include "xgboost/tree_model.h" #include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // MakeTensorView
#include "xgboost/span.h" // Span
#include "xgboost/tree_model.h" // RegTree
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
@ -100,8 +108,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
CHECK_LT(k + 1, h_node_ptr.size()); CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k]; size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n); auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
CHECK_LE(group_idx, info.labels.Shape(1));
auto h_labels = info.labels.HostView().Slice(linalg::All(), group_idx); auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_); auto h_weights = linalg::MakeVec(&info.weights_);
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float { auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
@ -115,9 +123,9 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
float q{0}; float q{0};
if (info.weights_.Empty()) { if (info.weights_.Empty()) {
q = common::Quantile(alpha, iter, iter + h_row_set.size()); q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else { } else {
q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it); q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
} }
if (std::isnan(q)) { if (std::isnan(q)) {
CHECK(h_row_set.empty()); CHECK(h_row_set.empty());
@ -127,6 +135,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
UpdateLeafValues(&quantiles, nidx, p_tree); UpdateLeafValues(&quantiles, nidx, p_tree);
} }
#if !defined(XGBOOST_USE_CUDA)
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
MetaInfo const&, HostDeviceVector<float> const&, float, RegTree*) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace detail } // namespace detail
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -20,20 +20,19 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) { HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
// copy position to buffer // copy position to buffer
dh::safe_cuda(cudaSetDevice(ctx->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
auto cuctx = ctx->CUDACtx();
size_t n_samples = position.size(); size_t n_samples = position.size();
dh::XGBDeviceAllocator<char> alloc;
dh::device_vector<bst_node_t> sorted_position(position.size()); dh::device_vector<bst_node_t> sorted_position(position.size());
dh::safe_cuda(cudaMemcpyAsync(sorted_position.data().get(), position.data(), dh::safe_cuda(cudaMemcpyAsync(sorted_position.data().get(), position.data(),
position.size_bytes(), cudaMemcpyDeviceToDevice)); position.size_bytes(), cudaMemcpyDeviceToDevice, cuctx->Stream()));
p_ridx->resize(position.size()); p_ridx->resize(position.size());
dh::Iota(dh::ToSpan(*p_ridx)); dh::Iota(dh::ToSpan(*p_ridx));
// sort row index according to node index // sort row index according to node index
thrust::stable_sort_by_key(thrust::cuda::par(alloc), sorted_position.begin(), thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(),
sorted_position.begin() + n_samples, p_ridx->begin()); sorted_position.begin() + n_samples, p_ridx->begin());
dh::XGBCachingDeviceAllocator<char> caching;
size_t beg_pos = size_t beg_pos =
thrust::find_if(thrust::cuda::par(caching), sorted_position.cbegin(), sorted_position.cend(), thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
[] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) - [] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) -
sorted_position.cbegin(); sorted_position.cbegin();
if (beg_pos == sorted_position.size()) { if (beg_pos == sorted_position.size()) {
@ -72,7 +71,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
size_t* h_num_runs = reinterpret_cast<size_t*>(pinned.subspan(0, sizeof(size_t)).data()); size_t* h_num_runs = reinterpret_cast<size_t*>(pinned.subspan(0, sizeof(size_t)).data());
dh::CUDAEvent e; dh::CUDAEvent e;
e.Record(dh::DefaultStream()); e.Record(cuctx->Stream());
copy_stream.View().Wait(e); copy_stream.View().Wait(e);
// flag for whether there's ignored position // flag for whether there's ignored position
bst_node_t* h_first_unique = bst_node_t* h_first_unique =
@ -108,7 +107,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
d_node_ptr[0] = beg_pos; d_node_ptr[0] = beg_pos;
} }
}); });
thrust::inclusive_scan(thrust::cuda::par(caching), dh::tbegin(d_node_ptr), dh::tend(d_node_ptr), thrust::inclusive_scan(cuctx->CTP(), dh::tbegin(d_node_ptr), dh::tend(d_node_ptr),
dh::tbegin(d_node_ptr)); dh::tbegin(d_node_ptr));
copy_stream.View().Sync(); copy_stream.View().Sync();
CHECK_GT(*h_num_runs, 0); CHECK_GT(*h_num_runs, 0);
@ -162,7 +161,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id); {info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
CHECK_LT(group_idx, d_predt.Shape(1)); CHECK_LT(group_idx, d_predt.Shape(1));
auto t_predt = d_predt.Slice(linalg::All(), group_idx); auto t_predt = d_predt.Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), group_idx); auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx); auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer(); auto seg_beg = nptr.DevicePointer();

View File

@ -6,13 +6,15 @@
#include <algorithm> #include <algorithm>
#include <cstdint> // std::int32_t #include <cstdint> // std::int32_t
#include <limits> #include <limits>
#include <vector> #include <vector> // std::vector
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "../common/common.h" #include "../common/common.h"
#include "xgboost/context.h" #include "xgboost/base.h" // bst_node_t
#include "xgboost/host_device_vector.h" #include "xgboost/context.h" // Context
#include "xgboost/tree_model.h" #include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/tree_model.h" // RegTree
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
@ -73,6 +75,15 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
} }
} }
inline std::size_t IdxY(MetaInfo const& info, bst_group_t group_idx) {
std::size_t y_idx{0};
if (info.labels.Shape(1) > 1) {
y_idx = group_idx;
}
CHECK_LE(y_idx, info.labels.Shape(1));
return y_idx;
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position, void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
std::int32_t group_idx, MetaInfo const& info, std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree); HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
@ -81,5 +92,18 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
std::int32_t group_idx, MetaInfo const& info, std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree); HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
} // namespace detail } // namespace detail
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
if (ctx->IsCPU()) {
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, predt, alpha,
p_tree);
} else {
position.SetDevice(ctx->gpu_id);
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, predt, alpha,
p_tree);
}
}
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -44,11 +44,13 @@ namespace obj {
// List of files that will be force linked in static links. // List of files that will be force linked in static links.
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu); DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu);
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu); DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu); DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
DMLC_REGISTRY_LINK_TAG(rank_obj_gpu); DMLC_REGISTRY_LINK_TAG(rank_obj_gpu);
#else #else
DMLC_REGISTRY_LINK_TAG(regression_obj); DMLC_REGISTRY_LINK_TAG(regression_obj);
DMLC_REGISTRY_LINK_TAG(quantile_obj);
DMLC_REGISTRY_LINK_TAG(hinge_obj); DMLC_REGISTRY_LINK_TAG(hinge_obj);
DMLC_REGISTRY_LINK_TAG(multiclass_obj); DMLC_REGISTRY_LINK_TAG(multiclass_obj);
DMLC_REGISTRY_LINK_TAG(rank_obj); DMLC_REGISTRY_LINK_TAG(rank_obj);

View File

@ -0,0 +1,18 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
// Dummy file to enable the CUDA conditional compile trick.
#include <dmlc/registry.h>
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(quantile_obj);
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#include "quantile_obj.cu"
#endif // !defined(XBGOOST_USE_CUDA)

View File

@ -0,0 +1,226 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <vector> // std::vector
#include "../common/linalg_op.h" // ElementWiseKernel,cbegin,cend
#include "../common/quantile_loss_utils.h" // QuantileLossParam
#include "../common/stats.h" // Quantile,WeightedQuantile
#include "adaptive.h" // UpdateTreeLeaf
#include "dmlc/parameter.h" // DMLC_DECLARE_PARAMETER
#include "init_estimation.h" // CheckInitInputs
#include "xgboost/base.h" // GradientPair,XGBOOST_DEVICE,bst_target_t
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/json.h" // Json,String,ToJson,FromJson
#include "xgboost/linalg.h" // Tensor,MakeTensorView,MakeVec
#include "xgboost/objective.h" // ObjFunction
#include "xgboost/parameter.h" // XGBoostParameter
#if defined(XGBOOST_USE_CUDA)
#include "../common/linalg_op.cuh" // ElementWiseKernel
#include "../common/stats.cuh" // SegmentedQuantile
#endif // defined(XGBOOST_USE_CUDA)
namespace xgboost {
namespace obj {
class QuantileRegression : public ObjFunction {
common::QuantileLossParam param_;
HostDeviceVector<float> alpha_;
bst_target_t Targets(MetaInfo const& info) const override {
auto const& alpha = param_.quantile_alpha.Get();
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target is not yet supported by the quantile loss.";
CHECK(!alpha.empty());
// We have some placeholders for multi-target in the quantile loss. But it's not
// supported as the gbtree doesn't know how to slice the gradient and there's no 3-dim
// model shape in general.
auto n_y = std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
return alpha_.Size() * n_y;
}
public:
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info, std::int32_t iter,
HostDeviceVector<GradientPair>* out_gpair) override {
if (iter == 0) {
CheckInitInputs(info);
}
CHECK_EQ(param_.quantile_alpha.Get().size(), alpha_.Size());
using SizeT = decltype(info.num_row_);
SizeT n_targets = this->Targets(info);
SizeT n_alphas = alpha_.Size();
CHECK_NE(n_alphas, 0);
CHECK_GE(n_targets, n_alphas);
CHECK_EQ(preds.Size(), info.num_row_ * n_targets);
auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Resize(n_targets * info.num_row_);
auto gpair =
linalg::MakeTensorView(ctx_->IsCPU() ? out_gpair->HostSpan() : out_gpair->DeviceSpan(),
{info.num_row_, n_alphas, n_targets / n_alphas}, ctx_->gpu_id);
info.weights_.SetDevice(ctx_->gpu_id);
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
preds.SetDevice(ctx_->gpu_id);
auto predt = linalg::MakeVec(&preds);
auto n_samples = info.num_row_;
alpha_.SetDevice(ctx_->gpu_id);
auto alpha = ctx_->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
linalg::ElementWiseKernel(
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
auto idx = linalg::UnravelIndex(
i, {n_samples, static_cast<SizeT>(alpha.size()), n_targets / alpha.size()});
// std::tie is not available for cuda kernel.
std::size_t sample_id = std::get<0>(idx);
std::size_t quantile_id = std::get<1>(idx);
std::size_t target_id = std::get<2>(idx);
auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id];
if (d >= 0) {
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id];
gpair(sample_id, quantile_id, target_id) = GradientPair{g, h};
} else {
auto g = (-alpha[quantile_id] * weight[sample_id]);
gpair(sample_id, quantile_id, target_id) = GradientPair{g, h};
}
});
}
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override {
CHECK(!alpha_.Empty());
auto n_targets = this->Targets(info);
base_score->SetDevice(ctx_->gpu_id);
base_score->Reshape(n_targets);
double sw{0};
if (ctx_->IsCPU()) {
auto quantiles = base_score->HostView();
auto h_weights = info.weights_.ConstHostVector();
if (info.weights_.Empty()) {
sw = info.num_row_;
} else {
sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0);
}
for (bst_target_t t{0}; t < n_targets; ++t) {
auto alpha = param_.quantile_alpha[t];
auto h_labels = info.labels.HostView();
if (h_weights.empty()) {
quantiles(t) =
common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels));
} else {
CHECK_EQ(h_weights.size(), h_labels.Size());
quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels),
linalg::cend(h_labels), std::cbegin(h_weights));
}
}
} else {
#if defined(XGBOOST_USE_CUDA)
alpha_.SetDevice(ctx_->gpu_id);
auto d_alpha = alpha_.ConstDeviceSpan();
auto d_labels = info.labels.View(ctx_->gpu_id);
auto seg_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return i * d_labels.Shape(0); });
CHECK_EQ(d_labels.Shape(1), 1);
auto val_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
auto sample_idx = i % d_labels.Shape(0);
return d_labels(sample_idx, 0);
});
auto n = d_labels.Size() * d_alpha.size();
CHECK_EQ(base_score->Size(), d_alpha.size());
if (info.weights_.Empty()) {
common::SegmentedQuantile(ctx_, d_alpha.data(), seg_it, seg_it + d_alpha.size() + 1, val_it,
val_it + n, base_score->Data());
sw = info.num_row_;
} else {
info.weights_.SetDevice(ctx_->gpu_id);
auto d_weights = info.weights_.ConstDeviceSpan();
auto weight_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
auto sample_idx = i % d_labels.Shape(0);
return d_weights[sample_idx];
});
common::SegmentedWeightedQuantile(ctx_, d_alpha.data(), seg_it, seg_it + d_alpha.size() + 1,
val_it, val_it + n, weight_it, weight_it + n,
base_score->Data());
sw = dh::Reduce(ctx_->CUDACtx()->CTP(), dh::tcbegin(d_weights), dh::tcend(d_weights), 0.0,
thrust::plus<double>{});
}
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
// For multiple quantiles, we should extend the base score to a vector instead of
// computing the average. For now, this is a workaround.
linalg::Vector<float> temp;
common::Mean(ctx_, *base_score, &temp);
double meanq = temp(0) * sw;
collective::Allreduce<collective::Operation::kSum>(&meanq, 1);
collective::Allreduce<collective::Operation::kSum>(&sw, 1);
meanq /= (sw + kRtEps);
base_score->Reshape(1);
base_score->Data()->Fill(meanq);
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
RegTree* p_tree) const override {
auto alpha = param_.quantile_alpha[group_idx];
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, alpha, p_tree);
}
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
param_.Validate();
this->alpha_.HostVector() = param_.quantile_alpha.Get();
}
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
static char const* Name() { return "reg:quantileerror"; }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(Name());
out["quantile_loss_param"] = ToJson(param_);
}
void LoadConfig(Json const& in) override {
CHECK_EQ(get<String const>(in["name"]), Name());
FromJson(in["quantile_loss_param"], &param_);
alpha_.HostVector() = param_.quantile_alpha.Get();
}
const char* DefaultEvalMetric() const override { return "quantile"; }
Json DefaultMetricConfig() const override {
CHECK(param_.GetInitialised());
Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()};
config["quantile_loss_param"] = ToJson(param_);
return config;
}
};
XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name())
.describe("Regression with quantile loss.")
.set_body([]() { return new QuantileRegression(); });
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
} // namespace obj
} // namespace xgboost

View File

@ -1,15 +1,16 @@
/*! /**
* Copyright 2017-2022 XGBoost contributors * Copyright 2017-2023 by XGBoost contributors
*/ */
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ #ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ #define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <cmath> #include <cmath>
#include "../common/math.h" #include "../common/math.h"
#include "xgboost/data.h" // MetaInfo
#include "xgboost/logging.h"
#include "xgboost/task.h" // ObjInfo #include "xgboost/task.h" // ObjInfo
namespace xgboost { namespace xgboost {
@ -105,7 +106,6 @@ struct LogisticRaw : public LogisticRegression {
static ObjInfo Info() { return ObjInfo::kRegression; } static ObjInfo Info() { return ObjInfo::kRegression; }
}; };
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -744,18 +744,7 @@ class MeanAbsoluteError : public ObjFunction {
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info, void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, std::int32_t group_idx, HostDeviceVector<float> const& prediction, std::int32_t group_idx,
RegTree* p_tree) const override { RegTree* p_tree) const override {
if (ctx_->IsCPU()) { ::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, 0.5, p_tree);
auto const& h_position = position.ConstHostVector();
detail::UpdateTreeLeafHost(ctx_, h_position, group_idx, info, prediction, 0.5, p_tree);
} else {
#if defined(XGBOOST_USE_CUDA)
position.SetDevice(ctx_->gpu_id);
auto d_position = position.ConstDeviceSpan();
detail::UpdateTreeLeafDevice(ctx_, d_position, group_idx, info, prediction, 0.5, p_tree);
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
} }
const char* DefaultEvalMetric() const override { return "mae"; } const char* DefaultEvalMetric() const override { return "mae"; }

View File

@ -151,6 +151,7 @@ def main(args: argparse.Namespace) -> None:
"demo/guide-python/sklearn_parallel.py", "demo/guide-python/sklearn_parallel.py",
"demo/guide-python/spark_estimator_examples.py", "demo/guide-python/spark_estimator_examples.py",
"demo/guide-python/individual_trees.py", "demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",
# CI # CI
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
"tests/ci_build/test_r_package.py", "tests/ci_build/test_r_package.py",
@ -193,6 +194,7 @@ def main(args: argparse.Namespace) -> None:
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/feature_weights.py", "demo/guide-python/feature_weights.py",
"demo/guide-python/individual_trees.py", "demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",
# tests # tests
"tests/python/test_dt.py", "tests/python/test_dt.py",
"tests/python/test_data_iterator.py", "tests/python/test_data_iterator.py",

View File

@ -11,19 +11,20 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
TEST(Stats, Quantile) { TEST(Stats, Quantile) {
Context ctx;
{ {
linalg::Tensor<float, 1> arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId); linalg::Tensor<float, 1> arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId);
std::vector<size_t> index{0, 2, 3, 4, 6}; std::vector<size_t> index{0, 2, 3, 4, 6};
auto h_arr = arr.HostView(); auto h_arr = arr.HostView();
auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); }); auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); });
auto end = beg + index.size(); auto end = beg + index.size();
auto q = Quantile(0.40f, beg, end); auto q = Quantile(&ctx, 0.40f, beg, end);
ASSERT_EQ(q, 26.0); ASSERT_EQ(q, 26.0);
q = Quantile(0.20f, beg, end); q = Quantile(&ctx, 0.20f, beg, end);
ASSERT_EQ(q, 16.0); ASSERT_EQ(q, 16.0);
q = Quantile(0.10f, beg, end); q = Quantile(&ctx, 0.10f, beg, end);
ASSERT_EQ(q, 15.0); ASSERT_EQ(q, 15.0);
} }
@ -31,12 +32,13 @@ TEST(Stats, Quantile) {
std::vector<float> vec{1., 2., 3., 4., 5.}; std::vector<float> vec{1., 2., 3., 4., 5.};
auto beg = MakeIndexTransformIter([&](size_t i) { return vec[i]; }); auto beg = MakeIndexTransformIter([&](size_t i) { return vec[i]; });
auto end = beg + vec.size(); auto end = beg + vec.size();
auto q = Quantile(0.5f, beg, end); auto q = Quantile(&ctx, 0.5f, beg, end);
ASSERT_EQ(q, 3.); ASSERT_EQ(q, 3.);
} }
} }
TEST(Stats, WeightedQuantile) { TEST(Stats, WeightedQuantile) {
Context ctx;
linalg::Tensor<float, 1> arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId); linalg::Tensor<float, 1> arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId);
linalg::Tensor<float, 1> weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId); linalg::Tensor<float, 1> weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId);
@ -47,13 +49,13 @@ TEST(Stats, WeightedQuantile) {
auto end = beg + arr.Size(); auto end = beg + arr.Size();
auto w = MakeIndexTransformIter([&](size_t i) { return h_weight(i); }); auto w = MakeIndexTransformIter([&](size_t i) { return h_weight(i); });
auto q = WeightedQuantile(0.50f, beg, end, w); auto q = WeightedQuantile(&ctx, 0.50f, beg, end, w);
ASSERT_EQ(q, 3); ASSERT_EQ(q, 3);
q = WeightedQuantile(0.0, beg, end, w); q = WeightedQuantile(&ctx, 0.0, beg, end, w);
ASSERT_EQ(q, 1); ASSERT_EQ(q, 1);
q = WeightedQuantile(1.0, beg, end, w); q = WeightedQuantile(&ctx, 1.0, beg, end, w);
ASSERT_EQ(q, 5); ASSERT_EQ(q, 5);
} }

View File

@ -1,4 +1,6 @@
// Copyright by Contributors /**
* Copyright 2016-2023 by XGBoost contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/context.h> #include <xgboost/context.h>
#include <xgboost/objective.h> #include <xgboost/objective.h>
@ -25,11 +27,14 @@ TEST(Objective, PredTransform) {
tparam.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); tparam.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
size_t n = 100; size_t n = 100;
for (const auto &entry : for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create(entry->name, &tparam)};
std::unique_ptr<xgboost::ObjFunction> obj{ if (entry->name.find("multi") != std::string::npos) {
xgboost::ObjFunction::Create(entry->name, &tparam)}; obj->Configure(Args{{"num_class", "2"}});
obj->Configure(Args{{"num_class", "2"}}); }
if (entry->name.find("quantile") != std::string::npos) {
obj->Configure(Args{{"quantile_alpha", "0.5"}});
}
HostDeviceVector<float> predts; HostDeviceVector<float> predts;
predts.Resize(n, 3.14f); // prediction is performed on host. predts.Resize(n, 3.14f); // prediction is performed on host.
ASSERT_FALSE(predts.DeviceCanRead()); ASSERT_FALSE(predts.DeviceCanRead());

View File

@ -0,0 +1,74 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/base.h> // Args
#include <xgboost/context.h> // Context
#include <xgboost/objective.h> // ObjFunction
#include <xgboost/span.h> // Span
#include <memory> // std::unique_ptr
#include <vector> // std::vector
#include "../helpers.h" // CheckConfigReload,CreateEmptyGenericParam,DeclareUnifiedTest
namespace xgboost {
TEST(Objective, DeclareUnifiedTest(Quantile)) {
Context ctx = CreateEmptyGenericParam(GPUIDX);
{
Args args{{"quantile_alpha", "[0.6, 0.8]"}};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "reg:quantileerror");
}
Args args{{"quantile_alpha", "0.6"}};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "reg:quantileerror");
std::vector<float> predts{1.0f, 2.0f, 3.0f};
std::vector<float> labels{3.0f, 2.0f, 1.0f};
std::vector<float> weights{1.0f, 1.0f, 1.0f};
std::vector<float> grad{-0.6f, 0.4f, 0.4f};
std::vector<float> hess = weights;
CheckObjFunction(obj, predts, labels, weights, grad, hess);
}
TEST(Objective, DeclareUnifiedTest(QuantileIntercept)) {
Context ctx = CreateEmptyGenericParam(GPUIDX);
Args args{{"quantile_alpha", "[0.6, 0.8]"}};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", &ctx)};
obj->Configure(args);
MetaInfo info;
info.num_row_ = 10;
info.labels.ModifyInplace([&](HostDeviceVector<float>* data, common::Span<std::size_t> shape) {
data->SetDevice(ctx.gpu_id);
data->Resize(info.num_row_);
shape[0] = info.num_row_;
shape[1] = 1;
auto& h_labels = data->HostVector();
for (std::size_t i = 0; i < info.num_row_; ++i) {
h_labels[i] = i;
}
});
linalg::Vector<float> base_scores;
obj->InitEstimation(info, &base_scores);
ASSERT_EQ(base_scores.Size(), 1) << "Vector is not yet supported.";
// mean([5.6, 7.8])
ASSERT_NEAR(base_scores(0), 6.7, kRtEps);
for (std::size_t i = 0; i < info.num_row_; ++i) {
info.weights_.HostVector().emplace_back(info.num_row_ - i - 1.0);
}
obj->InitEstimation(info, &base_scores);
ASSERT_EQ(base_scores.Size(), 1) << "Vector is not yet supported.";
// mean([3, 5])
ASSERT_NEAR(base_scores(0), 4.0, kRtEps);
}
} // namespace xgboost

View File

@ -0,0 +1,5 @@
/**
* Copyright 2023 XGBoost contributors
*/
// Dummy file to enable the CUDA tests.
#include "test_quantile_obj.cc"

View File

@ -5,7 +5,7 @@ import numpy as np
import pytest import pytest
from hypothesis import assume, given, note, settings, strategies from hypothesis import assume, given, note, settings, strategies
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy
from xgboost.testing.updater import check_init_estimation from xgboost.testing.updater import check_init_estimation, check_quantile_loss
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
@ -209,3 +209,7 @@ class TestGPUUpdaters:
def test_init_estimation(self) -> None: def test_init_estimation(self) -> None:
check_init_estimation("gpu_hist") check_init_estimation("gpu_hist")
@pytest.mark.parametrize("weighted", [True, False])
def test_quantile_loss(self, weighted: bool) -> None:
check_quantile_loss("gpu_hist", weighted)

View File

@ -146,6 +146,13 @@ def test_multioutput_reg() -> None:
subprocess.check_call(cmd) subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_reg() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "quantile_regression.py")
cmd = ['python', script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_ubjson()) @pytest.mark.skipif(**tm.no_ubjson())
def test_json_model() -> None: def test_json_model() -> None:
script = os.path.join(DEMO_DIR, "json-model", "json_parser.py") script = os.path.join(DEMO_DIR, "json-model", "json_parser.py")

View File

@ -10,7 +10,7 @@ from xgboost.testing.params import (
exact_parameter_strategy, exact_parameter_strategy,
hist_parameter_strategy, hist_parameter_strategy,
) )
from xgboost.testing.updater import check_init_estimation from xgboost.testing.updater import check_init_estimation, check_quantile_loss
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
@ -469,3 +469,7 @@ class TestTreeMethod:
def test_init_estimation(self) -> None: def test_init_estimation(self) -> None:
check_init_estimation("hist") check_init_estimation("hist")
@pytest.mark.parametrize("weighted", [True, False])
def test_quantile_loss(self, weighted: bool) -> None:
check_quantile_loss("hist", weighted)