Rework the precision metric. (#9222)
- Rework the precision metric for both CPU and GPU. - Mention it in the document. - Cleanup old support code for GPU ranking metric. - Deterministic GPU implementation. * Drop support for classification. * type. * use batch shape. * lint. * cpu build. * cpu build. * lint. * Tests. * Fix. * Cleanup error message.
This commit is contained in:
parent
db8288121d
commit
9fbde21e9d
@ -319,7 +319,7 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
|
||||
|
||||
# maximize is usually NULL when not set in xgb.train and built-in metrics
|
||||
if (is.null(maximize))
|
||||
maximize <<- grepl('(_auc|_map|_ndcg)', metric_name)
|
||||
maximize <<- grepl('(_auc|_map|_ndcg|_pre)', metric_name)
|
||||
|
||||
if (verbose && NVL(env$rank, 0) == 0)
|
||||
cat("Will train until ", metric_name, " hasn't improved in ",
|
||||
|
||||
@ -424,6 +424,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
|
||||
After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation.
|
||||
|
||||
- ``pre``: Precision at :math:`k`. Supports only learning to rank task.
|
||||
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
|
||||
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
|
||||
|
||||
@ -435,7 +436,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
|
||||
where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.
|
||||
|
||||
- ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
|
||||
- ``ndcg@n``, ``map@n``, ``pre@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
|
||||
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
|
||||
- ``poisson-nloglik``: negative log-likelihood for Poisson regression
|
||||
- ``gamma-nloglik``: negative log-likelihood for gamma regression
|
||||
|
||||
@ -372,6 +372,8 @@ class EarlyStopping(TrainingCallback):
|
||||
maximize_metrics = (
|
||||
"auc",
|
||||
"aucpr",
|
||||
"pre",
|
||||
"pre@",
|
||||
"map",
|
||||
"ndcg",
|
||||
"auc@",
|
||||
|
||||
@ -1,9 +1,61 @@
|
||||
"""Tests for evaluation metrics."""
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost.compat import concat
|
||||
from xgboost.core import _parse_eval_str
|
||||
|
||||
|
||||
def check_precision_score(tree_method: str) -> None:
|
||||
"""Test for precision with ranking and classification."""
|
||||
datasets = pytest.importorskip("sklearn.datasets")
|
||||
|
||||
X, y = datasets.make_classification(
|
||||
n_samples=1024, n_features=4, n_classes=2, random_state=2023
|
||||
)
|
||||
qid = np.zeros(shape=y.shape) # same group
|
||||
|
||||
ltr = xgb.XGBRanker(n_estimators=2, tree_method=tree_method)
|
||||
ltr.fit(X, y, qid=qid)
|
||||
|
||||
# re-generate so that XGBoost doesn't evaluate the result to 1.0
|
||||
X, y = datasets.make_classification(
|
||||
n_samples=512, n_features=4, n_classes=2, random_state=1994
|
||||
)
|
||||
|
||||
ltr.set_params(eval_metric="pre@32")
|
||||
result = _parse_eval_str(
|
||||
ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y), "Xy")])
|
||||
)
|
||||
score_0 = result[1][1]
|
||||
|
||||
X_list = []
|
||||
y_list = []
|
||||
n_query_groups = 3
|
||||
q_list: List[np.ndarray] = []
|
||||
for i in range(n_query_groups):
|
||||
# same for all groups
|
||||
X, y = datasets.make_classification(
|
||||
n_samples=512, n_features=4, n_classes=2, random_state=1994
|
||||
)
|
||||
X_list.append(X)
|
||||
y_list.append(y)
|
||||
q = np.full(shape=y.shape, fill_value=i, dtype=np.uint64)
|
||||
q_list.append(q)
|
||||
|
||||
qid = concat(q_list)
|
||||
X = concat(X_list)
|
||||
y = concat(y_list)
|
||||
|
||||
result = _parse_eval_str(
|
||||
ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y, qid=qid), "Xy")])
|
||||
)
|
||||
assert result[1][0].endswith("pre@32")
|
||||
score_1 = result[1][1]
|
||||
assert score_1 == score_0
|
||||
|
||||
|
||||
def check_quantile_error(tree_method: str) -> None:
|
||||
|
||||
@ -825,176 +825,6 @@ XGBOOST_DEVICE auto tcrend(xgboost::common::Span<T> const &span) { // NOLINT
|
||||
return tcrbegin(span) + span.size();
|
||||
}
|
||||
|
||||
// This type sorts an array which is divided into multiple groups. The sorting is influenced
|
||||
// by the function object 'Comparator'
|
||||
template <typename T>
|
||||
class SegmentSorter {
|
||||
private:
|
||||
// Items sorted within the group
|
||||
caching_device_vector<T> ditems_;
|
||||
|
||||
// Original position of the items before they are sorted descending within their groups
|
||||
caching_device_vector<uint32_t> doriginal_pos_;
|
||||
|
||||
// Segments within the original list that delineates the different groups
|
||||
caching_device_vector<uint32_t> group_segments_;
|
||||
|
||||
// Need this on the device as it is used in the kernels
|
||||
caching_device_vector<uint32_t> dgroups_; // Group information on device
|
||||
|
||||
// Where did the item that was originally present at position 'x' move to after they are sorted
|
||||
caching_device_vector<uint32_t> dindexable_sorted_pos_;
|
||||
|
||||
// Initialize everything but the segments
|
||||
void Init(uint32_t num_elems) {
|
||||
ditems_.resize(num_elems);
|
||||
|
||||
doriginal_pos_.resize(num_elems);
|
||||
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
|
||||
}
|
||||
|
||||
// Initialize all with group info
|
||||
void Init(const std::vector<uint32_t> &groups) {
|
||||
uint32_t num_elems = groups.back();
|
||||
this->Init(num_elems);
|
||||
this->CreateGroupSegments(groups);
|
||||
}
|
||||
|
||||
public:
|
||||
// This needs to be public due to device lambda
|
||||
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
|
||||
uint32_t num_elems = groups.back();
|
||||
group_segments_.resize(num_elems, 0);
|
||||
|
||||
dgroups_ = groups;
|
||||
|
||||
if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them
|
||||
|
||||
// Define the segments by assigning a group ID to each element
|
||||
const uint32_t *dgroups = dgroups_.data().get();
|
||||
uint32_t ngroups = dgroups_.size();
|
||||
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
|
||||
return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
|
||||
dgroups - 1;
|
||||
}; // NOLINT
|
||||
|
||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(num_elems),
|
||||
group_segments_.begin(),
|
||||
ComputeGroupIDLambda);
|
||||
}
|
||||
|
||||
// Accessors that returns device pointer
|
||||
inline uint32_t GetNumItems() const { return ditems_.size(); }
|
||||
inline const xgboost::common::Span<const T> GetItemsSpan() const {
|
||||
return { ditems_.data().get(), ditems_.size() };
|
||||
}
|
||||
|
||||
inline const xgboost::common::Span<const uint32_t> GetOriginalPositionsSpan() const {
|
||||
return { doriginal_pos_.data().get(), doriginal_pos_.size() };
|
||||
}
|
||||
|
||||
inline const xgboost::common::Span<const uint32_t> GetGroupSegmentsSpan() const {
|
||||
return { group_segments_.data().get(), group_segments_.size() };
|
||||
}
|
||||
|
||||
inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; }
|
||||
inline const xgboost::common::Span<const uint32_t> GetGroupsSpan() const {
|
||||
return { dgroups_.data().get(), dgroups_.size() };
|
||||
}
|
||||
|
||||
inline const xgboost::common::Span<const uint32_t> GetIndexableSortedPositionsSpan() const {
|
||||
return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() };
|
||||
}
|
||||
|
||||
// Sort an array that is divided into multiple groups. The array is sorted within each group.
|
||||
// This version provides the group information that is on the host.
|
||||
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
|
||||
// is used.
|
||||
template <typename Comparator = thrust::greater<T>>
|
||||
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
|
||||
const Comparator &comp = Comparator()) {
|
||||
this->Init(groups);
|
||||
this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp);
|
||||
}
|
||||
|
||||
// Sort an array that is divided into multiple groups. The array is sorted within each group.
|
||||
// This version provides the group information that is on the device.
|
||||
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
|
||||
// is used.
|
||||
template <typename Comparator = thrust::greater<T>>
|
||||
void SortItems(const T *ditems, uint32_t item_size,
|
||||
const xgboost::common::Span<const uint32_t> &group_segments,
|
||||
const Comparator &comp = Comparator()) {
|
||||
this->Init(item_size);
|
||||
|
||||
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
|
||||
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
|
||||
// when comparators are used. Hence, the following algorithm is used. This is done so that
|
||||
// we can grab the appropriate related values from the original list later, after the
|
||||
// items are sorted.
|
||||
//
|
||||
// Here is the internal representation:
|
||||
// dgroups_: [ 0, 3, 5, 8, 10 ]
|
||||
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
|
||||
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
|
||||
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
|
||||
//
|
||||
// Sort the items first and make a note of the original positions in doriginal_pos_
|
||||
// based on the sort
|
||||
// ditems_: 4 4 3 3 2 1 1 1 1 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
|
||||
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
|
||||
// in kernel, sorting using predicates etc.
|
||||
|
||||
ditems_.assign(thrust::device_ptr<const T>(ditems),
|
||||
thrust::device_ptr<const T>(ditems) + item_size);
|
||||
|
||||
// Allocator to be used by sort for managing space overhead while sorting
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
|
||||
ditems_.begin(), ditems_.end(),
|
||||
doriginal_pos_.begin(), comp);
|
||||
|
||||
if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented
|
||||
|
||||
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
|
||||
// holisitic item sort order on the segments
|
||||
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
|
||||
caching_device_vector<uint32_t> group_segments_c(item_size);
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
dh::tcbegin(group_segments), group_segments_c.begin());
|
||||
|
||||
// Now, sort the group segments so that you may bring the items within the group together,
|
||||
// in the process also noting the relative changes to the doriginal_pos_ while that happens
|
||||
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
|
||||
group_segments_c.begin(), group_segments_c.end(),
|
||||
doriginal_pos_.begin(), thrust::less<uint32_t>());
|
||||
|
||||
// Finally, gather the original items based on doriginal_pos_ to sort the input and
|
||||
// to store them in ditems_
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
|
||||
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
thrust::device_ptr<const T>(ditems), ditems_.begin());
|
||||
}
|
||||
|
||||
// Determine where an item that was originally present at position 'x' has been relocated to
|
||||
// after a sort. Creation of such an index has to be explicitly requested after a sort
|
||||
void CreateIndexableSortedPositions() {
|
||||
dindexable_sorted_pos_.resize(GetNumItems());
|
||||
thrust::scatter(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(GetNumItems()), // Rearrange indices...
|
||||
// ...based on this map
|
||||
dh::tcbegin(GetOriginalPositionsSpan()),
|
||||
dindexable_sorted_pos_.begin()); // Write results into this
|
||||
}
|
||||
};
|
||||
|
||||
// Atomic add function for gradients
|
||||
template <typename OutputGradientT, typename InputGradientT>
|
||||
XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
|
||||
@ -8,8 +8,7 @@
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/span.h" // Span
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
struct OptionalWeights {
|
||||
Span<float const> weights;
|
||||
float dft{1.0f}; // fixme: make this compile time constant
|
||||
@ -18,7 +17,8 @@ struct OptionalWeights {
|
||||
explicit OptionalWeights(float w) : dft{w} {}
|
||||
|
||||
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
|
||||
auto Empty() const { return weights.empty(); }
|
||||
[[nodiscard]] auto Empty() const { return weights.empty(); }
|
||||
[[nodiscard]] auto Size() const { return weights.size(); }
|
||||
};
|
||||
|
||||
inline OptionalWeights MakeOptionalWeights(Context const* ctx,
|
||||
@ -28,6 +28,5 @@ inline OptionalWeights MakeOptionalWeights(Context const* ctx,
|
||||
}
|
||||
return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
|
||||
|
||||
@ -90,6 +90,9 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid
|
||||
MetaInfo const &info, float missing) {
|
||||
auto const &h_weights =
|
||||
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());
|
||||
if (!use_group_ind_ && !h_weights.empty()) {
|
||||
CHECK_EQ(h_weights.size(), batch.Size()) << "Invalid size of sample weight.";
|
||||
}
|
||||
|
||||
auto is_valid = data::IsValidFunctor{missing};
|
||||
auto weights = OptionalWeights{Span<float const>{h_weights}};
|
||||
|
||||
@ -19,12 +19,12 @@
|
||||
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "error_msg.h" // GroupWeight
|
||||
#include "optional_weight.h" // OptionalWeights
|
||||
#include "threading_utils.h"
|
||||
#include "timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
/*!
|
||||
* \brief experimental wsummary
|
||||
* \tparam DType type of data content
|
||||
@ -695,13 +695,18 @@ inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
return group_weights;
|
||||
}
|
||||
|
||||
size_t n_samples = info.num_row_;
|
||||
auto const &group_ptr = info.group_ptr_;
|
||||
std::vector<float> results(n_samples);
|
||||
CHECK_GE(group_ptr.size(), 2);
|
||||
CHECK_EQ(group_ptr.back(), n_samples);
|
||||
|
||||
auto n_groups = group_ptr.size() - 1;
|
||||
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
|
||||
|
||||
bst_row_t n_samples = info.num_row_;
|
||||
std::vector<float> results(n_samples);
|
||||
CHECK_EQ(group_ptr.back(), n_samples)
|
||||
<< error::GroupSize() << " the number of rows from the data.";
|
||||
size_t cur_group = 0;
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
for (bst_row_t i = 0; i < n_samples; ++i) {
|
||||
results[i] = group_weights[cur_group];
|
||||
if (i == group_ptr[cur_group + 1]) {
|
||||
cur_group++;
|
||||
@ -1010,6 +1015,5 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
|
||||
*/
|
||||
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
#endif // XGBOOST_COMMON_QUANTILE_H_
|
||||
|
||||
@ -114,9 +114,20 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS
|
||||
|
||||
DMLC_REGISTER_PARAMETER(LambdaRankParam);
|
||||
|
||||
void PreCache::InitOnCPU(Context const*, MetaInfo const& info) {
|
||||
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
CheckPreLabels("pre", h_label,
|
||||
[](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void PreCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) {
|
||||
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
|
||||
CheckPreLabels("map", h_label,
|
||||
[](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -205,8 +205,13 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
|
||||
}
|
||||
|
||||
void PreCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||
CheckPreLabels("pre", d_label, CheckMAPOp{ctx->CUDACtx()});
|
||||
}
|
||||
|
||||
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||
CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()});
|
||||
CheckPreLabels("map", d_label, CheckMAPOp{ctx->CUDACtx()});
|
||||
}
|
||||
} // namespace xgboost::ltr
|
||||
|
||||
@ -366,18 +366,43 @@ bool IsBinaryRel(linalg::VectorView<float const> label, AllOf all_of) {
|
||||
});
|
||||
}
|
||||
/**
|
||||
* \brief Validate label for MAP
|
||||
* \brief Validate label for precision-based metric.
|
||||
*
|
||||
* \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
|
||||
* both CPU and GPU.
|
||||
*/
|
||||
template <typename AllOf>
|
||||
void CheckMapLabels(linalg::VectorView<float const> label, AllOf all_of) {
|
||||
void CheckPreLabels(StringView name, linalg::VectorView<float const> label, AllOf all_of) {
|
||||
auto s_label = label.Values();
|
||||
auto is_binary = IsBinaryRel(label, all_of);
|
||||
CHECK(is_binary) << "MAP can only be used with binary labels.";
|
||||
CHECK(is_binary) << name << " can only be used with binary labels.";
|
||||
}
|
||||
|
||||
class PreCache : public RankingCache {
|
||||
HostDeviceVector<double> pre_;
|
||||
|
||||
void InitOnCPU(Context const* ctx, MetaInfo const& info);
|
||||
void InitOnCUDA(Context const* ctx, MetaInfo const& info);
|
||||
|
||||
public:
|
||||
PreCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
|
||||
: RankingCache{ctx, info, p} {
|
||||
if (ctx->IsCPU()) {
|
||||
this->InitOnCPU(ctx, info);
|
||||
} else {
|
||||
this->InitOnCUDA(ctx, info);
|
||||
}
|
||||
}
|
||||
|
||||
common::Span<double> Pre(Context const* ctx) {
|
||||
if (pre_.Empty()) {
|
||||
pre_.SetDevice(ctx->gpu_id);
|
||||
pre_.Resize(this->Groups());
|
||||
}
|
||||
return ctx->IsCPU() ? pre_.HostSpan() : pre_.DeviceSpan();
|
||||
}
|
||||
};
|
||||
|
||||
class MAPCache : public RankingCache {
|
||||
// Total number of relevant documents for each group
|
||||
HostDeviceVector<double> n_rel_;
|
||||
|
||||
@ -366,8 +366,8 @@ inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, Da
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
|
||||
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
BatchParam const& param) {
|
||||
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const*,
|
||||
BatchParam const&) {
|
||||
common::AssertGPUSupport();
|
||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
||||
|
||||
@ -52,32 +52,13 @@ Metric::Create(const std::string& name, Context const* ctx) {
|
||||
metric->ctx_ = ctx;
|
||||
return metric;
|
||||
}
|
||||
|
||||
GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
|
||||
auto metric = CreateMetricImpl<MetricGPUReg>(name);
|
||||
if (metric == nullptr) {
|
||||
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
|
||||
<< ". Resorting to the CPU builder";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Narrowing reference only for the compiler to allow assignment to a base class member.
|
||||
// As such, using this narrowed reference to refer to derived members will be an illegal op.
|
||||
// This is moot, as this type is stateless.
|
||||
auto casted = static_cast<GPUMetric*>(metric);
|
||||
CHECK(casted);
|
||||
casted->ctx_ = ctx;
|
||||
return casted;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg);
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
namespace xgboost::metric {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(auc);
|
||||
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
|
||||
@ -88,5 +69,4 @@ DMLC_REGISTRY_LINK_TAG(rank_metric);
|
||||
DMLC_REGISTRY_LINK_TAG(auc_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_metric_gpu);
|
||||
#endif
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::metric
|
||||
|
||||
@ -23,53 +23,14 @@ class MetricNoCache : public Metric {
|
||||
|
||||
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
|
||||
double result{0.0};
|
||||
auto const& info = p_fmat->Info();
|
||||
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
|
||||
result = this->Eval(predts, info);
|
||||
});
|
||||
auto const &info = p_fmat->Info();
|
||||
collective::ApplyWithLabels(info, &result, sizeof(double),
|
||||
[&] { result = this->Eval(predts, info); });
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
|
||||
// present already. This is created when there is a device ordinal present and if xgboost
|
||||
// is compiled with CUDA support
|
||||
struct GPUMetric : public MetricNoCache {
|
||||
static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Internal registry entries for GPU Metric factory functions.
|
||||
* The additional parameter const char* param gives the value after @, can be null.
|
||||
* For example, metric map@3, then: param == "3".
|
||||
*/
|
||||
struct MetricGPUReg
|
||||
: public dmlc::FunctionRegEntryBase<MetricGPUReg,
|
||||
std::function<Metric * (const char*)> > {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register metric computed on GPU.
|
||||
*
|
||||
* \code
|
||||
* // example of registering a objective ndcg@k
|
||||
* XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg")
|
||||
* .describe("NDCG metric computer on GPU.")
|
||||
* .set_body([](const char* param) {
|
||||
* int at_k = atoi(param);
|
||||
* return new NDCG(at_k);
|
||||
* });
|
||||
* \endcode
|
||||
*/
|
||||
|
||||
// Note: Metric names registered in the GPU registry should follow this convention:
|
||||
// - GPU metric types should be registered with the same name as the non GPU metric types
|
||||
#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \
|
||||
::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \
|
||||
::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name)
|
||||
|
||||
namespace metric {
|
||||
|
||||
// Ranking config to be used on device and host
|
||||
struct EvalRankConfig {
|
||||
public:
|
||||
@ -81,8 +42,8 @@ struct EvalRankConfig {
|
||||
};
|
||||
|
||||
class PackedReduceResult {
|
||||
double residue_sum_ { 0 };
|
||||
double weights_sum_ { 0 };
|
||||
double residue_sum_{0};
|
||||
double weights_sum_{0};
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE PackedReduceResult() {} // NOLINT
|
||||
@ -91,16 +52,15 @@ class PackedReduceResult {
|
||||
|
||||
XGBOOST_DEVICE
|
||||
PackedReduceResult operator+(PackedReduceResult const &other) const {
|
||||
return PackedReduceResult{residue_sum_ + other.residue_sum_,
|
||||
weights_sum_ + other.weights_sum_};
|
||||
return PackedReduceResult{residue_sum_ + other.residue_sum_, weights_sum_ + other.weights_sum_};
|
||||
}
|
||||
PackedReduceResult &operator+=(PackedReduceResult const &other) {
|
||||
this->residue_sum_ += other.residue_sum_;
|
||||
this->weights_sum_ += other.weights_sum_;
|
||||
return *this;
|
||||
}
|
||||
double Residue() const { return residue_sum_; }
|
||||
double Weights() const { return weights_sum_; }
|
||||
[[nodiscard]] double Residue() const { return residue_sum_; }
|
||||
[[nodiscard]] double Weights() const { return weights_sum_; }
|
||||
};
|
||||
|
||||
} // namespace metric
|
||||
|
||||
@ -1,25 +1,6 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost contributors
|
||||
*/
|
||||
// When device ordinal is present, we would want to build the metrics on the GPU. It is *not*
|
||||
// possible for a valid device ordinal to be present for non GPU builds. However, it is possible
|
||||
// for an invalid device ordinal to be specified in GPU builds - to train/predict and/or compute
|
||||
// the metrics on CPU. To accommodate these scenarios, the following is done for the metrics
|
||||
// accelerated on the GPU.
|
||||
// - An internal GPU registry holds all the GPU metric types (defined in the .cu file)
|
||||
// - An instance of the appropriate GPU metric type is created when a device ordinal is present
|
||||
// - If the creation is successful, the metric computation is done on the device
|
||||
// - else, it falls back on the CPU
|
||||
// - The GPU metric types are *only* registered when xgboost is built for GPUs
|
||||
//
|
||||
// This is done for 2 reasons:
|
||||
// - Clear separation of CPU and GPU logic
|
||||
// - Sorting datasets containing large number of rows is (much) faster when parallel sort
|
||||
// semantics is used on the CPU. The __gnu_parallel/concurrency primitives needed to perform
|
||||
// this cannot be used when the translation unit is compiled using the 'nvcc' compiler (as the
|
||||
// corresponding headers that brings in those function declaration can't be included with CUDA).
|
||||
// This precludes the CPU and GPU logic to coexist inside a .cu file
|
||||
|
||||
#include "rank_metric.h"
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
@ -57,55 +38,8 @@
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace {
|
||||
|
||||
using PredIndPair = std::pair<xgboost::bst_float, xgboost::ltr::rel_degree_t>;
|
||||
using PredIndPairContainer = std::vector<PredIndPair>;
|
||||
|
||||
/*
|
||||
* Adapter to access instance weights.
|
||||
*
|
||||
* - For ranking task, weights are per-group
|
||||
* - For binary classification task, weights are per-instance
|
||||
*
|
||||
* WeightPolicy::GetWeightOfInstance() :
|
||||
* get weight associated with an individual instance, using index into
|
||||
* `info.weights`
|
||||
* WeightPolicy::GetWeightOfSortedRecord() :
|
||||
* get weight associated with an individual instance, using index into
|
||||
* sorted records `rec` (in ascending order of predicted labels). `rec` is
|
||||
* of type PredIndPairContainer
|
||||
*/
|
||||
|
||||
class PerInstanceWeightPolicy {
|
||||
public:
|
||||
inline static xgboost::bst_float
|
||||
GetWeightOfInstance(const xgboost::MetaInfo& info,
|
||||
unsigned instance_id, unsigned) {
|
||||
return info.GetWeight(instance_id);
|
||||
}
|
||||
inline static xgboost::bst_float
|
||||
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
|
||||
const PredIndPairContainer& rec,
|
||||
unsigned record_id, unsigned) {
|
||||
return info.GetWeight(rec[record_id].second);
|
||||
}
|
||||
};
|
||||
|
||||
class PerGroupWeightPolicy {
|
||||
public:
|
||||
inline static xgboost::bst_float
|
||||
GetWeightOfInstance(const xgboost::MetaInfo& info,
|
||||
unsigned, unsigned group_id) {
|
||||
return info.GetWeight(group_id);
|
||||
}
|
||||
|
||||
inline static xgboost::bst_float
|
||||
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
|
||||
const PredIndPairContainer&,
|
||||
unsigned, unsigned group_id) {
|
||||
return info.GetWeight(group_id);
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
namespace xgboost::metric {
|
||||
@ -177,10 +111,6 @@ struct EvalAMS : public MetricNoCache {
|
||||
|
||||
/*! \brief Evaluate rank list */
|
||||
struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
||||
private:
|
||||
// This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU.
|
||||
std::unique_ptr<MetricNoCache> rank_gpu_;
|
||||
|
||||
public:
|
||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||
@ -199,20 +129,10 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
||||
// sum statistics
|
||||
double sum_metric = 0.0f;
|
||||
|
||||
// Check and see if we have the GPU metric registered in the internal registry
|
||||
if (ctx_->gpu_id >= 0) {
|
||||
if (!rank_gpu_) {
|
||||
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_));
|
||||
}
|
||||
if (rank_gpu_) {
|
||||
sum_metric = rank_gpu_->Eval(preds, info);
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(ctx_);
|
||||
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
|
||||
|
||||
if (!rank_gpu_ || ctx_->gpu_id < 0) {
|
||||
{
|
||||
const auto& labels = info.labels.View(Context::kCpuId);
|
||||
const auto &h_preds = preds.ConstHostVector();
|
||||
|
||||
@ -253,23 +173,6 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
||||
virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
|
||||
};
|
||||
|
||||
/*! \brief Precision at N, for both classification and rank */
|
||||
struct EvalPrecision : public EvalRank {
|
||||
public:
|
||||
explicit EvalPrecision(const char* name, const char* param) : EvalRank(name, param) {}
|
||||
|
||||
double EvalGroup(PredIndPairContainer *recptr) const override {
|
||||
PredIndPairContainer &rec(*recptr);
|
||||
// calculate Precision
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhit = 0;
|
||||
for (size_t j = 0; j < rec.size() && j < this->topn; ++j) {
|
||||
nhit += (rec[j].second != 0);
|
||||
}
|
||||
return static_cast<double>(nhit) / this->topn;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
|
||||
struct EvalCox : public MetricNoCache {
|
||||
public:
|
||||
@ -312,7 +215,7 @@ struct EvalCox : public MetricNoCache {
|
||||
return out/num_events; // normalize by the number of events
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
[[nodiscard]] const char* Name() const override {
|
||||
return "cox-nloglik";
|
||||
}
|
||||
};
|
||||
@ -321,10 +224,6 @@ XGBOOST_REGISTER_METRIC(AMS, "ams")
|
||||
.describe("AMS metric for higgs.")
|
||||
.set_body([](const char* param) { return new EvalAMS(param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||
.describe("precision@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
||||
.describe("Negative log partial likelihood of Cox proportional hazards model.")
|
||||
.set_body([](const char*) { return new EvalCox(); });
|
||||
@ -387,6 +286,8 @@ class EvalRankWithCache : public Metric {
|
||||
return result;
|
||||
}
|
||||
|
||||
[[nodiscard]] const char* Name() const override { return name_.c_str(); }
|
||||
|
||||
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
|
||||
std::shared_ptr<Cache> p_cache) = 0;
|
||||
};
|
||||
@ -408,6 +309,52 @@ double Finalize(MetaInfo const& info, double score, double sw) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
|
||||
public:
|
||||
using EvalRankWithCache::EvalRankWithCache;
|
||||
|
||||
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
|
||||
std::shared_ptr<ltr::PreCache> p_cache) final {
|
||||
auto n_groups = p_cache->Groups();
|
||||
if (!info.weights_.Empty()) {
|
||||
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
|
||||
}
|
||||
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
|
||||
return Finalize(info, pre.Residue(), pre.Weights());
|
||||
}
|
||||
|
||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
|
||||
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
|
||||
|
||||
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
|
||||
auto pre = p_cache->Pre(ctx_);
|
||||
|
||||
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
|
||||
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]);
|
||||
|
||||
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
|
||||
double n_hits{0.0};
|
||||
for (std::size_t i = 0; i < n; ++i) {
|
||||
n_hits += g_label(g_rank[i]) * weight[g];
|
||||
}
|
||||
pre[g] = n_hits / static_cast<double>(n);
|
||||
});
|
||||
|
||||
auto sw = 0.0;
|
||||
for (std::size_t i = 0; i < pre.size(); ++i) {
|
||||
sw += weight[i];
|
||||
}
|
||||
|
||||
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
|
||||
return Finalize(info, sum, sw);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Implement the NDCG score function for learning to rank.
|
||||
*
|
||||
@ -416,7 +363,6 @@ double Finalize(MetaInfo const& info, double score, double sw) {
|
||||
class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
public:
|
||||
using EvalRankWithCache::EvalRankWithCache;
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
|
||||
double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache) override {
|
||||
@ -475,7 +421,6 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
public:
|
||||
using EvalRankWithCache::EvalRankWithCache;
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
|
||||
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache) override {
|
||||
@ -494,7 +439,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
|
||||
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
|
||||
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
auto g_rank = rank_idx.subspan(gptr[g]);
|
||||
auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]);
|
||||
|
||||
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
|
||||
double n_hits{0.0};
|
||||
@ -527,6 +472,10 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||
.describe("precision@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(EvalMAP, "map")
|
||||
.describe("map@k for ranking.")
|
||||
.set_body([](char const* param) {
|
||||
|
||||
@ -28,108 +28,57 @@ namespace xgboost::metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(rank_metric_gpu);
|
||||
|
||||
/*! \brief Evaluate rank list on GPU */
|
||||
template <typename EvalMetricT>
|
||||
struct EvalRankGpu : public GPUMetric, public EvalRankConfig {
|
||||
public:
|
||||
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
|
||||
// Sanity check is done by the caller
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(preds.Size());
|
||||
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
||||
namespace cuda_impl {
|
||||
PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt,
|
||||
std::shared_ptr<ltr::PreCache> p_cache) {
|
||||
auto d_gptr = p_cache->DataGroupPtr(ctx);
|
||||
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||
|
||||
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
predt.SetDevice(ctx->gpu_id);
|
||||
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
|
||||
auto topk = p_cache->Param().TopK();
|
||||
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
|
||||
|
||||
auto device = ctx_->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
info.labels.SetDevice(device);
|
||||
preds.SetDevice(device);
|
||||
|
||||
auto dpreds = preds.ConstDevicePointer();
|
||||
auto dlabels = info.labels.View(device);
|
||||
|
||||
// Sort all the predictions
|
||||
dh::SegmentSorter<float> segment_pred_sorter;
|
||||
segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr);
|
||||
|
||||
// Compute individual group metric and sum them up
|
||||
return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels.Values().data(), *this);
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
return name.c_str();
|
||||
}
|
||||
|
||||
explicit EvalRankGpu(const char* name, const char* param) {
|
||||
using namespace std; // NOLINT(*)
|
||||
if (param != nullptr) {
|
||||
std::ostringstream os;
|
||||
if (sscanf(param, "%u[-]?", &this->topn) == 1) {
|
||||
os << name << '@' << param;
|
||||
this->name = os.str();
|
||||
} else {
|
||||
os << name << param;
|
||||
this->name = os.str();
|
||||
}
|
||||
if (param[strlen(param) - 1] == '-') {
|
||||
this->minus = true;
|
||||
}
|
||||
} else {
|
||||
this->name = name;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Precision at N, for both classification and rank */
|
||||
struct EvalPrecisionGpu {
|
||||
public:
|
||||
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
|
||||
const float *dlabels,
|
||||
const EvalRankConfig &ecfg) {
|
||||
// Group info on device
|
||||
const auto &dgroups = pred_sorter.GetGroupsSpan();
|
||||
const auto ngroups = pred_sorter.GetNumGroups();
|
||||
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
|
||||
|
||||
// Original positions of the predictions after they have been sorted
|
||||
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
|
||||
|
||||
// First, determine non zero labels in the dataset individually
|
||||
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
|
||||
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
|
||||
}; // NOLINT
|
||||
|
||||
// Find each group's metric sum
|
||||
dh::caching_device_vector<uint32_t> hits(ngroups, 0);
|
||||
const auto nitems = pred_sorter.GetNumItems();
|
||||
auto *dhits = hits.data().get();
|
||||
|
||||
int device_id = -1;
|
||||
dh::safe_cuda(cudaGetDevice(&device_id));
|
||||
// For each group item compute the aggregated precision
|
||||
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
|
||||
const auto group_idx = dgroup_idx[idx];
|
||||
const auto group_begin = dgroups[group_idx];
|
||||
const auto ridx = idx - group_begin;
|
||||
if (ridx < ecfg.topn && DetermineNonTrivialLabelLambda(idx)) {
|
||||
atomicAdd(&dhits[group_idx], 1);
|
||||
auto it = dh::MakeTransformIterator<double>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
|
||||
auto g = dh::SegmentId(d_gptr, i);
|
||||
auto g_begin = d_gptr[g];
|
||||
auto g_end = d_gptr[g + 1];
|
||||
i -= g_begin;
|
||||
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
|
||||
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
|
||||
double y = g_label(g_rank[i]);
|
||||
auto n = std::min(static_cast<std::size_t>(topk), g_label.Size());
|
||||
double w{d_weight[g]};
|
||||
if (i >= n) {
|
||||
return 0.0;
|
||||
}
|
||||
return y / static_cast<double>(n) * w;
|
||||
});
|
||||
|
||||
// Allocator to be used for managing space overhead while performing reductions
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
return static_cast<double>(thrust::reduce(thrust::cuda::par(alloc),
|
||||
hits.begin(), hits.end())) / ecfg.topn;
|
||||
}
|
||||
};
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
auto pre = p_cache->Pre(ctx);
|
||||
thrust::fill_n(cuctx->CTP(), pre.data(), pre.size(), 0.0);
|
||||
|
||||
std::size_t bytes;
|
||||
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, it, pre.data(), p_cache->Groups(), d_gptr.data(),
|
||||
d_gptr.data() + 1, cuctx->Stream());
|
||||
dh::TemporaryArray<char> temp(bytes);
|
||||
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, it, pre.data(), p_cache->Groups(),
|
||||
d_gptr.data(), d_gptr.data() + 1, cuctx->Stream());
|
||||
|
||||
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
|
||||
.describe("precision@k for rank computed on GPU.")
|
||||
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
|
||||
auto w_it =
|
||||
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t g) { return d_weight[g]; });
|
||||
auto n_weights = p_cache->Groups();
|
||||
auto sw = dh::Reduce(cuctx->CTP(), w_it, w_it + n_weights, 0.0, thrust::plus<double>{});
|
||||
auto sum =
|
||||
dh::Reduce(cuctx->CTP(), dh::tcbegin(pre), dh::tcend(pre), 0.0, thrust::plus<double>{});
|
||||
auto result = PackedReduceResult{sum, sw};
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache) {
|
||||
|
||||
@ -12,9 +12,7 @@
|
||||
#include "xgboost/data.h" // for MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
namespace cuda_impl {
|
||||
namespace xgboost::metric::cuda_impl {
|
||||
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache);
|
||||
@ -23,6 +21,10 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache);
|
||||
|
||||
PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt,
|
||||
std::shared_ptr<ltr::PreCache> p_cache);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
|
||||
HostDeviceVector<float> const &, bool,
|
||||
@ -37,8 +39,13 @@ inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
|
||||
common::AssertGPUSupport();
|
||||
return {};
|
||||
}
|
||||
|
||||
inline PackedReduceResult PreScore(Context const *, MetaInfo const &,
|
||||
HostDeviceVector<float> const &,
|
||||
std::shared_ptr<ltr::PreCache>) {
|
||||
common::AssertGPUSupport();
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
} // namespace cuda_impl
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::metric::cuda_impl
|
||||
#endif // XGBOOST_METRIC_RANK_METRIC_H_
|
||||
|
||||
@ -90,7 +90,7 @@ def check_cmd_print_failure_assistance(cmd: List[str]) -> bool:
|
||||
|
||||
subprocess.run([cmd[0], "--version"])
|
||||
msg = """
|
||||
Please run the following command on your machine to address the formatting error:
|
||||
Please run the following command on your machine to address the error:
|
||||
|
||||
"""
|
||||
msg += " ".join(cmd)
|
||||
|
||||
@ -17,34 +17,30 @@
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/json.h" // for Json, String, Object
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
namespace xgboost::metric {
|
||||
|
||||
inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||
// When the limit for precision is not given, it takes the limit at
|
||||
// std::numeric_limits<unsigned>::max(); hence all values are very small
|
||||
// NOTE(AbdealiJK): Maybe this should be fixed to be num_row by default.
|
||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("pre", &ctx);
|
||||
std::unique_ptr<xgboost::Metric> metric{Metric::Create("pre", &ctx)};
|
||||
ASSERT_STREQ(metric->Name(), "pre");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0, 1e-7);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}, {}, {}, data_split_mode),
|
||||
0, 1e-7);
|
||||
EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7);
|
||||
EXPECT_NEAR(
|
||||
GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode),
|
||||
0.5, 1e-7);
|
||||
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("pre@2", &ctx);
|
||||
metric.reset(xgboost::Metric::Create("pre@2", &ctx));
|
||||
ASSERT_STREQ(metric->Name(), "pre@2");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}, {}, {}, data_split_mode),
|
||||
EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7);
|
||||
EXPECT_NEAR(
|
||||
GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode),
|
||||
0.5f, 0.001f);
|
||||
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}, {}, {}, data_split_mode));
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric.get(), {0, 1}, {}, {}, {}, data_split_mode));
|
||||
|
||||
delete metric;
|
||||
metric.reset(xgboost::Metric::Create("pre@4", &ctx));
|
||||
EXPECT_NEAR(GetMetricEval(metric.get(), {0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f},
|
||||
{0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f}, {}, {}, data_split_mode),
|
||||
0.5f, 1e-7);
|
||||
}
|
||||
|
||||
inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) {
|
||||
@ -187,5 +183,4 @@ inline void VerifyNDCGExpGain(DataSplitMode data_split_mode = DataSplitMode::kRo
|
||||
ndcg = metric->Evaluate(predt, p_fmat);
|
||||
ASSERT_NEAR(ndcg, 1.0, kRtEps);
|
||||
}
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::metric
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
import xgboost
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.metrics import check_quantile_error
|
||||
from xgboost.testing.metrics import check_precision_score, check_quantile_error
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import test_eval_metrics as test_em # noqa
|
||||
@ -59,6 +59,9 @@ class TestGPUEvalMetrics:
|
||||
def test_pr_auc_ltr(self):
|
||||
self.cpu_test.run_pr_auc_ltr("gpu_hist")
|
||||
|
||||
def test_precision_score(self):
|
||||
check_precision_score("gpu_hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_quantile_error(self) -> None:
|
||||
check_quantile_error("gpu_hist")
|
||||
|
||||
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.metrics import check_quantile_error
|
||||
from xgboost.testing.metrics import check_precision_score, check_quantile_error
|
||||
|
||||
rng = np.random.RandomState(1337)
|
||||
|
||||
@ -315,6 +315,9 @@ class TestEvalMetrics:
|
||||
def test_pr_auc_ltr(self):
|
||||
self.run_pr_auc_ltr("hist")
|
||||
|
||||
def test_precision_score(self):
|
||||
check_precision_score("hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_quantile_error(self) -> None:
|
||||
check_quantile_error("hist")
|
||||
|
||||
@ -55,6 +55,38 @@ class TestQuantileDMatrix:
|
||||
r = np.arange(1.0, n_samples)
|
||||
np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r)
|
||||
|
||||
def test_error(self):
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
rng = np.random.default_rng(1994)
|
||||
X, y = make_categorical(
|
||||
n_samples=128, n_features=2, n_categories=3, onehot=False
|
||||
)
|
||||
reg = xgb.XGBRegressor(tree_method="hist", enable_categorical=True)
|
||||
w = rng.uniform(0, 1, size=y.shape[0])
|
||||
|
||||
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
|
||||
X, y, w, random_state=1994
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="sample weight"):
|
||||
reg.fit(
|
||||
X,
|
||||
y,
|
||||
sample_weight=w_train,
|
||||
eval_set=[(X_test, y_test)],
|
||||
sample_weight_eval_set=[w_test],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="sample weight"):
|
||||
reg.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
sample_weight=w,
|
||||
eval_set=[(X_test, y_test)],
|
||||
sample_weight_eval_set=[w_test],
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
|
||||
def test_with_iterator(self, sparsity: float) -> None:
|
||||
n_samples_per_batch = 317
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user