Rework the precision metric. (#9222)

- Rework the precision metric for both CPU and GPU.
- Mention it in the document.
- Cleanup old support code for GPU ranking metric.
- Deterministic GPU implementation.

* Drop support for classification.

* type.

* use batch shape.

* lint.

* cpu build.

* cpu build.

* lint.

* Tests.

* Fix.

* Cleanup error message.
This commit is contained in:
Jiaming Yuan 2023-06-02 20:49:43 +08:00 committed by GitHub
parent db8288121d
commit 9fbde21e9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 312 additions and 502 deletions

View File

@ -319,7 +319,7 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
# maximize is usually NULL when not set in xgb.train and built-in metrics # maximize is usually NULL when not set in xgb.train and built-in metrics
if (is.null(maximize)) if (is.null(maximize))
maximize <<- grepl('(_auc|_map|_ndcg)', metric_name) maximize <<- grepl('(_auc|_map|_ndcg|_pre)', metric_name)
if (verbose && NVL(env$rank, 0) == 0) if (verbose && NVL(env$rank, 0) == 0)
cat("Will train until ", metric_name, " hasn't improved in ", cat("Will train until ", metric_name, " hasn't improved in ",

View File

@ -424,6 +424,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation. After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation.
- ``pre``: Precision at :math:`k`. Supports only learning to rank task.
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_ - ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ - ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
@ -435,7 +436,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries. where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.
- ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation. - ``ndcg@n``, ``map@n``, ``pre@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions. - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
- ``poisson-nloglik``: negative log-likelihood for Poisson regression - ``poisson-nloglik``: negative log-likelihood for Poisson regression
- ``gamma-nloglik``: negative log-likelihood for gamma regression - ``gamma-nloglik``: negative log-likelihood for gamma regression

View File

@ -372,6 +372,8 @@ class EarlyStopping(TrainingCallback):
maximize_metrics = ( maximize_metrics = (
"auc", "auc",
"aucpr", "aucpr",
"pre",
"pre@",
"map", "map",
"ndcg", "ndcg",
"auc@", "auc@",

View File

@ -1,9 +1,61 @@
"""Tests for evaluation metrics.""" """Tests for evaluation metrics."""
from typing import Dict from typing import Dict, List
import numpy as np import numpy as np
import pytest
import xgboost as xgb import xgboost as xgb
from xgboost.compat import concat
from xgboost.core import _parse_eval_str
def check_precision_score(tree_method: str) -> None:
"""Test for precision with ranking and classification."""
datasets = pytest.importorskip("sklearn.datasets")
X, y = datasets.make_classification(
n_samples=1024, n_features=4, n_classes=2, random_state=2023
)
qid = np.zeros(shape=y.shape) # same group
ltr = xgb.XGBRanker(n_estimators=2, tree_method=tree_method)
ltr.fit(X, y, qid=qid)
# re-generate so that XGBoost doesn't evaluate the result to 1.0
X, y = datasets.make_classification(
n_samples=512, n_features=4, n_classes=2, random_state=1994
)
ltr.set_params(eval_metric="pre@32")
result = _parse_eval_str(
ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y), "Xy")])
)
score_0 = result[1][1]
X_list = []
y_list = []
n_query_groups = 3
q_list: List[np.ndarray] = []
for i in range(n_query_groups):
# same for all groups
X, y = datasets.make_classification(
n_samples=512, n_features=4, n_classes=2, random_state=1994
)
X_list.append(X)
y_list.append(y)
q = np.full(shape=y.shape, fill_value=i, dtype=np.uint64)
q_list.append(q)
qid = concat(q_list)
X = concat(X_list)
y = concat(y_list)
result = _parse_eval_str(
ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y, qid=qid), "Xy")])
)
assert result[1][0].endswith("pre@32")
score_1 = result[1][1]
assert score_1 == score_0
def check_quantile_error(tree_method: str) -> None: def check_quantile_error(tree_method: str) -> None:

View File

@ -825,176 +825,6 @@ XGBOOST_DEVICE auto tcrend(xgboost::common::Span<T> const &span) { // NOLINT
return tcrbegin(span) + span.size(); return tcrbegin(span) + span.size();
} }
// This type sorts an array which is divided into multiple groups. The sorting is influenced
// by the function object 'Comparator'
template <typename T>
class SegmentSorter {
private:
// Items sorted within the group
caching_device_vector<T> ditems_;
// Original position of the items before they are sorted descending within their groups
caching_device_vector<uint32_t> doriginal_pos_;
// Segments within the original list that delineates the different groups
caching_device_vector<uint32_t> group_segments_;
// Need this on the device as it is used in the kernels
caching_device_vector<uint32_t> dgroups_; // Group information on device
// Where did the item that was originally present at position 'x' move to after they are sorted
caching_device_vector<uint32_t> dindexable_sorted_pos_;
// Initialize everything but the segments
void Init(uint32_t num_elems) {
ditems_.resize(num_elems);
doriginal_pos_.resize(num_elems);
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
}
// Initialize all with group info
void Init(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
this->Init(num_elems);
this->CreateGroupSegments(groups);
}
public:
// This needs to be public due to device lambda
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
group_segments_.resize(num_elems, 0);
dgroups_ = groups;
if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them
// Define the segments by assigning a group ID to each element
const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size();
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
dgroups - 1;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(num_elems),
group_segments_.begin(),
ComputeGroupIDLambda);
}
// Accessors that returns device pointer
inline uint32_t GetNumItems() const { return ditems_.size(); }
inline const xgboost::common::Span<const T> GetItemsSpan() const {
return { ditems_.data().get(), ditems_.size() };
}
inline const xgboost::common::Span<const uint32_t> GetOriginalPositionsSpan() const {
return { doriginal_pos_.data().get(), doriginal_pos_.size() };
}
inline const xgboost::common::Span<const uint32_t> GetGroupSegmentsSpan() const {
return { group_segments_.data().get(), group_segments_.size() };
}
inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; }
inline const xgboost::common::Span<const uint32_t> GetGroupsSpan() const {
return { dgroups_.data().get(), dgroups_.size() };
}
inline const xgboost::common::Span<const uint32_t> GetIndexableSortedPositionsSpan() const {
return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() };
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the host.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
const Comparator &comp = Comparator()) {
this->Init(groups);
this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp);
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the device.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size,
const xgboost::common::Span<const uint32_t> &group_segments,
const Comparator &comp = Comparator()) {
this->Init(item_size);
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
// when comparators are used. Hence, the following algorithm is used. This is done so that
// we can grab the appropriate related values from the original list later, after the
// items are sorted.
//
// Here is the internal representation:
// dgroups_: [ 0, 3, 5, 8, 10 ]
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
//
// Sort the items first and make a note of the original positions in doriginal_pos_
// based on the sort
// ditems_: 4 4 3 3 2 1 1 1 1 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
// in kernel, sorting using predicates etc.
ditems_.assign(thrust::device_ptr<const T>(ditems),
thrust::device_ptr<const T>(ditems) + item_size);
// Allocator to be used by sort for managing space overhead while sorting
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
ditems_.begin(), ditems_.end(),
doriginal_pos_.begin(), comp);
if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
// holisitic item sort order on the segments
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
caching_device_vector<uint32_t> group_segments_c(item_size);
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
dh::tcbegin(group_segments), group_segments_c.begin());
// Now, sort the group segments so that you may bring the items within the group together,
// in the process also noting the relative changes to the doriginal_pos_ while that happens
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
group_segments_c.begin(), group_segments_c.end(),
doriginal_pos_.begin(), thrust::less<uint32_t>());
// Finally, gather the original items based on doriginal_pos_ to sort the input and
// to store them in ditems_
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
thrust::device_ptr<const T>(ditems), ditems_.begin());
}
// Determine where an item that was originally present at position 'x' has been relocated to
// after a sort. Creation of such an index has to be explicitly requested after a sort
void CreateIndexableSortedPositions() {
dindexable_sorted_pos_.resize(GetNumItems());
thrust::scatter(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(GetNumItems()), // Rearrange indices...
// ...based on this map
dh::tcbegin(GetOriginalPositionsSpan()),
dindexable_sorted_pos_.begin()); // Write results into this
}
};
// Atomic add function for gradients // Atomic add function for gradients
template <typename OutputGradientT, typename InputGradientT> template <typename OutputGradientT, typename InputGradientT>
XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,

View File

@ -8,8 +8,7 @@
#include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/span.h" // Span #include "xgboost/span.h" // Span
namespace xgboost { namespace xgboost::common {
namespace common {
struct OptionalWeights { struct OptionalWeights {
Span<float const> weights; Span<float const> weights;
float dft{1.0f}; // fixme: make this compile time constant float dft{1.0f}; // fixme: make this compile time constant
@ -18,7 +17,8 @@ struct OptionalWeights {
explicit OptionalWeights(float w) : dft{w} {} explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
auto Empty() const { return weights.empty(); } [[nodiscard]] auto Empty() const { return weights.empty(); }
[[nodiscard]] auto Size() const { return weights.size(); }
}; };
inline OptionalWeights MakeOptionalWeights(Context const* ctx, inline OptionalWeights MakeOptionalWeights(Context const* ctx,
@ -28,6 +28,5 @@ inline OptionalWeights MakeOptionalWeights(Context const* ctx,
} }
return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()}; return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_ #endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_

View File

@ -90,6 +90,9 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid
MetaInfo const &info, float missing) { MetaInfo const &info, float missing) {
auto const &h_weights = auto const &h_weights =
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector()); (use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());
if (!use_group_ind_ && !h_weights.empty()) {
CHECK_EQ(h_weights.size(), batch.Size()) << "Invalid size of sample weight.";
}
auto is_valid = data::IsValidFunctor{missing}; auto is_valid = data::IsValidFunctor{missing};
auto weights = OptionalWeights{Span<float const>{h_weights}}; auto weights = OptionalWeights{Span<float const>{h_weights}};

View File

@ -19,12 +19,12 @@
#include "categorical.h" #include "categorical.h"
#include "common.h" #include "common.h"
#include "error_msg.h" // GroupWeight
#include "optional_weight.h" // OptionalWeights #include "optional_weight.h" // OptionalWeights
#include "threading_utils.h" #include "threading_utils.h"
#include "timer.h" #include "timer.h"
namespace xgboost { namespace xgboost::common {
namespace common {
/*! /*!
* \brief experimental wsummary * \brief experimental wsummary
* \tparam DType type of data content * \tparam DType type of data content
@ -695,13 +695,18 @@ inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
return group_weights; return group_weights;
} }
size_t n_samples = info.num_row_;
auto const &group_ptr = info.group_ptr_; auto const &group_ptr = info.group_ptr_;
std::vector<float> results(n_samples);
CHECK_GE(group_ptr.size(), 2); CHECK_GE(group_ptr.size(), 2);
CHECK_EQ(group_ptr.back(), n_samples);
auto n_groups = group_ptr.size() - 1;
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
bst_row_t n_samples = info.num_row_;
std::vector<float> results(n_samples);
CHECK_EQ(group_ptr.back(), n_samples)
<< error::GroupSize() << " the number of rows from the data.";
size_t cur_group = 0; size_t cur_group = 0;
for (size_t i = 0; i < n_samples; ++i) { for (bst_row_t i = 0; i < n_samples; ++i) {
results[i] = group_weights[cur_group]; results[i] = group_weights[cur_group];
if (i == group_ptr[cur_group + 1]) { if (i == group_ptr[cur_group + 1]) {
cur_group++; cur_group++;
@ -1010,6 +1015,5 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
*/ */
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian); void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
}; };
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_H_ #endif // XGBOOST_COMMON_QUANTILE_H_

View File

@ -114,9 +114,20 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS
DMLC_REGISTER_PARAMETER(LambdaRankParam); DMLC_REGISTER_PARAMETER(LambdaRankParam);
void PreCache::InitOnCPU(Context const*, MetaInfo const& info) {
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
CheckPreLabels("pre", h_label,
[](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
}
#if !defined(XGBOOST_USE_CUDA)
void PreCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)
void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) { void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) {
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0); auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); }); CheckPreLabels("map", h_label,
[](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)

View File

@ -205,8 +205,13 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); }); [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
} }
void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()});
}
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()}); CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()});
} }
} // namespace xgboost::ltr } // namespace xgboost::ltr

View File

@ -366,18 +366,43 @@ bool IsBinaryRel(linalg::VectorView<float const> label, AllOf all_of) {
}); });
} }
/** /**
* \brief Validate label for MAP * \brief Validate label for precision-based metric.
* *
* \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for * \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
* both CPU and GPU. * both CPU and GPU.
*/ */
template <typename AllOf> template <typename AllOf>
void CheckMapLabels(linalg::VectorView<float const> label, AllOf all_of) { void CheckPreLabels(StringView name, linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values(); auto s_label = label.Values();
auto is_binary = IsBinaryRel(label, all_of); auto is_binary = IsBinaryRel(label, all_of);
CHECK(is_binary) << "MAP can only be used with binary labels."; CHECK(is_binary) << name << " can only be used with binary labels.";
} }
class PreCache : public RankingCache {
HostDeviceVector<double> pre_;
void InitOnCPU(Context const* ctx, MetaInfo const& info);
void InitOnCUDA(Context const* ctx, MetaInfo const& info);
public:
PreCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
this->InitOnCUDA(ctx, info);
}
}
common::Span<double> Pre(Context const* ctx) {
if (pre_.Empty()) {
pre_.SetDevice(ctx->gpu_id);
pre_.Resize(this->Groups());
}
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
}
};
class MAPCache : public RankingCache { class MAPCache : public RankingCache {
// Total number of relevant documents for each group // Total number of relevant documents for each group
HostDeviceVector<double> n_rel_; HostDeviceVector<double> n_rel_;

View File

@ -366,8 +366,8 @@ inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, Da
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx, inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const*,
BatchParam const& param) { BatchParam const&) {
common::AssertGPUSupport(); common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_)); auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter)); return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));

View File

@ -52,32 +52,13 @@ Metric::Create(const std::string& name, Context const* ctx) {
metric->ctx_ = ctx; metric->ctx_ = ctx;
return metric; return metric;
} }
GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
auto metric = CreateMetricImpl<MetricGPUReg>(name);
if (metric == nullptr) {
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
<< ". Resorting to the CPU builder";
return nullptr;
}
// Narrowing reference only for the compiler to allow assignment to a base class member.
// As such, using this narrowed reference to refer to derived members will be an illegal op.
// This is moot, as this type is stateless.
auto casted = static_cast<GPUMetric*>(metric);
CHECK(casted);
casted->ctx_ = ctx;
return casted;
}
} // namespace xgboost } // namespace xgboost
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg);
} }
namespace xgboost { namespace xgboost::metric {
namespace metric {
// List of files that will be force linked in static links. // List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(auc); DMLC_REGISTRY_LINK_TAG(auc);
DMLC_REGISTRY_LINK_TAG(elementwise_metric); DMLC_REGISTRY_LINK_TAG(elementwise_metric);
@ -88,5 +69,4 @@ DMLC_REGISTRY_LINK_TAG(rank_metric);
DMLC_REGISTRY_LINK_TAG(auc_gpu); DMLC_REGISTRY_LINK_TAG(auc_gpu);
DMLC_REGISTRY_LINK_TAG(rank_metric_gpu); DMLC_REGISTRY_LINK_TAG(rank_metric_gpu);
#endif #endif
} // namespace metric } // namespace xgboost::metric
} // namespace xgboost

View File

@ -24,52 +24,13 @@ class MetricNoCache : public Metric {
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final { double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
double result{0.0}; double result{0.0};
auto const &info = p_fmat->Info(); auto const &info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] { collective::ApplyWithLabels(info, &result, sizeof(double),
result = this->Eval(predts, info); [&] { result = this->Eval(predts, info); });
});
return result; return result;
} }
}; };
// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
// present already. This is created when there is a device ordinal present and if xgboost
// is compiled with CUDA support
struct GPUMetric : public MetricNoCache {
static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam);
};
/*!
* \brief Internal registry entries for GPU Metric factory functions.
* The additional parameter const char* param gives the value after @, can be null.
* For example, metric map@3, then: param == "3".
*/
struct MetricGPUReg
: public dmlc::FunctionRegEntryBase<MetricGPUReg,
std::function<Metric * (const char*)> > {
};
/*!
* \brief Macro to register metric computed on GPU.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg")
* .describe("NDCG metric computer on GPU.")
* .set_body([](const char* param) {
* int at_k = atoi(param);
* return new NDCG(at_k);
* });
* \endcode
*/
// Note: Metric names registered in the GPU registry should follow this convention:
// - GPU metric types should be registered with the same name as the non GPU metric types
#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \
::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name)
namespace metric { namespace metric {
// Ranking config to be used on device and host // Ranking config to be used on device and host
struct EvalRankConfig { struct EvalRankConfig {
public: public:
@ -91,16 +52,15 @@ class PackedReduceResult {
XGBOOST_DEVICE XGBOOST_DEVICE
PackedReduceResult operator+(PackedReduceResult const &other) const { PackedReduceResult operator+(PackedReduceResult const &other) const {
return PackedReduceResult{residue_sum_ + other.residue_sum_, return PackedReduceResult{residue_sum_ + other.residue_sum_, weights_sum_ + other.weights_sum_};
weights_sum_ + other.weights_sum_};
} }
PackedReduceResult &operator+=(PackedReduceResult const &other) { PackedReduceResult &operator+=(PackedReduceResult const &other) {
this->residue_sum_ += other.residue_sum_; this->residue_sum_ += other.residue_sum_;
this->weights_sum_ += other.weights_sum_; this->weights_sum_ += other.weights_sum_;
return *this; return *this;
} }
double Residue() const { return residue_sum_; } [[nodiscard]] double Residue() const { return residue_sum_; }
double Weights() const { return weights_sum_; } [[nodiscard]] double Weights() const { return weights_sum_; }
}; };
} // namespace metric } // namespace metric

View File

@ -1,25 +1,6 @@
/** /**
* Copyright 2020-2023 by XGBoost contributors * Copyright 2020-2023 by XGBoost contributors
*/ */
// When device ordinal is present, we would want to build the metrics on the GPU. It is *not*
// possible for a valid device ordinal to be present for non GPU builds. However, it is possible
// for an invalid device ordinal to be specified in GPU builds - to train/predict and/or compute
// the metrics on CPU. To accommodate these scenarios, the following is done for the metrics
// accelerated on the GPU.
// - An internal GPU registry holds all the GPU metric types (defined in the .cu file)
// - An instance of the appropriate GPU metric type is created when a device ordinal is present
// - If the creation is successful, the metric computation is done on the device
// - else, it falls back on the CPU
// - The GPU metric types are *only* registered when xgboost is built for GPUs
//
// This is done for 2 reasons:
// - Clear separation of CPU and GPU logic
// - Sorting datasets containing large number of rows is (much) faster when parallel sort
// semantics is used on the CPU. The __gnu_parallel/concurrency primitives needed to perform
// this cannot be used when the translation unit is compiled using the 'nvcc' compiler (as the
// corresponding headers that brings in those function declaration can't be included with CUDA).
// This precludes the CPU and GPU logic to coexist inside a .cu file
#include "rank_metric.h" #include "rank_metric.h"
#include <dmlc/omp.h> #include <dmlc/omp.h>
@ -57,55 +38,8 @@
#include "xgboost/string_view.h" // for StringView #include "xgboost/string_view.h" // for StringView
namespace { namespace {
using PredIndPair = std::pair<xgboost::bst_float, xgboost::ltr::rel_degree_t>; using PredIndPair = std::pair<xgboost::bst_float, xgboost::ltr::rel_degree_t>;
using PredIndPairContainer = std::vector<PredIndPair>; using PredIndPairContainer = std::vector<PredIndPair>;
/*
* Adapter to access instance weights.
*
* - For ranking task, weights are per-group
* - For binary classification task, weights are per-instance
*
* WeightPolicy::GetWeightOfInstance() :
* get weight associated with an individual instance, using index into
* `info.weights`
* WeightPolicy::GetWeightOfSortedRecord() :
* get weight associated with an individual instance, using index into
* sorted records `rec` (in ascending order of predicted labels). `rec` is
* of type PredIndPairContainer
*/
class PerInstanceWeightPolicy {
public:
inline static xgboost::bst_float
GetWeightOfInstance(const xgboost::MetaInfo& info,
unsigned instance_id, unsigned) {
return info.GetWeight(instance_id);
}
inline static xgboost::bst_float
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
const PredIndPairContainer& rec,
unsigned record_id, unsigned) {
return info.GetWeight(rec[record_id].second);
}
};
class PerGroupWeightPolicy {
public:
inline static xgboost::bst_float
GetWeightOfInstance(const xgboost::MetaInfo& info,
unsigned, unsigned group_id) {
return info.GetWeight(group_id);
}
inline static xgboost::bst_float
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
const PredIndPairContainer&,
unsigned, unsigned group_id) {
return info.GetWeight(group_id);
}
};
} // anonymous namespace } // anonymous namespace
namespace xgboost::metric { namespace xgboost::metric {
@ -177,10 +111,6 @@ struct EvalAMS : public MetricNoCache {
/*! \brief Evaluate rank list */ /*! \brief Evaluate rank list */
struct EvalRank : public MetricNoCache, public EvalRankConfig { struct EvalRank : public MetricNoCache, public EvalRankConfig {
private:
// This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU.
std::unique_ptr<MetricNoCache> rank_gpu_;
public: public:
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override { double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
CHECK_EQ(preds.Size(), info.labels.Size()) CHECK_EQ(preds.Size(), info.labels.Size())
@ -199,20 +129,10 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
// sum statistics // sum statistics
double sum_metric = 0.0f; double sum_metric = 0.0f;
// Check and see if we have the GPU metric registered in the internal registry
if (ctx_->gpu_id >= 0) {
if (!rank_gpu_) {
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_));
}
if (rank_gpu_) {
sum_metric = rank_gpu_->Eval(preds, info);
}
}
CHECK(ctx_); CHECK(ctx_);
std::vector<double> sum_tloc(ctx_->Threads(), 0.0); std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
if (!rank_gpu_ || ctx_->gpu_id < 0) { {
const auto& labels = info.labels.View(Context::kCpuId); const auto& labels = info.labels.View(Context::kCpuId);
const auto &h_preds = preds.ConstHostVector(); const auto &h_preds = preds.ConstHostVector();
@ -253,23 +173,6 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
virtual double EvalGroup(PredIndPairContainer *recptr) const = 0; virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
}; };
/*! \brief Precision at N, for both classification and rank */
struct EvalPrecision : public EvalRank {
public:
explicit EvalPrecision(const char* name, const char* param) : EvalRank(name, param) {}
double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
// calculate Precision
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhit = 0;
for (size_t j = 0; j < rec.size() && j < this->topn; ++j) {
nhit += (rec[j].second != 0);
}
return static_cast<double>(nhit) / this->topn;
}
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */ /*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache { struct EvalCox : public MetricNoCache {
public: public:
@ -312,7 +215,7 @@ struct EvalCox : public MetricNoCache {
return out/num_events; // normalize by the number of events return out/num_events; // normalize by the number of events
} }
const char* Name() const override { [[nodiscard]] const char* Name() const override {
return "cox-nloglik"; return "cox-nloglik";
} }
}; };
@ -321,10 +224,6 @@ XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.") .describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); }); .set_body([](const char* param) { return new EvalAMS(param); });
XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.") .describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); }); .set_body([](const char*) { return new EvalCox(); });
@ -387,6 +286,8 @@ class EvalRankWithCache : public Metric {
return result; return result;
} }
[[nodiscard]] const char* Name() const override { return name_.c_str(); }
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info, virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<Cache> p_cache) = 0; std::shared_ptr<Cache> p_cache) = 0;
}; };
@ -408,6 +309,52 @@ double Finalize(MetaInfo const& info, double score, double sw) {
} }
} // namespace } // namespace
class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::PreCache> p_cache) final {
auto n_groups = p_cache->Groups();
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
}
if (ctx_->IsCUDA()) {
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
return Finalize(info, pre.Residue(), pre.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
auto pre = p_cache->Pre(ctx_);
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
for (std::size_t i = 0; i < n; ++i) {
n_hits += g_label(g_rank[i]) * weight[g];
}
pre[g] = n_hits / static_cast<double>(n);
});
auto sw = 0.0;
for (std::size_t i = 0; i < pre.size(); ++i) {
sw += weight[i];
}
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
return Finalize(info, sum, sw);
}
};
/** /**
* \brief Implement the NDCG score function for learning to rank. * \brief Implement the NDCG score function for learning to rank.
* *
@ -416,7 +363,6 @@ double Finalize(MetaInfo const& info, double score, double sw) {
class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> { class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
public: public:
using EvalRankWithCache::EvalRankWithCache; using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info, double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<ltr::NDCGCache> p_cache) override { std::shared_ptr<ltr::NDCGCache> p_cache) override {
@ -475,7 +421,6 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> { class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
public: public:
using EvalRankWithCache::EvalRankWithCache; using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info, double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache) override { std::shared_ptr<ltr::MAPCache> p_cache) override {
@ -494,7 +439,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1])); auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]); auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size()); auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0}; double n_hits{0.0};
@ -527,6 +472,10 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
} }
}; };
XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(EvalMAP, "map") XGBOOST_REGISTER_METRIC(EvalMAP, "map")
.describe("map@k for ranking.") .describe("map@k for ranking.")
.set_body([](char const* param) { .set_body([](char const* param) {

View File

@ -28,108 +28,57 @@ namespace xgboost::metric {
// tag the this file, used by force static link later. // tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(rank_metric_gpu); DMLC_REGISTRY_FILE_TAG(rank_metric_gpu);
/*! \brief Evaluate rank list on GPU */ namespace cuda_impl {
template <typename EvalMetricT> PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
struct EvalRankGpu : public GPUMetric, public EvalRankConfig { HostDeviceVector<float> const &predt,
public: std::shared_ptr<ltr::PreCache> p_cache) {
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override { auto d_gptr = p_cache->DataGroupPtr(ctx);
// Sanity check is done by the caller auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(preds.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1); predt.SetDevice(ctx->gpu_id);
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
auto topk = p_cache->Param().TopK();
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
auto device = ctx_->gpu_id; auto it = dh::MakeTransformIterator<double>(
dh::safe_cuda(cudaSetDevice(device)); thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
auto g = dh::SegmentId(d_gptr, i);
info.labels.SetDevice(device); auto g_begin = d_gptr[g];
preds.SetDevice(device); auto g_end = d_gptr[g + 1];
i -= g_begin;
auto dpreds = preds.ConstDevicePointer(); auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto dlabels = info.labels.View(device); auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
double y = g_label(g_rank[i]);
// Sort all the predictions auto n = std::min(static_cast<std::size_t>(topk), g_label.Size());
dh::SegmentSorter<float> segment_pred_sorter; double w{d_weight[g]};
segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr); if (i >= n) {
return 0.0;
// Compute individual group metric and sum them up
return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels.Values().data(), *this);
}
const char* Name() const override {
return name.c_str();
}
explicit EvalRankGpu(const char* name, const char* param) {
using namespace std; // NOLINT(*)
if (param != nullptr) {
std::ostringstream os;
if (sscanf(param, "%u[-]?", &this->topn) == 1) {
os << name << '@' << param;
this->name = os.str();
} else {
os << name << param;
this->name = os.str();
}
if (param[strlen(param) - 1] == '-') {
this->minus = true;
}
} else {
this->name = name;
}
}
};
/*! \brief Precision at N, for both classification and rank */
struct EvalPrecisionGpu {
public:
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg) {
// Group info on device
const auto &dgroups = pred_sorter.GetGroupsSpan();
const auto ngroups = pred_sorter.GetNumGroups();
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
// Original positions of the predictions after they have been sorted
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
// First, determine non zero labels in the dataset individually
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
}; // NOLINT
// Find each group's metric sum
dh::caching_device_vector<uint32_t> hits(ngroups, 0);
const auto nitems = pred_sorter.GetNumItems();
auto *dhits = hits.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
const auto group_idx = dgroup_idx[idx];
const auto group_begin = dgroups[group_idx];
const auto ridx = idx - group_begin;
if (ridx < ecfg.topn && DetermineNonTrivialLabelLambda(idx)) {
atomicAdd(&dhits[group_idx], 1);
} }
return y / static_cast<double>(n) * w;
}); });
// Allocator to be used for managing space overhead while performing reductions auto cuctx = ctx->CUDACtx();
dh::XGBCachingDeviceAllocator<char> alloc; auto pre = p_cache->Pre(ctx);
return static_cast<double>(thrust::reduce(thrust::cuda::par(alloc), thrust::fill_n(cuctx->CTP(), pre.data(), pre.size(), 0.0);
hits.begin(), hits.end())) / ecfg.topn;
std::size_t bytes;
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, it, pre.data(), p_cache->Groups(), d_gptr.data(),
d_gptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, it, pre.data(), p_cache->Groups(),
d_gptr.data(), d_gptr.data() + 1, cuctx->Stream());
auto w_it =
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t g) { return d_weight[g]; });
auto n_weights = p_cache->Groups();
auto sw = dh::Reduce(cuctx->CTP(), w_it, w_it + n_weights, 0.0, thrust::plus<double>{});
auto sum =
dh::Reduce(cuctx->CTP(), dh::tcbegin(pre), dh::tcend(pre), 0.0, thrust::plus<double>{});
auto result = PackedReduceResult{sum, sw};
return result;
} }
};
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus, HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache) { std::shared_ptr<ltr::NDCGCache> p_cache) {

View File

@ -12,9 +12,7 @@
#include "xgboost/data.h" // for MetaInfo #include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/host_device_vector.h" // for HostDeviceVector
namespace xgboost { namespace xgboost::metric::cuda_impl {
namespace metric {
namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus, HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache); std::shared_ptr<ltr::NDCGCache> p_cache);
@ -23,6 +21,10 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus, HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache); std::shared_ptr<ltr::MAPCache> p_cache);
PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt,
std::shared_ptr<ltr::PreCache> p_cache);
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool, HostDeviceVector<float> const &, bool,
@ -37,8 +39,13 @@ inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }
inline PackedReduceResult PreScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &,
std::shared_ptr<ltr::PreCache>) {
common::AssertGPUSupport();
return {};
}
#endif #endif
} // namespace cuda_impl } // namespace xgboost::metric::cuda_impl
} // namespace metric
} // namespace xgboost
#endif // XGBOOST_METRIC_RANK_METRIC_H_ #endif // XGBOOST_METRIC_RANK_METRIC_H_

View File

@ -90,7 +90,7 @@ def check_cmd_print_failure_assistance(cmd: List[str]) -> bool:
subprocess.run([cmd[0], "--version"]) subprocess.run([cmd[0], "--version"])
msg = """ msg = """
Please run the following command on your machine to address the formatting error: Please run the following command on your machine to address the error:
""" """
msg += " ".join(cmd) msg += " ".join(cmd)

View File

@ -17,34 +17,30 @@
#include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, String, Object #include "xgboost/json.h" // for Json, String, Object
namespace xgboost { namespace xgboost::metric {
namespace metric {
inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) { inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) {
// When the limit for precision is not given, it takes the limit at
// std::numeric_limits<unsigned>::max(); hence all values are very small
// NOTE(AbdealiJK): Maybe this should be fixed to be num_row by default.
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("pre", &ctx); std::unique_ptr<xgboost::Metric> metric{Metric::Create("pre", &ctx)};
ASSERT_STREQ(metric->Name(), "pre"); ASSERT_STREQ(metric->Name(), "pre");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0, 1e-7); EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(
{0.1f, 0.9f, 0.1f, 0.9f}, GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode),
{ 0, 0, 1, 1}, {}, {}, data_split_mode), 0.5, 1e-7);
0, 1e-7);
delete metric; metric.reset(xgboost::Metric::Create("pre@2", &ctx));
metric = xgboost::Metric::Create("pre@2", &ctx);
ASSERT_STREQ(metric->Name(), "pre@2"); ASSERT_STREQ(metric->Name(), "pre@2");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7); EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(
{0.1f, 0.9f, 0.1f, 0.9f}, GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode),
{ 0, 0, 1, 1}, {}, {}, data_split_mode),
0.5f, 0.001f); 0.5f, 0.001f);
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}, {}, {}, data_split_mode)); EXPECT_ANY_THROW(GetMetricEval(metric.get(), {0, 1}, {}, {}, {}, data_split_mode));
delete metric; metric.reset(xgboost::Metric::Create("pre@4", &ctx));
EXPECT_NEAR(GetMetricEval(metric.get(), {0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f},
{0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f}, {}, {}, data_split_mode),
0.5f, 1e-7);
} }
inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) { inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) {
@ -187,5 +183,4 @@ inline void VerifyNDCGExpGain(DataSplitMode data_split_mode = DataSplitMode::kRo
ndcg = metric->Evaluate(predt, p_fmat); ndcg = metric->Evaluate(predt, p_fmat);
ASSERT_NEAR(ndcg, 1.0, kRtEps); ASSERT_NEAR(ndcg, 1.0, kRtEps);
} }
} // namespace metric } // namespace xgboost::metric
} // namespace xgboost

View File

@ -5,7 +5,7 @@ import pytest
import xgboost import xgboost
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.metrics import check_quantile_error from xgboost.testing.metrics import check_precision_score, check_quantile_error
sys.path.append("tests/python") sys.path.append("tests/python")
import test_eval_metrics as test_em # noqa import test_eval_metrics as test_em # noqa
@ -59,6 +59,9 @@ class TestGPUEvalMetrics:
def test_pr_auc_ltr(self): def test_pr_auc_ltr(self):
self.cpu_test.run_pr_auc_ltr("gpu_hist") self.cpu_test.run_pr_auc_ltr("gpu_hist")
def test_precision_score(self):
check_precision_score("gpu_hist")
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_error(self) -> None: def test_quantile_error(self) -> None:
check_quantile_error("gpu_hist") check_quantile_error("gpu_hist")

View File

@ -3,7 +3,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.metrics import check_quantile_error from xgboost.testing.metrics import check_precision_score, check_quantile_error
rng = np.random.RandomState(1337) rng = np.random.RandomState(1337)
@ -315,6 +315,9 @@ class TestEvalMetrics:
def test_pr_auc_ltr(self): def test_pr_auc_ltr(self):
self.run_pr_auc_ltr("hist") self.run_pr_auc_ltr("hist")
def test_precision_score(self):
check_precision_score("hist")
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_error(self) -> None: def test_quantile_error(self) -> None:
check_quantile_error("hist") check_quantile_error("hist")

View File

@ -55,6 +55,38 @@ class TestQuantileDMatrix:
r = np.arange(1.0, n_samples) r = np.arange(1.0, n_samples)
np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r) np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r)
def test_error(self):
from sklearn.model_selection import train_test_split
rng = np.random.default_rng(1994)
X, y = make_categorical(
n_samples=128, n_features=2, n_categories=3, onehot=False
)
reg = xgb.XGBRegressor(tree_method="hist", enable_categorical=True)
w = rng.uniform(0, 1, size=y.shape[0])
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
X, y, w, random_state=1994
)
with pytest.raises(ValueError, match="sample weight"):
reg.fit(
X,
y,
sample_weight=w_train,
eval_set=[(X_test, y_test)],
sample_weight_eval_set=[w_test],
)
with pytest.raises(ValueError, match="sample weight"):
reg.fit(
X_train,
y_train,
sample_weight=w,
eval_set=[(X_test, y_test)],
sample_weight_eval_set=[w_test],
)
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9]) @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
def test_with_iterator(self, sparsity: float) -> None: def test_with_iterator(self, sparsity: float) -> None:
n_samples_per_batch = 317 n_samples_per_batch = 317