Rework MAP and Pairwise for LTR. (#9075)

This commit is contained in:
Jiaming Yuan 2023-04-28 02:39:12 +08:00 committed by GitHub
parent 0e470ef606
commit e206b899ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 612 additions and 1135 deletions

View File

@ -32,7 +32,6 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \

View File

@ -32,7 +32,6 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \

View File

@ -219,6 +219,16 @@
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
},
"lambdarank_param": {
"type": "object",
"properties": {
"lambdarank_num_pair_per_sample": { "type": "string" },
"lambdarank_pair_method": { "type": "string" },
"lambdarank_unbiased": {"type": "string" },
"lambdarank_bias_norm": {"type": "string" },
"ndcg_exp_gain": {"type": "string"}
}
}
},
"type": "object",
@ -477,22 +487,22 @@
"type": "object",
"properties": {
"name": { "const": "rank:pairwise" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
"lambda_rank_param"
"lambdarank_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:ndcg" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
"lambda_rank_param"
"lambdarank_param"
]
},
{

View File

@ -233,7 +233,7 @@ Parameters for Tree Booster
.. note:: This parameter is working-in-progress.
- The strategy used for training multi-target models, including multi-target regression
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
- ``one_output_per_tree``: One model for each target.
- ``multi_output_tree``: Use multi-target trees.
@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
See :doc:`/tutorials/aft_survival_analysis` for details.
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
- ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
- ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized
- ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
- ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized. This objective supports position debiasing for click data.
- ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
- ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective.
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.
@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
* ``eval_metric`` [default according to objective]
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
- The choices are listed below:
- ``rmse``: `root mean square error <http://en.wikipedia.org/wiki/Root_mean_square_error>`_
@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli
* ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``.
.. _ltr-param:
Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
================================================================================
These are parameters specific to learning to rank task. See :doc:`Learning to Rank </tutorials/learning_to_rank>` for an in-depth explanation.
* ``lambdarank_pair_method`` [default = ``mean``]
How to construct pairs for pair-wise learning.
- ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
- ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
* ``lambdarank_unbiased`` [default = ``false``]
Specify whether do we need to debias input click data.
* ``lambdarank_bias_norm`` [default = 2.0]
:math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
* ``ndcg_exp_gain`` [default = ``true``]
Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
***********************
Command Line Parameters
***********************

View File

@ -431,8 +431,11 @@ def make_ltr(
"""Make a dataset for testing LTR."""
rng = np.random.default_rng(1994)
X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features)
y = rng.integers(0, max_rel, size=n_samples)
qid = rng.integers(0, n_query_groups, size=n_samples)
y = np.sum(X, axis=1)
y -= y.min()
y = np.round(y / y.max() * max_rel).astype(np.int32)
qid = rng.integers(0, n_query_groups, size=n_samples, dtype=np.int32)
w = rng.normal(0, 1.0, size=n_query_groups)
w -= np.min(w)
w /= np.max(w)

View File

@ -493,7 +493,6 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]);

View File

@ -69,6 +69,7 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
lj(i) += g_lj(i);
}
}
// The ti+ is not guaranteed to decrease since it depends on the |\delta Z|
//
// The update normalizes the ti+ to make ti+(0) equal to 1, which breaks the probability
@ -432,9 +433,201 @@ void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView<double cons
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl
namespace cpu_impl {
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache) {
auto h_n_rel = p_cache->NumRelevant(ctx);
auto gptr = p_cache->DataGroupPtr(ctx);
CHECK_EQ(h_n_rel.size(), gptr.back());
CHECK_EQ(h_n_rel.size(), label.Size());
auto h_acc = p_cache->Acc(ctx);
common::ParallelFor(p_cache->Groups(), ctx->Threads(), [&](auto g) {
auto cnt = gptr[g + 1] - gptr[g];
auto g_n_rel = h_n_rel.subspan(gptr[g], cnt);
auto g_rank = rank_idx.subspan(gptr[g], cnt);
auto g_label = label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
// The number of relevant documents at each position
g_n_rel[0] = g_label(g_rank[0]);
for (std::size_t k = 1; k < g_rank.size(); ++k) {
g_n_rel[k] = g_n_rel[k - 1] + g_label(g_rank[k]);
}
// \sum l_k/k
auto g_acc = h_acc.subspan(gptr[g], cnt);
g_acc[0] = g_label(g_rank[0]) / 1.0;
for (std::size_t k = 1; k < g_rank.size(); ++k) {
g_acc[k] = g_acc[k - 1] + (g_label(g_rank[k]) / static_cast<double>(k + 1));
}
});
}
} // namespace cpu_impl
class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
public:
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) {
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective.";
if (ctx_->IsCUDA()) {
return cuda_impl::LambdaRankGetGradientMAP(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
out_gpair);
}
auto gptr = p_cache_->DataGroupPtr(ctx_).data();
bst_group_t n_groups = p_cache_->Groups();
out_gpair->Resize(info.num_row_);
auto h_gpair = out_gpair->HostSpan();
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = predt.ConstHostSpan();
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
cpu_impl::MAPStat(ctx_, h_label, rank_idx, GetCache());
auto n_rel = GetCache()->NumRelevant(ctx_);
auto acc = GetCache()->Acc(ctx_);
auto delta_map = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low,
bst_group_t g) {
if (rank_high > rank_low) {
std::swap(rank_high, rank_low);
std::swap(y_high, y_low);
}
auto cnt = gptr[g + 1] - gptr[g];
// In a hot loop
auto g_n_rel = common::Span<double const>{n_rel.data() + gptr[g], cnt};
auto g_acc = common::Span<double const>{acc.data() + gptr[g], cnt};
auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc);
return d;
};
using D = decltype(delta_map);
common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
auto cnt = gptr[g + 1] - gptr[g];
auto w = h_weight[g];
auto g_predt = h_predt.subspan(gptr[g], cnt);
auto g_gpair = h_gpair.subspan(gptr[g], cnt);
auto g_label = h_label.Slice(make_range(g));
auto g_rank = rank_idx.subspan(gptr[g], cnt);
auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair);
if (param_.lambdarank_unbiased) {
std::apply(&LambdaRankMAP::CalcLambdaForGroup<true, D>, args);
} else {
std::apply(&LambdaRankMAP::CalcLambdaForGroup<false, D>, args);
}
});
}
static char const* Name() { return "rank:map"; }
[[nodiscard]] const char* DefaultEvalMetric() const override {
return this->RankEvalMetric("map");
}
};
#if !defined(XGBOOST_USE_CUDA)
namespace cuda_impl {
void MAPStat(Context const*, MetaInfo const&, common::Span<std::size_t const>,
std::shared_ptr<ltr::MAPCache>) {
common::AssertGPUSupport();
}
void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector<float> const&,
const MetaInfo&, std::shared_ptr<ltr::MAPCache>,
linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double>, linalg::VectorView<double>,
HostDeviceVector<GradientPair>*) {
common::AssertGPUSupport();
}
} // namespace cuda_impl
#endif // !defined(XGBOOST_USE_CUDA)
/**
* \brief The RankNet loss.
*/
class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::RankingCache> {
public:
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) {
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective.";
if (ctx_->IsCUDA()) {
return cuda_impl::LambdaRankGetGradientPairwise(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
out_gpair);
}
auto gptr = p_cache_->DataGroupPtr(ctx_);
bst_group_t n_groups = p_cache_->Groups();
out_gpair->Resize(info.num_row_);
auto h_gpair = out_gpair->HostSpan();
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = predt.ConstHostSpan();
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
auto delta = [](auto...) { return 1.0; };
using D = decltype(delta);
common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
auto cnt = gptr[g + 1] - gptr[g];
auto w = h_weight[g];
auto g_predt = h_predt.subspan(gptr[g], cnt);
auto g_gpair = h_gpair.subspan(gptr[g], cnt);
auto g_label = h_label.Slice(make_range(g));
auto g_rank = rank_idx.subspan(gptr[g], cnt);
auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
if (param_.lambdarank_unbiased) {
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<true, D>, args);
} else {
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<false, D>, args);
}
});
}
static char const* Name() { return "rank:pairwise"; }
[[nodiscard]] const char* DefaultEvalMetric() const override {
return this->RankEvalMetric("ndcg");
}
};
#if !defined(XGBOOST_USE_CUDA)
namespace cuda_impl {
void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVector<float> const&,
const MetaInfo&, std::shared_ptr<ltr::RankingCache>,
linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double>, linalg::VectorView<double>,
HostDeviceVector<GradientPair>*) {
common::AssertGPUSupport();
}
} // namespace cuda_impl
#endif // !defined(XGBOOST_USE_CUDA)
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name())
.describe("LambdaRank with NDCG loss as objective")
.set_body([]() { return new LambdaRankNDCG{}; });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name())
.describe("LambdaRank with RankNet loss as objective")
.set_body([]() { return new LambdaRankPairwise{}; });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name())
.describe("LambdaRank with MAP loss as objective.")
.set_body([]() { return new LambdaRankMAP{}; });
DMLC_REGISTRY_FILE_TAG(lambdarank_obj);
} // namespace xgboost::obj

View File

@ -390,6 +390,112 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair);
}
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
std::shared_ptr<ltr::MAPCache> p_cache) {
common::Span<double> out_n_rel = p_cache->NumRelevant(ctx);
common::Span<double> out_acc = p_cache->Acc(ctx);
CHECK_EQ(out_n_rel.size(), info.num_row_);
CHECK_EQ(out_acc.size(), info.num_row_);
auto group_ptr = p_cache->DataGroupPtr(ctx);
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); });
auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
auto const* cuctx = ctx->CUDACtx();
{
// calculate number of relevant documents
auto val_it = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double {
auto g = dh::SegmentId(group_ptr, i);
auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]));
auto idx_in_group = i - group_ptr[g];
auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]);
return static_cast<double>(g_label(g_sorted_idx[idx_in_group]));
});
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it,
out_n_rel.data());
}
{
// \sum l_k/k
auto val_it = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double {
auto g = dh::SegmentId(group_ptr, i);
auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]));
auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]);
auto idx_in_group = i - group_ptr[g];
double rank_in_group = idx_in_group + 1.0;
return static_cast<double>(g_label(g_sorted_idx[idx_in_group])) / rank_in_group;
});
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it,
out_acc.data());
}
}
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, const MetaInfo& info,
std::shared_ptr<ltr::MAPCache> p_cache,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) {
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
info.labels.SetDevice(device_id);
predt.SetDevice(device_id);
CHECK(p_cache);
auto d_predt = predt.ConstDeviceSpan();
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
MAPStat(ctx, info, d_sorted_idx, p_cache);
auto d_n_rel = p_cache->NumRelevant(ctx);
auto d_acc = p_cache->Acc(ctx);
auto d_gptr = p_cache->DataGroupPtr(ctx).data();
auto delta_map = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
std::size_t rank_low, bst_group_t g) {
if (rank_high > rank_low) {
thrust::swap(rank_high, rank_low);
thrust::swap(y_high, y_low);
}
auto cnt = d_gptr[g + 1] - d_gptr[g];
auto g_n_rel = d_n_rel.subspan(d_gptr[g], cnt);
auto g_acc = d_acc.subspan(d_gptr[g], cnt);
auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc);
return d;
};
Launch(ctx, iter, predt, info, p_cache, delta_map, ti_plus, tj_minus, li, lj, out_gpair);
}
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, const MetaInfo& info,
std::shared_ptr<ltr::RankingCache> p_cache,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) {
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
info.labels.SetDevice(device_id);
predt.SetDevice(device_id);
auto d_predt = predt.ConstDeviceSpan();
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
auto delta = [] XGBOOST_DEVICE(float, float, std::size_t, std::size_t, bst_group_t) {
return 1.0;
};
Launch(ctx, iter, predt, info, p_cache, delta, ti_plus, tj_minus, li, lj, out_gpair);
}
namespace {
struct ReduceOp {
template <typename Tup>

View File

@ -156,6 +156,27 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
/**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
*/
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
std::shared_ptr<ltr::MAPCache> p_cache);
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache,
linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, const MetaInfo& info,
std::shared_ptr<ltr::RankingCache> p_cache,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
linalg::VectorView<double const> lj_full,
@ -165,6 +186,18 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
std::shared_ptr<ltr::RankingCache> p_cache);
} // namespace cuda_impl
namespace cpu_impl {
/**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
*
* \param label Ground truth relevance label.
* \param rank_idx Sorted index of prediction.
* \param p_cache An initialized MAPCache.
*/
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache);
} // namespace cpu_impl
/**
* \param Construct pairs on CPU
*

View File

@ -47,7 +47,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu);
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
DMLC_REGISTRY_LINK_TAG(rank_obj_gpu);
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu);
#else
@ -55,7 +54,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj);
DMLC_REGISTRY_LINK_TAG(quantile_obj);
DMLC_REGISTRY_LINK_TAG(hinge_obj);
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
DMLC_REGISTRY_LINK_TAG(rank_obj);
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
#endif // XGBOOST_USE_CUDA
} // namespace obj

View File

@ -1,17 +0,0 @@
/*!
* Copyright 2019 XGBoost contributors
*/
// Dummy file to keep the CUDA conditional compile trick.
#include <dmlc/registry.h>
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(rank_obj);
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#include "rank_obj.cu"
#endif // XGBOOST_USE_CUDA

View File

@ -1,789 +0,0 @@
/*!
* Copyright 2015-2022 XGBoost contributors
*/
#include <dmlc/omp.h>
#include <dmlc/timer.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "xgboost/json.h"
#include "xgboost/parameter.h"
#include "../common/math.h"
#include "../common/random.h"
#if defined(__CUDACC__)
#include <thrust/sort.h>
#include <thrust/gather.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/random/uniform_int_distribution.h>
#include <thrust/random/linear_congruential_engine.h>
#include <cub/util_allocator.cuh>
#include "../common/device_helpers.cuh"
#endif
namespace xgboost {
namespace obj {
#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(rank_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
size_t num_pairsample;
float fix_list_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(LambdaRankParam) {
DMLC_DECLARE_FIELD(num_pairsample).set_lower_bound(1).set_default(1)
.describe("Number of pair generated for each instance.");
DMLC_DECLARE_FIELD(fix_list_weight).set_lower_bound(0.0f).set_default(0.0f)
.describe("Normalize the weight of each list by this value,"
" if equals 0, no effect will happen");
}
};
#if defined(__CUDACC__)
// Helper functions
template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) {
return thrust::lower_bound(thrust::seq, items, items + n, v,
thrust::greater<T>()) -
items;
}
template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) {
return n - (thrust::upper_bound(thrust::seq, items, items + n, v,
thrust::greater<T>()) -
items);
}
#endif
/*! \brief helper information in a list */
struct ListEntry {
/*! \brief the predict score we in the data */
bst_float pred;
/*! \brief the actual label of the entry */
bst_float label;
/*! \brief row index in the data matrix */
unsigned rindex;
// constructor
ListEntry(bst_float pred, bst_float label, unsigned rindex)
: pred(pred), label(label), rindex(rindex) {}
// comparator by prediction
inline static bool CmpPred(const ListEntry &a, const ListEntry &b) {
return a.pred > b.pred;
}
// comparator by label
inline static bool CmpLabel(const ListEntry &a, const ListEntry &b) {
return a.label > b.label;
}
};
/*! \brief a pair in the lambda rank */
struct LambdaPair {
/*! \brief positive index: this is a position in the list */
unsigned pos_index;
/*! \brief negative index: this is a position in the list */
unsigned neg_index;
/*! \brief weight to be filled in */
bst_float weight;
// constructor
LambdaPair(unsigned pos_index, unsigned neg_index)
: pos_index(pos_index), neg_index(neg_index), weight(1.0f) {}
// constructor
LambdaPair(unsigned pos_index, unsigned neg_index, bst_float weight)
: pos_index(pos_index), neg_index(neg_index), weight(weight) {}
};
class PairwiseLambdaWeightComputer {
public:
/*!
* \brief get lambda weight for existing pairs - for pairwise objective
* \param list a list that is sorted by pred score
* \param io_pairs record of pairs, containing the pairs to fill in weights
*/
static void GetLambdaWeight(const std::vector<ListEntry>&,
std::vector<LambdaPair>*) {}
static char const* Name() {
return "rank:pairwise";
}
#if defined(__CUDACC__)
PairwiseLambdaWeightComputer(const bst_float*,
const bst_float*,
const dh::SegmentSorter<float>&) {}
class PairwiseLambdaWeightMultiplier {
public:
// Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
return 1.0f;
}
};
inline const PairwiseLambdaWeightMultiplier GetWeightMultiplier() const {
return {};
}
#endif
};
#if defined(__CUDACC__)
class BaseLambdaWeightMultiplier {
public:
BaseLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
const dh::SegmentSorter<float> &segment_pred_sorter)
: dsorted_labels_(segment_label_sorter.GetItemsSpan()),
dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()),
dgroups_(segment_label_sorter.GetGroupsSpan()),
dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {}
protected:
const common::Span<const float> dsorted_labels_; // Labels sorted within a group
const common::Span<const uint32_t> dorig_pos_; // Original indices of the labels
// before they are sorted
const common::Span<const uint32_t> dgroups_; // The group indices
// Where can a prediction for a label be found in the original array, when they are sorted
const common::Span<const uint32_t> dindexable_sorted_preds_pos_;
};
// While computing the weight that needs to be adjusted by this ranking objective, we need
// to figure out where positive and negative labels chosen earlier exists, if the group
// were to be sorted by its predictions. To accommodate this, we employ the following algorithm.
// For a given group, let's assume the following:
// labels: 1 5 9 2 4 8 0 7 6 3
// predictions: 1 9 0 8 2 7 3 6 5 4
// position: 0 1 2 3 4 5 6 7 8 9
//
// After label sort:
// labels: 9 8 7 6 5 4 3 2 1 0
// position: 2 5 7 8 1 4 9 3 0 6
//
// After prediction sort:
// predictions: 9 8 7 6 5 4 3 2 1 0
// position: 1 3 5 7 8 9 6 4 0 2
//
// If a sorted label at position 'x' is chosen, then we need to find out where the prediction
// for this label 'x' exists, if the group were to be sorted by predictions.
// We first take the sorted prediction positions:
// position: 1 3 5 7 8 9 6 4 0 2
// at indices: 0 1 2 3 4 5 6 7 8 9
//
// We create a sorted prediction positional array, such that value at position 'x' gives
// us the position in the sorted prediction array where its related prediction lies.
// dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5
// at indices: 0 1 2 3 4 5 6 7 8 9
// Basically, swap the previous 2 arrays, sort the indices and reorder positions
// for an O(1) lookup using the position where the sorted label exists.
//
// This type does that using the SegmentSorter
class IndexablePredictionSorter {
public:
IndexablePredictionSorter(const bst_float *dpreds,
const dh::SegmentSorter<float> &segment_label_sorter) {
// Sort the predictions first
segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(),
segment_label_sorter.GetGroupSegmentsSpan());
// Create an index for the sorted prediction positions
segment_pred_sorter_.CreateIndexableSortedPositions();
}
inline const dh::SegmentSorter<float> &GetPredictionSorter() const {
return segment_pred_sorter_;
}
private:
dh::SegmentSorter<float> segment_pred_sorter_; // For sorting the predictions
};
#endif
class MAPLambdaWeightComputer
#if defined(__CUDACC__)
: public IndexablePredictionSorter
#endif
{
public:
struct MAPStats {
/*! \brief the accumulated precision */
float ap_acc{0.0f};
/*!
* \brief the accumulated precision,
* assuming a positive instance is missing
*/
float ap_acc_miss{0.0f};
/*!
* \brief the accumulated precision,
* assuming that one more positive instance is inserted ahead
*/
float ap_acc_add{0.0f};
/* \brief the accumulated positive instance count */
float hits{0.0f};
XGBOOST_DEVICE MAPStats() {} // NOLINT
XGBOOST_DEVICE MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits)
: ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {}
// For prefix scan
XGBOOST_DEVICE MAPStats operator +(const MAPStats &v1) const {
return {ap_acc + v1.ap_acc, ap_acc_miss + v1.ap_acc_miss,
ap_acc_add + v1.ap_acc_add, hits + v1.hits};
}
// For test purposes - compare for equality
XGBOOST_DEVICE bool operator ==(const MAPStats &rhs) const {
return ap_acc == rhs.ap_acc && ap_acc_miss == rhs.ap_acc_miss &&
ap_acc_add == rhs.ap_acc_add && hits == rhs.hits;
}
};
private:
template <typename T>
XGBOOST_DEVICE inline static void Swap(T &v0, T &v1) {
#if defined(__CUDACC__)
thrust::swap(v0, v1);
#else
std::swap(v0, v1);
#endif
}
/*!
* \brief Obtain the delta MAP by trying to switch the positions of labels in pos_pred_pos or
* neg_pred_pos when sorted by predictions
* \param pos_pred_pos positive label's prediction value position when the groups prediction
* values are sorted
* \param neg_pred_pos negative label's prediction value position when the groups prediction
* values are sorted
* \param pos_label, neg_label the chosen positive and negative labels
* \param p_map_stats a vector containing the accumulated precisions for each position in a list
* \param map_stats_size size of the accumulated precisions vector
*/
XGBOOST_DEVICE inline static bst_float GetLambdaMAP(
int pos_pred_pos, int neg_pred_pos,
bst_float pos_label, bst_float neg_label,
const MAPStats *p_map_stats, uint32_t map_stats_size) {
if (pos_pred_pos == neg_pred_pos || p_map_stats[map_stats_size - 1].hits == 0) {
return 0.0f;
}
if (pos_pred_pos > neg_pred_pos) {
Swap(pos_pred_pos, neg_pred_pos);
Swap(pos_label, neg_label);
}
bst_float original = p_map_stats[neg_pred_pos].ap_acc;
if (pos_pred_pos != 0) original -= p_map_stats[pos_pred_pos - 1].ap_acc;
bst_float changed = 0;
bst_float label1 = pos_label > 0.0f ? 1.0f : 0.0f;
bst_float label2 = neg_label > 0.0f ? 1.0f : 0.0f;
if (label1 == label2) {
return 0.0;
} else if (label1 < label2) {
changed += p_map_stats[neg_pred_pos - 1].ap_acc_add - p_map_stats[pos_pred_pos].ap_acc_add;
changed += (p_map_stats[pos_pred_pos].hits + 1.0f) / (pos_pred_pos + 1);
} else {
changed += p_map_stats[neg_pred_pos - 1].ap_acc_miss - p_map_stats[pos_pred_pos].ap_acc_miss;
changed += p_map_stats[neg_pred_pos].hits / (neg_pred_pos + 1);
}
bst_float ans = (changed - original) / (p_map_stats[map_stats_size - 1].hits);
if (ans < 0) ans = -ans;
return ans;
}
public:
/*
* \brief obtain preprocessing results for calculating delta MAP
* \param sorted_list the list containing entry information
* \param map_stats a vector containing the accumulated precisions for each position in a list
*/
inline static void GetMAPStats(const std::vector<ListEntry> &sorted_list,
std::vector<MAPStats> *p_map_acc) {
std::vector<MAPStats> &map_acc = *p_map_acc;
map_acc.resize(sorted_list.size());
bst_float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
for (size_t i = 1; i <= sorted_list.size(); ++i) {
if (sorted_list[i - 1].label > 0.0f) {
hit++;
acc1 += hit / i;
acc2 += (hit - 1) / i;
acc3 += (hit + 1) / i;
}
map_acc[i - 1] = MAPStats(acc1, acc2, acc3, hit);
}
}
static char const* Name() {
return "rank:map";
}
static void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) {
std::vector<LambdaPair> &pairs = *io_pairs;
std::vector<MAPStats> map_stats;
GetMAPStats(sorted_list, &map_stats);
for (auto & pair : pairs) {
pair.weight *=
GetLambdaMAP(pair.pos_index, pair.neg_index,
sorted_list[pair.pos_index].label, sorted_list[pair.neg_index].label,
&map_stats[0], map_stats.size());
}
}
#if defined(__CUDACC__)
MAPLambdaWeightComputer(const bst_float *dpreds,
const bst_float *dlabels,
const dh::SegmentSorter<float> &segment_label_sorter)
: IndexablePredictionSorter(dpreds, segment_label_sorter),
dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()),
weight_multiplier_(segment_label_sorter, *this) {
this->CreateMAPStats(dlabels, segment_label_sorter);
}
void CreateMAPStats(const bst_float *dlabels,
const dh::SegmentSorter<float> &segment_label_sorter) {
// For each group, go through the sorted prediction positions, and look up its corresponding
// label from the unsorted labels (from the original label list)
// For each item in the group, compute its MAP stats.
// Interleave the computation of map stats amongst different groups.
// First, determine postive labels in the dataset individually
auto nitems = segment_label_sorter.GetNumItems();
dh::caching_device_vector<uint32_t> dhits(nitems, 0);
// Original positions of the predictions after they have been sorted
const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan();
// Unsorted labels
const float *unsorted_labels = dlabels;
auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) {
return (unsorted_labels[pred_original_pos[idx]] > 0.0f) ? 1 : 0;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(nitems),
dhits.begin(),
DeterminePositiveLabelLambda);
// Allocator to be used by sort for managing space overhead while performing prefix scans
dh::XGBCachingDeviceAllocator<char> alloc;
// Next, prefix scan the positive labels that are segmented to accumulate them.
// This is required for computing the accumulated precisions
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
// Data segmented into different groups...
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
dhits.begin(), // Input value
dhits.begin()); // In-place scan
// Compute accumulated precisions for each item, assuming positive and
// negative instances are missing.
// But first, compute individual item precisions
const auto *dhits_arr = dhits.data().get();
// Group info on device
const auto &dgroups = segment_label_sorter.GetGroupsSpan();
auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) {
if (unsorted_labels[pred_original_pos[idx]] > 0.0f) {
auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1;
return MAPStats{static_cast<float>(dhits_arr[idx]) / idx_within_group,
static_cast<float>(dhits_arr[idx] - 1) / idx_within_group,
static_cast<float>(dhits_arr[idx] + 1) / idx_within_group,
1.0f};
}
return MAPStats{};
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(nitems),
this->dmap_stats_.begin(),
ComputeItemPrecisionLambda);
// Lastly, compute the accumulated precisions for all the items segmented by groups.
// The precisions are accumulated within each group
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
this->dmap_stats_.begin(), // Input map stats
this->dmap_stats_.begin()); // In-place scan and output here
}
inline const common::Span<const MAPStats> GetMapStatsSpan() const {
return { dmap_stats_.data().get(), dmap_stats_.size() };
}
// Type containing device pointers that can be cheaply copied on the kernel
class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
public:
MAPLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
const MAPLambdaWeightComputer &lwc)
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
dmap_stats_(lwc.GetMapStatsSpan()) {}
// Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
uint32_t group_begin = dgroups_[gidx];
uint32_t group_end = dgroups_[gidx + 1];
auto pos_lab_orig_posn = dorig_pos_[pidx];
auto neg_lab_orig_posn = dorig_pos_[nidx];
KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn);
// Note: the label positive and negative indices are relative to the entire dataset.
// Hence, scale them back to an index within the group
auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
return MAPLambdaWeightComputer::GetLambdaMAP(
pos_pred_pos, neg_pred_pos,
dsorted_labels_[pidx], dsorted_labels_[nidx],
&dmap_stats_[group_begin], group_end - group_begin);
}
private:
common::Span<const MAPStats> dmap_stats_; // Start address of the map stats for every sorted
// prediction value
};
inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; }
private:
dh::caching_device_vector<MAPStats> dmap_stats_;
// This computes the adjustment to the weight
const MAPLambdaWeightMultiplier weight_multiplier_;
#endif
};
#if defined(__CUDACC__)
class SortedLabelList : dh::SegmentSorter<float> {
private:
const LambdaRankParam &param_; // Objective configuration
public:
explicit SortedLabelList(const LambdaRankParam &param)
: param_(param) {}
// Sort the labels that are grouped by 'groups'
void Sort(const HostDeviceVector<bst_float> &dlabels, const std::vector<uint32_t> &groups) {
this->SortItems(dlabels.ConstDevicePointer(), dlabels.Size(), groups);
}
// This kernel can only run *after* the kernel in sort is completed, as they
// use the default stream
template <typename LambdaWeightComputerT>
void ComputeGradients(const bst_float *dpreds, // Unsorted predictions
const bst_float *dlabels, // Unsorted labels
const HostDeviceVector<bst_float> &weights,
int iter,
GradientPair *out_gpair,
float weight_normalization_factor) {
// Group info on device
const auto &dgroups = this->GetGroupsSpan();
uint32_t ngroups = this->GetNumGroups() + 1;
uint32_t total_items = this->GetNumItems();
uint32_t niter = param_.num_pairsample * total_items;
float fix_list_weight = param_.fix_list_weight;
const auto &original_pos = this->GetOriginalPositionsSpan();
uint32_t num_weights = weights.Size();
auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr;
const auto &sorted_labels = this->GetItemsSpan();
// This is used to adjust the weight of different elements based on the different ranking
// objective function policies
LambdaWeightComputerT weight_computer(dpreds, dlabels, *this);
auto wmultiplier = weight_computer.GetWeightMultiplier();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each instance in the group, compute the gradient pair concurrently
dh::LaunchN(niter, nullptr, [=] __device__(uint32_t idx) {
// First, determine the group 'idx' belongs to
uint32_t item_idx = idx % total_items;
uint32_t group_idx =
thrust::upper_bound(thrust::seq, dgroups.begin(),
dgroups.begin() + ngroups, item_idx) -
dgroups.begin();
// Span of this group within the larger labels/predictions sorted tuple
uint32_t group_begin = dgroups[group_idx - 1];
uint32_t group_end = dgroups[group_idx];
uint32_t total_group_items = group_end - group_begin;
// Are the labels diverse enough? If they are all the same, then there is nothing to pick
// from another group - bail sooner
if (sorted_labels[group_begin] == sorted_labels[group_end - 1]) return;
// Find the number of labels less than and greater than the current label
// at the sorted index position item_idx
uint32_t nleft = CountNumItemsToTheLeftOf(
sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
uint32_t nright = CountNumItemsToTheRightOf(
sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]);
// Create a minstd_rand object to act as our source of randomness
thrust::minstd_rand rng((iter + 1) * 1111);
rng.discard(((idx / total_items) * total_group_items) + item_idx - group_begin);
// Create a uniform_int_distribution to produce a sample from outside of the
// present label group
thrust::uniform_int_distribution<int> dist(0, nleft + nright - 1);
int sample = dist(rng);
int pos_idx = -1; // Bigger label
int neg_idx = -1; // Smaller label
// Are we picking a sample to the left/right of the current group?
if (sample < nleft) {
// Go left
pos_idx = sample + group_begin;
neg_idx = item_idx;
} else {
pos_idx = item_idx;
uint32_t items_in_group = total_group_items - nleft - nright;
neg_idx = sample + items_in_group + group_begin;
}
// Compute and assign the gradients now
const float eps = 1e-16f;
bst_float p = common::Sigmoid(dpreds[original_pos[pos_idx]] - dpreds[original_pos[neg_idx]]);
bst_float g = p - 1.0f;
bst_float h = thrust::max(p * (1.0f - p), eps);
// Rescale each gradient and hessian so that the group has a weighted constant
float scale = __frcp_ru(niter / total_items);
if (fix_list_weight != 0.0f) {
scale *= fix_list_weight / total_group_items;
}
float weight = num_weights ? dweights[group_idx - 1] : 1.0f;
weight *= weight_normalization_factor;
weight *= wmultiplier.GetWeight(group_idx - 1, pos_idx, neg_idx);
weight *= scale;
// Accumulate gradient and hessian in both positive and negative indices
const GradientPair in_pos_gpair(g * weight, 2.0f * weight * h);
dh::AtomicAddGpair(&out_gpair[original_pos[pos_idx]], in_pos_gpair);
const GradientPair in_neg_gpair(-g * weight, 2.0f * weight * h);
dh::AtomicAddGpair(&out_gpair[original_pos[neg_idx]], in_neg_gpair);
});
// Wait until the computations done by the kernel is complete
dh::safe_cuda(cudaStreamSynchronize(nullptr));
}
};
#endif
// objective for lambda rank
template <typename LambdaWeightComputerT>
class LambdaRankObj : public ObjFunction {
public:
void Configure(Args const &args) override { param_.UpdateAllowUnknown(args); }
ObjInfo Task() const override { return ObjInfo::kRanking; }
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_EQ(preds.Size(), info.labels.Size()) << "label size predict size not match";
// quick consistency when group is not available
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
CHECK(gptr.size() != 0 && gptr.back() == info.labels.Size())
<< "group structure not consistent with #rows" << ", "
<< "group ponter size: " << gptr.size() << ", "
<< "labels size: " << info.labels.Size() << ", "
<< "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back());
#if defined(__CUDACC__)
// Check if we have a GPU assignment; else, revert back to CPU
auto device = ctx_->gpu_id;
if (device >= 0) {
ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr);
} else {
// Revert back to CPU
#endif
ComputeGradientsOnCPU(preds, info, iter, out_gpair, gptr);
#if defined(__CUDACC__)
}
#endif
}
const char* DefaultEvalMetric() const override {
return "map";
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(LambdaWeightComputerT::Name());
out["lambda_rank_param"] = ToJson(param_);
}
void LoadConfig(Json const& in) override {
FromJson(in["lambda_rank_param"], &param_);
}
private:
bst_float ComputeWeightNormalizationFactor(const MetaInfo& info,
const std::vector<unsigned> &gptr) {
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
bst_float sum_weights = 0;
for (bst_omp_uint k = 0; k < ngroup; ++k) {
sum_weights += info.GetWeight(k);
}
return ngroup / sum_weights;
}
void ComputeGradientsOnCPU(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair,
const std::vector<unsigned> &gptr) {
LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on CPU.";
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);
const auto& preds_h = preds.HostVector();
const auto& labels = info.labels.HostView();
std::vector<GradientPair>& gpair = out_gpair->HostVector();
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
out_gpair->Resize(preds.Size());
dmlc::OMPException exc;
#pragma omp parallel num_threads(ctx_->Threads())
{
exc.Run([&]() {
// parallel construct, declare random number generator here, so that each
// thread use its own random number generator, seed by thread id and current iteration
std::minstd_rand rnd((iter + 1) * 1111);
std::vector<LambdaPair> pairs;
std::vector<ListEntry> lst;
std::vector< std::pair<bst_float, unsigned> > rec;
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
exc.Run([&]() {
lst.clear(); pairs.clear();
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
lst.emplace_back(preds_h[j], labels(j), j);
gpair[j] = GradientPair(0.0f, 0.0f);
}
std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred);
rec.resize(lst.size());
for (unsigned i = 0; i < lst.size(); ++i) {
rec[i] = std::make_pair(lst[i].label, i);
}
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
// enumerate buckets with same label
// for each item in the lst, grab another sample randomly
for (unsigned i = 0; i < rec.size(); ) {
unsigned j = i + 1;
while (j < rec.size() && rec[j].first == rec[i].first) ++j;
// bucket in [i,j), get a sample outside bucket
unsigned nleft = i, nright = static_cast<unsigned>(rec.size() - j);
if (nleft + nright != 0) {
int nsample = param_.num_pairsample;
while (nsample --) {
for (unsigned pid = i; pid < j; ++pid) {
unsigned ridx =
std::uniform_int_distribution<unsigned>(0, nleft + nright - 1)(rnd);
if (ridx < nleft) {
pairs.emplace_back(rec[ridx].second, rec[pid].second,
info.GetWeight(k) * weight_normalization_factor);
} else {
pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second,
info.GetWeight(k) * weight_normalization_factor);
}
}
}
}
i = j;
}
// get lambda weight for the pairs
LambdaWeightComputerT::GetLambdaWeight(lst, &pairs);
// rescale each gradient and hessian so that the lst have constant weighted
float scale = 1.0f / param_.num_pairsample;
if (param_.fix_list_weight != 0.0f) {
scale *= param_.fix_list_weight / (gptr[k + 1] - gptr[k]);
}
for (auto & pair : pairs) {
const ListEntry &pos = lst[pair.pos_index];
const ListEntry &neg = lst[pair.neg_index];
const bst_float w = pair.weight * scale;
const float eps = 1e-16f;
bst_float p = common::Sigmoid(pos.pred - neg.pred);
bst_float g = p - 1.0f;
bst_float h = std::max(p * (1.0f - p), eps);
// accumulate gradient and hessian in both pid, and nid
gpair[pos.rindex] += GradientPair(g * w, 2.0f*w*h);
gpair[neg.rindex] += GradientPair(-g * w, 2.0f*w*h);
}
});
}
});
}
exc.Rethrow();
}
#if defined(__CUDACC__)
void ComputeGradientsOnGPU(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair,
const std::vector<unsigned> &gptr) {
LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU.";
auto device = ctx_->gpu_id;
dh::safe_cuda(cudaSetDevice(device));
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);
// Set the device ID and copy them to the device
out_gpair->SetDevice(device);
info.labels.SetDevice(device);
preds.SetDevice(device);
info.weights_.SetDevice(device);
out_gpair->Resize(preds.Size());
auto d_preds = preds.ConstDevicePointer();
auto d_gpair = out_gpair->DevicePointer();
auto d_labels = info.labels.View(device);
SortedLabelList slist(param_);
// Sort the labels within the groups on the device
slist.Sort(*info.labels.Data(), gptr);
// Initialize the gradients next
out_gpair->Fill(GradientPair(0.0f, 0.0f));
// Finally, compute the gradients
slist.ComputeGradients<LambdaWeightComputerT>(d_preds, d_labels.Values().data(), info.weights_,
iter, d_gpair, weight_normalization_factor);
}
#endif
LambdaRankParam param_;
};
#if !defined(GTEST_TEST)
// register the objective functions
DMLC_REGISTER_PARAMETER(LambdaRankParam);
XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name())
.describe("Pairwise rank objective.")
.set_body([]() { return new LambdaRankObj<PairwiseLambdaWeightComputer>(); });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name())
.describe("LambdaRank with MAP as objective.")
.set_body([]() { return new LambdaRankObj<MAPLambdaWeightComputer>(); });
#endif
} // namespace obj
} // namespace xgboost

View File

@ -223,4 +223,125 @@ TEST(LambdaRank, MakePair) {
ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair());
}
}
void TestMAPStat(Context const* ctx) {
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
ltr::LambdaRankParam param;
param.UpdateAllowUnknown(Args{});
{
std::vector<float> h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
info.labels.Reshape(h_data.size(), 1);
info.labels.Data()->HostVector() = h_data;
info.num_row_ = h_data.size();
HostDeviceVector<float> predt;
auto& h_predt = predt.HostVector();
h_predt.resize(h_data.size());
std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f);
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
predt.SetDevice(ctx->gpu_id);
auto rank_idx =
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
if (ctx->IsCPU()) {
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
p_cache);
} else {
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
}
Context cpu_ctx;
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
auto acc = p_cache->Acc(&cpu_ctx);
ASSERT_EQ(n_rel[0], 1.0);
ASSERT_EQ(acc[0], 1.0);
ASSERT_EQ(n_rel.back(), h_data.size() - 1.0);
ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps);
}
{
info.labels.Reshape(16);
auto& h_label = info.labels.Data()->HostVector();
info.group_ptr_ = {0, 8, 16};
info.num_row_ = info.labels.Shape(0);
std::fill_n(h_label.begin(), 8, 1.0f);
std::fill_n(h_label.begin() + 8, 8, 0.0f);
HostDeviceVector<float> predt;
auto& h_predt = predt.HostVector();
h_predt.resize(h_label.size());
std::iota(h_predt.rbegin(), h_predt.rbegin() + 8, 0.0f);
std::iota(h_predt.rbegin() + 8, h_predt.rend(), 0.0f);
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
predt.SetDevice(ctx->gpu_id);
auto rank_idx =
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
if (ctx->IsCPU()) {
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
p_cache);
} else {
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
}
Context cpu_ctx;
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
ASSERT_EQ(n_rel[7], 8); // first group
ASSERT_EQ(n_rel.back(), 0); // second group
}
}
TEST(LambdaRank, MAPStat) {
Context ctx;
TestMAPStat(&ctx);
}
void TestMAPGPair(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:map", ctx)};
Args args;
obj->Configure(args);
CheckConfigReload(obj, "rank:map");
CheckRankingObjFunction(obj, // obj
{0, 0.1f, 0, 0.1f}, // score
{0, 1, 0, 1}, // label
{2.0f, 2.0f}, // weight
{0, 2, 4}, // group
{1.2054923f, -1.2054923f, 1.2054923f, -1.2054923f}, // out grad
{1.2657166f, 1.2657166f, 1.2657166f, 1.2657166f});
// disable the second query group with 0 weight
CheckRankingObjFunction(obj, // obj
{0, 0.1f, 0, 0.1f}, // score
{0, 1, 0, 1}, // label
{2.0f, 0.0f}, // weight
{0, 2, 4}, // group
{1.2054923f, -1.2054923f, .0f, .0f}, // out grad
{1.2657166f, 1.2657166f, .0f, .0f});
}
TEST(LambdaRank, MAPGPair) {
Context ctx;
TestMAPGPair(&ctx);
}
void TestPairWiseGPair(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:pairwise", ctx)};
Args args;
obj->Configure(args);
args.emplace_back("lambdarank_unbiased", "true");
}
TEST(LambdaRank, Pairwise) {
Context ctx;
TestPairWiseGPair(&ctx);
}
} // namespace xgboost::obj

View File

@ -18,6 +18,12 @@ TEST(LambdaRank, GPUNDCGJsonIO) {
TestNDCGJsonIO(&ctx);
}
TEST(LambdaRank, GPUMAPStat) {
Context ctx;
ctx.gpu_id = 0;
TestMAPStat(&ctx);
}
TEST(LambdaRank, GPUNDCGGPair) {
Context ctx;
ctx.gpu_id = 0;
@ -153,4 +159,10 @@ TEST(LambdaRank, RankItemCountOnRight) {
RankItemCountImpl(sorted_items, wrapper, 1, static_cast<uint32_t>(1));
RankItemCountImpl(sorted_items, wrapper, 0, static_cast<uint32_t>(0));
}
TEST(LambdaRank, GPUMAPGPair) {
Context ctx;
ctx.gpu_id = 0;
TestMAPGPair(&ctx);
}
} // namespace xgboost::obj

View File

@ -18,6 +18,8 @@
#include "../helpers.h" // for EmptyDMatrix
namespace xgboost::obj {
void TestMAPStat(Context const* ctx);
inline void TestNDCGJsonIO(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{ObjFunction::Create("rank:ndcg", ctx)};
@ -37,6 +39,8 @@ void TestNDCGGPair(Context const* ctx);
void TestUnbiasedNDCG(Context const* ctx);
void TestMAPGPair(Context const* ctx);
/**
* \brief Initialize test data for make pair tests.
*/

View File

@ -1,83 +0,0 @@
// Copyright by Contributors
#include <xgboost/context.h>
#include <xgboost/json.h>
#include <xgboost/objective.h>
#include "../helpers.h"
namespace xgboost {
TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:pairwise", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "rank:pairwise");
// Test with setting sample weight to second query group
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{1.9f, -1.9f, 0.0f, 0.0f},
{1.995f, 1.995f, 0.0f, 0.0f});
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.95f, -0.95f, 0.95f, -0.95f},
{0.9975f, 0.9975f, 0.9975f, 0.9975f});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("rank:pairwise", &ctx)};
obj->Configure(args);
// No computation of gradient/hessian, as there is no diversity in labels
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{1, 1, 1, 1},
{2.0f, 0.0f},
{0, 2, 4},
{0.0f, 0.0f, 0.0f, 0.0f},
{0.0f, 0.0f, 0.0f, 0.0f});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:map", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "rank:map");
// Test with setting sample weight to second query group
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{0.95f, -0.95f, 0.0f, 0.0f},
{0.9975f, 0.9975f, 0.0f, 0.0f});
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.475f, -0.475f, 0.475f, -0.475f},
{0.4988f, 0.4988f, 0.4988f, 0.4988f});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
} // namespace xgboost

View File

@ -1,175 +0,0 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
*/
#include <thrust/host_vector.h>
#include "test_ranking_obj.cc"
#include "../../../src/objective/rank_obj.cu"
namespace xgboost {
template <typename T = uint32_t, typename Comparator = thrust::greater<T>>
std::unique_ptr<dh::SegmentSorter<T>>
RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices,
const std::vector<T> &hlabels,
const std::vector<T> &expected_sorted_hlabels,
const std::vector<uint32_t> &expected_orig_pos
) {
std::unique_ptr<dh::SegmentSorter<T>> seg_sorter_ptr(new dh::SegmentSorter<T>);
dh::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
// Create a bunch of unsorted labels on the device and sort it via the segment sorter
dh::device_vector<T> dlabels(hlabels);
seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator());
auto num_items = seg_sorter.GetItemsSpan().size();
EXPECT_EQ(num_items, group_indices.back());
EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1);
// Check the labels
dh::device_vector<T> sorted_dlabels(num_items);
sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()),
dh::tcend(seg_sorter.GetItemsSpan()));
thrust::host_vector<T> sorted_hlabels(sorted_dlabels);
EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels);
// Check the indices
dh::device_vector<uint32_t> dorig_pos(num_items);
dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()),
dh::tcend(seg_sorter.GetOriginalPositionsSpan()));
dh::device_vector<uint32_t> horig_pos(dorig_pos);
EXPECT_EQ(expected_orig_pos, horig_pos);
return seg_sorter_ptr;
}
TEST(Objective, RankSegmentSorterTest) {
RankSegmentSorterTestImpl({0, 2, 4, 7, 10, 14, 18, 22, 26}, // Groups
{1, 1, // Labels
1, 2,
3, 2, 1,
1, 2, 1,
1, 3, 4, 2,
1, 2, 1, 1,
1, 2, 2, 3,
3, 3, 1, 2},
{1, 1, // Expected sorted labels
2, 1,
3, 2, 1,
2, 1, 1,
4, 3, 2, 1,
2, 1, 1, 1,
3, 2, 2, 1,
3, 3, 2, 1},
{0, 1, // Expected original positions
3, 2,
4, 5, 6,
8, 7, 9,
12, 11, 13, 10,
15, 14, 16, 17,
21, 19, 20, 18,
22, 23, 25, 24});
}
TEST(Objective, RankSegmentSorterSingleGroupTest) {
RankSegmentSorterTestImpl({0, 7}, // Groups
{6, 1, 4, 3, 0, 5, 2}, // Labels
{6, 5, 4, 3, 2, 1, 0}, // Expected sorted labels
{0, 5, 2, 3, 6, 1, 4}); // Expected original positions
}
TEST(Objective, RankSegmentSorterAscendingTest) {
RankSegmentSorterTestImpl<uint32_t, thrust::less<uint32_t>>(
{0, 4, 7}, // Groups
{3, 1, 4, 2, // Labels
6, 5, 7},
{1, 2, 3, 4, // Expected sorted labels
5, 6, 7},
{1, 3, 0, 2, // Expected original positions
5, 4, 6});
}
TEST(Objective, IndexableSortedItemsTest) {
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
7.8f, 5.01f, 6.96f,
10.3f, 8.7f, 11.4f, 9.45f, 11.4f};
dh::device_vector<bst_float> dlabels(hlabels);
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
{0, 4, 7, 12}, // Groups
hlabels,
{4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels
7.8f, 6.96f, 5.01f,
11.4f, 11.4f, 10.3f, 9.45f, 8.7f},
{3, 0, 2, 1, // Expected original positions
4, 6, 5,
9, 11, 7, 10, 8});
segment_label_sorter->CreateIndexableSortedPositions();
std::vector<uint32_t> sorted_indices(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&sorted_indices,
segment_label_sorter->GetIndexableSortedPositionsSpan());
std::vector<uint32_t> expected_sorted_indices = {
1, 3, 2, 0,
4, 6, 5,
9, 11, 7, 10, 8};
EXPECT_EQ(expected_sorted_indices, sorted_indices);
}
TEST(Objective, ComputeAndCompareMAPStatsTest) {
std::vector<float> hlabels = {3.1f, 0.0f, 2.3f, 4.4f, // Labels
0.0f, 5.01f, 0.0f,
10.3f, 0.0f, 11.4f, 9.45f, 11.4f};
dh::device_vector<bst_float> dlabels(hlabels);
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
{0, 4, 7, 12}, // Groups
hlabels,
{4.4f, 3.1f, 2.3f, 0.0f, // Expected sorted labels
5.01f, 0.0f, 0.0f,
11.4f, 11.4f, 10.3f, 9.45f, 0.0f},
{3, 0, 2, 1, // Expected original positions
5, 4, 6,
9, 11, 7, 10, 8});
// Create MAP stats on the device first using the objective
std::vector<bst_float> hpreds{-9.78f, 24.367f, 0.908f, -11.47f,
-1.03f, -2.79f, -3.1f,
104.22f, 103.1f, -101.7f, 100.5f, 45.1f};
dh::device_vector<bst_float> dpreds(hpreds);
xgboost::obj::MAPLambdaWeightComputer map_lw_computer(dpreds.data().get(),
dlabels.data().get(),
*segment_label_sorter);
// Get the device MAP stats on host
std::vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> dmap_stats(
segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan());
// Compute the MAP stats on host next to compare
std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
for (size_t i = 0; i < hgroups.size() - 1; ++i) {
auto gbegin = hgroups[i];
auto gend = hgroups[i + 1];
std::vector<xgboost::obj::ListEntry> lst_entry;
for (auto j = gbegin; j < gend; ++j) {
lst_entry.emplace_back(hpreds[j], hlabels[j], j);
}
std::stable_sort(lst_entry.begin(), lst_entry.end(), xgboost::obj::ListEntry::CmpPred);
// Compute the MAP stats with this list and compare with the ones computed on the device
std::vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> hmap_stats;
xgboost::obj::MAPLambdaWeightComputer::GetMAPStats(lst_entry, &hmap_stats);
for (auto j = gbegin; j < gend; ++j) {
EXPECT_EQ(dmap_stats[j].hits, hmap_stats[j - gbegin].hits);
EXPECT_NEAR(dmap_stats[j].ap_acc, hmap_stats[j - gbegin].ap_acc, 0.01f);
EXPECT_NEAR(dmap_stats[j].ap_acc_miss, hmap_stats[j - gbegin].ap_acc_miss, 0.01f);
EXPECT_NEAR(dmap_stats[j].ap_acc_add, hmap_stats[j - gbegin].ap_acc_add, 0.01f);
}
}
}
} // namespace xgboost

View File

@ -176,7 +176,7 @@ def test_ranking():
def test_ranking_metric() -> None:
from sklearn.metrics import roc_auc_score
X, y, qid, w = tm.make_ltr(512, 4, 3, 2)
X, y, qid, w = tm.make_ltr(512, 4, 3, 1)
# use auc for test as ndcg_score in sklearn works only on label gain instead of exp
# gain.
# note that the auc in sklearn is different from the one in XGBoost. The one in

View File

@ -1343,61 +1343,94 @@ class XgboostLocalTest(SparkTestCase):
SparkXGBClassifier(evals_result={})
class XgboostRankerLocalTest(SparkTestCase):
def setUp(self):
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
self.ranker_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
self.ranker_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
],
["features", "qid", "expected_prediction"],
)
self.ranker_df_train_1 = self.session.createDataFrame(
[
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
]
* 4,
["features", "label", "qid"],
)
LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1"))
def test_ranker(self):
ranker = SparkXGBRanker(qid_col="qid")
@pytest.fixture
def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
ranker_df_train = spark.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
X_train = np.array(
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[9.0, 4.0, 8.0],
[np.NaN, 1.0, 5.5],
[np.NaN, 6.0, 7.5],
[np.NaN, 8.0, 9.5],
]
)
qid_train = np.array([0, 0, 0, 1, 1, 1])
y_train = np.array([0, 1, 2, 0, 1, 2])
X_test = np.array(
[
[1.5, 2.0, 3.0],
[4.5, 5.0, 6.0],
[9.0, 4.5, 8.0],
[np.NaN, 1.0, 6.0],
[np.NaN, 6.0, 7.0],
[np.NaN, 8.0, 10.5],
]
)
ltr = xgb.XGBRanker(tree_method="approx", objective="rank:pairwise")
ltr.fit(X_train, y_train, qid=qid_train)
predt = ltr.predict(X_test)
ranker_df_test = spark.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, float(predt[0])),
(Vectors.dense(4.5, 5.0, 6.0), 0, float(predt[1])),
(Vectors.dense(9.0, 4.5, 8.0), 0, float(predt[2])),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, float(predt[3])),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, float(predt[4])),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, float(predt[5])),
],
["features", "qid", "expected_prediction"],
)
ranker_df_train_1 = spark.createDataFrame(
[
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
]
* 4,
["features", "label", "qid"],
)
yield LTRData(ranker_df_train, ranker_df_test, ranker_df_train_1)
class TestPySparkLocalLETOR:
def test_ranker(self, ltr_data: LTRData) -> None:
ranker = SparkXGBRanker(qid_col="qid", objective="rank:pairwise")
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train)
pred_result = model.transform(self.ranker_df_test).collect()
model = ranker.fit(ltr_data.df_train)
pred_result = model.transform(ltr_data.df_test).collect()
for row in pred_result:
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
def test_ranker_qid_sorted(self):
ranker = SparkXGBRanker(qid_col="qid", num_workers=4)
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train_1)
model.transform(self.ranker_df_test).collect()
def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None:
ranker = SparkXGBRanker(qid_col="qid", num_workers=4, objective="rank:ndcg")
assert ranker.getOrDefault(ranker.objective) == "rank:ndcg"
model = ranker.fit(ltr_data.df_train_1)
model.transform(ltr_data.df_test).collect()