Rework MAP and Pairwise for LTR. (#9075)
This commit is contained in:
parent
0e470ef606
commit
e206b899ef
@ -32,7 +32,6 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/objective/objective.o \
|
||||
$(PKGROOT)/src/objective/regression_obj.o \
|
||||
$(PKGROOT)/src/objective/multiclass_obj.o \
|
||||
$(PKGROOT)/src/objective/rank_obj.o \
|
||||
$(PKGROOT)/src/objective/lambdarank_obj.o \
|
||||
$(PKGROOT)/src/objective/hinge.o \
|
||||
$(PKGROOT)/src/objective/aft_obj.o \
|
||||
|
||||
@ -32,7 +32,6 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/objective/objective.o \
|
||||
$(PKGROOT)/src/objective/regression_obj.o \
|
||||
$(PKGROOT)/src/objective/multiclass_obj.o \
|
||||
$(PKGROOT)/src/objective/rank_obj.o \
|
||||
$(PKGROOT)/src/objective/lambdarank_obj.o \
|
||||
$(PKGROOT)/src/objective/hinge.o \
|
||||
$(PKGROOT)/src/objective/aft_obj.o \
|
||||
|
||||
@ -219,6 +219,16 @@
|
||||
"num_pairsample": { "type": "string" },
|
||||
"fix_list_weight": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"lambdarank_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lambdarank_num_pair_per_sample": { "type": "string" },
|
||||
"lambdarank_pair_method": { "type": "string" },
|
||||
"lambdarank_unbiased": {"type": "string" },
|
||||
"lambdarank_bias_norm": {"type": "string" },
|
||||
"ndcg_exp_gain": {"type": "string"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@ -477,22 +487,22 @@
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "const": "rank:pairwise" },
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"lambda_rank_param"
|
||||
"lambdarank_param"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "const": "rank:ndcg" },
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"lambda_rank_param"
|
||||
"lambdarank_param"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -233,7 +233,7 @@ Parameters for Tree Booster
|
||||
.. note:: This parameter is working-in-progress.
|
||||
|
||||
- The strategy used for training multi-target models, including multi-target regression
|
||||
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
|
||||
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
|
||||
|
||||
- ``one_output_per_tree``: One model for each target.
|
||||
- ``multi_output_tree``: Use multi-target trees.
|
||||
@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
See :doc:`/tutorials/aft_survival_analysis` for details.
|
||||
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
|
||||
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
|
||||
- ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
|
||||
- ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized
|
||||
- ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
|
||||
- ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized. This objective supports position debiasing for click data.
|
||||
- ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
|
||||
- ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective.
|
||||
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
|
||||
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.
|
||||
|
||||
@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
|
||||
* ``eval_metric`` [default according to objective]
|
||||
|
||||
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
|
||||
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
|
||||
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
|
||||
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
|
||||
|
||||
- The choices are listed below:
|
||||
|
||||
- ``rmse``: `root mean square error <http://en.wikipedia.org/wiki/Root_mean_square_error>`_
|
||||
@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli
|
||||
|
||||
* ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``.
|
||||
|
||||
.. _ltr-param:
|
||||
|
||||
Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
|
||||
================================================================================
|
||||
|
||||
These are parameters specific to learning to rank task. See :doc:`Learning to Rank </tutorials/learning_to_rank>` for an in-depth explanation.
|
||||
|
||||
* ``lambdarank_pair_method`` [default = ``mean``]
|
||||
|
||||
How to construct pairs for pair-wise learning.
|
||||
|
||||
- ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
|
||||
- ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
|
||||
|
||||
* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
|
||||
|
||||
It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
|
||||
|
||||
* ``lambdarank_unbiased`` [default = ``false``]
|
||||
|
||||
Specify whether do we need to debias input click data.
|
||||
|
||||
* ``lambdarank_bias_norm`` [default = 2.0]
|
||||
|
||||
:math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
|
||||
|
||||
* ``ndcg_exp_gain`` [default = ``true``]
|
||||
|
||||
Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
|
||||
|
||||
***********************
|
||||
Command Line Parameters
|
||||
***********************
|
||||
|
||||
@ -431,8 +431,11 @@ def make_ltr(
|
||||
"""Make a dataset for testing LTR."""
|
||||
rng = np.random.default_rng(1994)
|
||||
X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features)
|
||||
y = rng.integers(0, max_rel, size=n_samples)
|
||||
qid = rng.integers(0, n_query_groups, size=n_samples)
|
||||
y = np.sum(X, axis=1)
|
||||
y -= y.min()
|
||||
y = np.round(y / y.max() * max_rel).astype(np.int32)
|
||||
|
||||
qid = rng.integers(0, n_query_groups, size=n_samples, dtype=np.int32)
|
||||
w = rng.normal(0, 1.0, size=n_query_groups)
|
||||
w -= np.min(w)
|
||||
w /= np.max(w)
|
||||
|
||||
@ -493,7 +493,6 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
|
||||
|
||||
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
|
||||
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
auto g_rank = rank_idx.subspan(gptr[g]);
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
|
||||
lj(i) += g_lj(i);
|
||||
}
|
||||
}
|
||||
|
||||
// The ti+ is not guaranteed to decrease since it depends on the |\delta Z|
|
||||
//
|
||||
// The update normalizes the ti+ to make ti+(0) equal to 1, which breaks the probability
|
||||
@ -432,9 +433,201 @@ void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView<double cons
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda_impl
|
||||
|
||||
namespace cpu_impl {
|
||||
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
|
||||
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache) {
|
||||
auto h_n_rel = p_cache->NumRelevant(ctx);
|
||||
auto gptr = p_cache->DataGroupPtr(ctx);
|
||||
|
||||
CHECK_EQ(h_n_rel.size(), gptr.back());
|
||||
CHECK_EQ(h_n_rel.size(), label.Size());
|
||||
|
||||
auto h_acc = p_cache->Acc(ctx);
|
||||
|
||||
common::ParallelFor(p_cache->Groups(), ctx->Threads(), [&](auto g) {
|
||||
auto cnt = gptr[g + 1] - gptr[g];
|
||||
auto g_n_rel = h_n_rel.subspan(gptr[g], cnt);
|
||||
auto g_rank = rank_idx.subspan(gptr[g], cnt);
|
||||
auto g_label = label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
|
||||
// The number of relevant documents at each position
|
||||
g_n_rel[0] = g_label(g_rank[0]);
|
||||
for (std::size_t k = 1; k < g_rank.size(); ++k) {
|
||||
g_n_rel[k] = g_n_rel[k - 1] + g_label(g_rank[k]);
|
||||
}
|
||||
|
||||
// \sum l_k/k
|
||||
auto g_acc = h_acc.subspan(gptr[g], cnt);
|
||||
g_acc[0] = g_label(g_rank[0]) / 1.0;
|
||||
|
||||
for (std::size_t k = 1; k < g_rank.size(); ++k) {
|
||||
g_acc[k] = g_acc[k - 1] + (g_label(g_rank[k]) / static_cast<double>(k + 1));
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
|
||||
class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
|
||||
public:
|
||||
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
|
||||
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) {
|
||||
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective.";
|
||||
if (ctx_->IsCUDA()) {
|
||||
return cuda_impl::LambdaRankGetGradientMAP(
|
||||
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
|
||||
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
|
||||
out_gpair);
|
||||
}
|
||||
|
||||
auto gptr = p_cache_->DataGroupPtr(ctx_).data();
|
||||
bst_group_t n_groups = p_cache_->Groups();
|
||||
|
||||
out_gpair->Resize(info.num_row_);
|
||||
auto h_gpair = out_gpair->HostSpan();
|
||||
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
|
||||
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
|
||||
|
||||
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
|
||||
|
||||
cpu_impl::MAPStat(ctx_, h_label, rank_idx, GetCache());
|
||||
auto n_rel = GetCache()->NumRelevant(ctx_);
|
||||
auto acc = GetCache()->Acc(ctx_);
|
||||
|
||||
auto delta_map = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low,
|
||||
bst_group_t g) {
|
||||
if (rank_high > rank_low) {
|
||||
std::swap(rank_high, rank_low);
|
||||
std::swap(y_high, y_low);
|
||||
}
|
||||
auto cnt = gptr[g + 1] - gptr[g];
|
||||
// In a hot loop
|
||||
auto g_n_rel = common::Span<double const>{n_rel.data() + gptr[g], cnt};
|
||||
auto g_acc = common::Span<double const>{acc.data() + gptr[g], cnt};
|
||||
auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc);
|
||||
return d;
|
||||
};
|
||||
using D = decltype(delta_map);
|
||||
|
||||
common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
|
||||
auto cnt = gptr[g + 1] - gptr[g];
|
||||
auto w = h_weight[g];
|
||||
auto g_predt = h_predt.subspan(gptr[g], cnt);
|
||||
auto g_gpair = h_gpair.subspan(gptr[g], cnt);
|
||||
auto g_label = h_label.Slice(make_range(g));
|
||||
auto g_rank = rank_idx.subspan(gptr[g], cnt);
|
||||
|
||||
auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair);
|
||||
|
||||
if (param_.lambdarank_unbiased) {
|
||||
std::apply(&LambdaRankMAP::CalcLambdaForGroup<true, D>, args);
|
||||
} else {
|
||||
std::apply(&LambdaRankMAP::CalcLambdaForGroup<false, D>, args);
|
||||
}
|
||||
});
|
||||
}
|
||||
static char const* Name() { return "rank:map"; }
|
||||
[[nodiscard]] const char* DefaultEvalMetric() const override {
|
||||
return this->RankEvalMetric("map");
|
||||
}
|
||||
};
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
namespace cuda_impl {
|
||||
void MAPStat(Context const*, MetaInfo const&, common::Span<std::size_t const>,
|
||||
std::shared_ptr<ltr::MAPCache>) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
|
||||
void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector<float> const&,
|
||||
const MetaInfo&, std::shared_ptr<ltr::MAPCache>,
|
||||
linalg::VectorView<double const>, // input bias ratio
|
||||
linalg::VectorView<double const>, // input bias ratio
|
||||
linalg::VectorView<double>, linalg::VectorView<double>,
|
||||
HostDeviceVector<GradientPair>*) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
/**
|
||||
* \brief The RankNet loss.
|
||||
*/
|
||||
class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::RankingCache> {
|
||||
public:
|
||||
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
|
||||
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) {
|
||||
CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective.";
|
||||
if (ctx_->IsCUDA()) {
|
||||
return cuda_impl::LambdaRankGetGradientPairwise(
|
||||
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
|
||||
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
|
||||
out_gpair);
|
||||
}
|
||||
|
||||
auto gptr = p_cache_->DataGroupPtr(ctx_);
|
||||
bst_group_t n_groups = p_cache_->Groups();
|
||||
|
||||
out_gpair->Resize(info.num_row_);
|
||||
auto h_gpair = out_gpair->HostSpan();
|
||||
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
|
||||
|
||||
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
|
||||
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
|
||||
|
||||
auto delta = [](auto...) { return 1.0; };
|
||||
using D = decltype(delta);
|
||||
|
||||
common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
|
||||
auto cnt = gptr[g + 1] - gptr[g];
|
||||
auto w = h_weight[g];
|
||||
auto g_predt = h_predt.subspan(gptr[g], cnt);
|
||||
auto g_gpair = h_gpair.subspan(gptr[g], cnt);
|
||||
auto g_label = h_label.Slice(make_range(g));
|
||||
auto g_rank = rank_idx.subspan(gptr[g], cnt);
|
||||
|
||||
auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
|
||||
if (param_.lambdarank_unbiased) {
|
||||
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<true, D>, args);
|
||||
} else {
|
||||
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<false, D>, args);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static char const* Name() { return "rank:pairwise"; }
|
||||
[[nodiscard]] const char* DefaultEvalMetric() const override {
|
||||
return this->RankEvalMetric("ndcg");
|
||||
}
|
||||
};
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
namespace cuda_impl {
|
||||
void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVector<float> const&,
|
||||
const MetaInfo&, std::shared_ptr<ltr::RankingCache>,
|
||||
linalg::VectorView<double const>, // input bias ratio
|
||||
linalg::VectorView<double const>, // input bias ratio
|
||||
linalg::VectorView<double>, linalg::VectorView<double>,
|
||||
HostDeviceVector<GradientPair>*) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name())
|
||||
.describe("LambdaRank with NDCG loss as objective")
|
||||
.set_body([]() { return new LambdaRankNDCG{}; });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name())
|
||||
.describe("LambdaRank with RankNet loss as objective")
|
||||
.set_body([]() { return new LambdaRankPairwise{}; });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name())
|
||||
.describe("LambdaRank with MAP loss as objective.")
|
||||
.set_body([]() { return new LambdaRankMAP{}; });
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(lambdarank_obj);
|
||||
} // namespace xgboost::obj
|
||||
|
||||
@ -390,6 +390,112 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
|
||||
Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair);
|
||||
}
|
||||
|
||||
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache) {
|
||||
common::Span<double> out_n_rel = p_cache->NumRelevant(ctx);
|
||||
common::Span<double> out_acc = p_cache->Acc(ctx);
|
||||
|
||||
CHECK_EQ(out_n_rel.size(), info.num_row_);
|
||||
CHECK_EQ(out_acc.size(), info.num_row_);
|
||||
|
||||
auto group_ptr = p_cache->DataGroupPtr(ctx);
|
||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); });
|
||||
auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||
auto const* cuctx = ctx->CUDACtx();
|
||||
|
||||
{
|
||||
// calculate number of relevant documents
|
||||
auto val_it = dh::MakeTransformIterator<double>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double {
|
||||
auto g = dh::SegmentId(group_ptr, i);
|
||||
auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]));
|
||||
auto idx_in_group = i - group_ptr[g];
|
||||
auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]);
|
||||
return static_cast<double>(g_label(g_sorted_idx[idx_in_group]));
|
||||
});
|
||||
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it,
|
||||
out_n_rel.data());
|
||||
}
|
||||
{
|
||||
// \sum l_k/k
|
||||
auto val_it = dh::MakeTransformIterator<double>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double {
|
||||
auto g = dh::SegmentId(group_ptr, i);
|
||||
auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]));
|
||||
auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]);
|
||||
auto idx_in_group = i - group_ptr[g];
|
||||
double rank_in_group = idx_in_group + 1.0;
|
||||
return static_cast<double>(g_label(g_sorted_idx[idx_in_group])) / rank_in_group;
|
||||
});
|
||||
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it,
|
||||
out_acc.data());
|
||||
}
|
||||
}
|
||||
|
||||
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
|
||||
HostDeviceVector<float> const& predt, const MetaInfo& info,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache,
|
||||
linalg::VectorView<double const> ti_plus, // input bias ratio
|
||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
std::int32_t device_id = ctx->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
|
||||
info.labels.SetDevice(device_id);
|
||||
predt.SetDevice(device_id);
|
||||
|
||||
CHECK(p_cache);
|
||||
|
||||
auto d_predt = predt.ConstDeviceSpan();
|
||||
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
|
||||
|
||||
MAPStat(ctx, info, d_sorted_idx, p_cache);
|
||||
auto d_n_rel = p_cache->NumRelevant(ctx);
|
||||
auto d_acc = p_cache->Acc(ctx);
|
||||
auto d_gptr = p_cache->DataGroupPtr(ctx).data();
|
||||
|
||||
auto delta_map = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
|
||||
std::size_t rank_low, bst_group_t g) {
|
||||
if (rank_high > rank_low) {
|
||||
thrust::swap(rank_high, rank_low);
|
||||
thrust::swap(y_high, y_low);
|
||||
}
|
||||
auto cnt = d_gptr[g + 1] - d_gptr[g];
|
||||
auto g_n_rel = d_n_rel.subspan(d_gptr[g], cnt);
|
||||
auto g_acc = d_acc.subspan(d_gptr[g], cnt);
|
||||
auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc);
|
||||
return d;
|
||||
};
|
||||
|
||||
Launch(ctx, iter, predt, info, p_cache, delta_map, ti_plus, tj_minus, li, lj, out_gpair);
|
||||
}
|
||||
|
||||
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
|
||||
HostDeviceVector<float> const& predt, const MetaInfo& info,
|
||||
std::shared_ptr<ltr::RankingCache> p_cache,
|
||||
linalg::VectorView<double const> ti_plus, // input bias ratio
|
||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
std::int32_t device_id = ctx->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
|
||||
info.labels.SetDevice(device_id);
|
||||
predt.SetDevice(device_id);
|
||||
|
||||
auto d_predt = predt.ConstDeviceSpan();
|
||||
auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt);
|
||||
|
||||
auto delta = [] XGBOOST_DEVICE(float, float, std::size_t, std::size_t, bst_group_t) {
|
||||
return 1.0;
|
||||
};
|
||||
|
||||
Launch(ctx, iter, predt, info, p_cache, delta, ti_plus, tj_minus, li, lj, out_gpair);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ReduceOp {
|
||||
template <typename Tup>
|
||||
|
||||
@ -156,6 +156,27 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair);
|
||||
|
||||
/**
|
||||
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
|
||||
*/
|
||||
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache);
|
||||
|
||||
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
|
||||
HostDeviceVector<float> const& predt, MetaInfo const& info,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache,
|
||||
linalg::VectorView<double const> t_plus, // input bias ratio
|
||||
linalg::VectorView<double const> t_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair);
|
||||
|
||||
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
|
||||
HostDeviceVector<float> const& predt, const MetaInfo& info,
|
||||
std::shared_ptr<ltr::RankingCache> p_cache,
|
||||
linalg::VectorView<double const> ti_plus, // input bias ratio
|
||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair);
|
||||
|
||||
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
|
||||
linalg::VectorView<double const> lj_full,
|
||||
@ -165,6 +186,18 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
|
||||
std::shared_ptr<ltr::RankingCache> p_cache);
|
||||
} // namespace cuda_impl
|
||||
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
|
||||
*
|
||||
* \param label Ground truth relevance label.
|
||||
* \param rank_idx Sorted index of prediction.
|
||||
* \param p_cache An initialized MAPCache.
|
||||
*/
|
||||
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
|
||||
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache);
|
||||
} // namespace cpu_impl
|
||||
|
||||
/**
|
||||
* \param Construct pairs on CPU
|
||||
*
|
||||
|
||||
@ -47,7 +47,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu);
|
||||
#else
|
||||
@ -55,7 +54,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(quantile_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
} // namespace obj
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
*/
|
||||
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
#include <dmlc/registry.h>
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(rank_obj);
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
#include "rank_obj.cu"
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
@ -1,789 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/timer.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
|
||||
#include "../common/math.h"
|
||||
#include "../common/random.h"
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/gather.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/random/uniform_int_distribution.h>
|
||||
#include <thrust/random/linear_congruential_engine.h>
|
||||
|
||||
#include <cub/util_allocator.cuh>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST)
|
||||
DMLC_REGISTRY_FILE_TAG(rank_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
size_t num_pairsample;
|
||||
float fix_list_weight;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LambdaRankParam) {
|
||||
DMLC_DECLARE_FIELD(num_pairsample).set_lower_bound(1).set_default(1)
|
||||
.describe("Number of pair generated for each instance.");
|
||||
DMLC_DECLARE_FIELD(fix_list_weight).set_lower_bound(0.0f).set_default(0.0f)
|
||||
.describe("Normalize the weight of each list by this value,"
|
||||
" if equals 0, no effect will happen");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
// Helper functions
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||
CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) {
|
||||
return thrust::lower_bound(thrust::seq, items, items + n, v,
|
||||
thrust::greater<T>()) -
|
||||
items;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||
CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) {
|
||||
return n - (thrust::upper_bound(thrust::seq, items, items + n, v,
|
||||
thrust::greater<T>()) -
|
||||
items);
|
||||
}
|
||||
#endif
|
||||
|
||||
/*! \brief helper information in a list */
|
||||
struct ListEntry {
|
||||
/*! \brief the predict score we in the data */
|
||||
bst_float pred;
|
||||
/*! \brief the actual label of the entry */
|
||||
bst_float label;
|
||||
/*! \brief row index in the data matrix */
|
||||
unsigned rindex;
|
||||
// constructor
|
||||
ListEntry(bst_float pred, bst_float label, unsigned rindex)
|
||||
: pred(pred), label(label), rindex(rindex) {}
|
||||
// comparator by prediction
|
||||
inline static bool CmpPred(const ListEntry &a, const ListEntry &b) {
|
||||
return a.pred > b.pred;
|
||||
}
|
||||
// comparator by label
|
||||
inline static bool CmpLabel(const ListEntry &a, const ListEntry &b) {
|
||||
return a.label > b.label;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief a pair in the lambda rank */
|
||||
struct LambdaPair {
|
||||
/*! \brief positive index: this is a position in the list */
|
||||
unsigned pos_index;
|
||||
/*! \brief negative index: this is a position in the list */
|
||||
unsigned neg_index;
|
||||
/*! \brief weight to be filled in */
|
||||
bst_float weight;
|
||||
// constructor
|
||||
LambdaPair(unsigned pos_index, unsigned neg_index)
|
||||
: pos_index(pos_index), neg_index(neg_index), weight(1.0f) {}
|
||||
// constructor
|
||||
LambdaPair(unsigned pos_index, unsigned neg_index, bst_float weight)
|
||||
: pos_index(pos_index), neg_index(neg_index), weight(weight) {}
|
||||
};
|
||||
|
||||
class PairwiseLambdaWeightComputer {
|
||||
public:
|
||||
/*!
|
||||
* \brief get lambda weight for existing pairs - for pairwise objective
|
||||
* \param list a list that is sorted by pred score
|
||||
* \param io_pairs record of pairs, containing the pairs to fill in weights
|
||||
*/
|
||||
static void GetLambdaWeight(const std::vector<ListEntry>&,
|
||||
std::vector<LambdaPair>*) {}
|
||||
|
||||
static char const* Name() {
|
||||
return "rank:pairwise";
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
PairwiseLambdaWeightComputer(const bst_float*,
|
||||
const bst_float*,
|
||||
const dh::SegmentSorter<float>&) {}
|
||||
|
||||
class PairwiseLambdaWeightMultiplier {
|
||||
public:
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
return 1.0f;
|
||||
}
|
||||
};
|
||||
|
||||
inline const PairwiseLambdaWeightMultiplier GetWeightMultiplier() const {
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
class BaseLambdaWeightMultiplier {
|
||||
public:
|
||||
BaseLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
|
||||
const dh::SegmentSorter<float> &segment_pred_sorter)
|
||||
: dsorted_labels_(segment_label_sorter.GetItemsSpan()),
|
||||
dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()),
|
||||
dgroups_(segment_label_sorter.GetGroupsSpan()),
|
||||
dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {}
|
||||
|
||||
protected:
|
||||
const common::Span<const float> dsorted_labels_; // Labels sorted within a group
|
||||
const common::Span<const uint32_t> dorig_pos_; // Original indices of the labels
|
||||
// before they are sorted
|
||||
const common::Span<const uint32_t> dgroups_; // The group indices
|
||||
// Where can a prediction for a label be found in the original array, when they are sorted
|
||||
const common::Span<const uint32_t> dindexable_sorted_preds_pos_;
|
||||
};
|
||||
|
||||
// While computing the weight that needs to be adjusted by this ranking objective, we need
|
||||
// to figure out where positive and negative labels chosen earlier exists, if the group
|
||||
// were to be sorted by its predictions. To accommodate this, we employ the following algorithm.
|
||||
// For a given group, let's assume the following:
|
||||
// labels: 1 5 9 2 4 8 0 7 6 3
|
||||
// predictions: 1 9 0 8 2 7 3 6 5 4
|
||||
// position: 0 1 2 3 4 5 6 7 8 9
|
||||
//
|
||||
// After label sort:
|
||||
// labels: 9 8 7 6 5 4 3 2 1 0
|
||||
// position: 2 5 7 8 1 4 9 3 0 6
|
||||
//
|
||||
// After prediction sort:
|
||||
// predictions: 9 8 7 6 5 4 3 2 1 0
|
||||
// position: 1 3 5 7 8 9 6 4 0 2
|
||||
//
|
||||
// If a sorted label at position 'x' is chosen, then we need to find out where the prediction
|
||||
// for this label 'x' exists, if the group were to be sorted by predictions.
|
||||
// We first take the sorted prediction positions:
|
||||
// position: 1 3 5 7 8 9 6 4 0 2
|
||||
// at indices: 0 1 2 3 4 5 6 7 8 9
|
||||
//
|
||||
// We create a sorted prediction positional array, such that value at position 'x' gives
|
||||
// us the position in the sorted prediction array where its related prediction lies.
|
||||
// dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5
|
||||
// at indices: 0 1 2 3 4 5 6 7 8 9
|
||||
// Basically, swap the previous 2 arrays, sort the indices and reorder positions
|
||||
// for an O(1) lookup using the position where the sorted label exists.
|
||||
//
|
||||
// This type does that using the SegmentSorter
|
||||
class IndexablePredictionSorter {
|
||||
public:
|
||||
IndexablePredictionSorter(const bst_float *dpreds,
|
||||
const dh::SegmentSorter<float> &segment_label_sorter) {
|
||||
// Sort the predictions first
|
||||
segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(),
|
||||
segment_label_sorter.GetGroupSegmentsSpan());
|
||||
|
||||
// Create an index for the sorted prediction positions
|
||||
segment_pred_sorter_.CreateIndexableSortedPositions();
|
||||
}
|
||||
|
||||
inline const dh::SegmentSorter<float> &GetPredictionSorter() const {
|
||||
return segment_pred_sorter_;
|
||||
}
|
||||
|
||||
private:
|
||||
dh::SegmentSorter<float> segment_pred_sorter_; // For sorting the predictions
|
||||
};
|
||||
#endif
|
||||
|
||||
class MAPLambdaWeightComputer
|
||||
#if defined(__CUDACC__)
|
||||
: public IndexablePredictionSorter
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
struct MAPStats {
|
||||
/*! \brief the accumulated precision */
|
||||
float ap_acc{0.0f};
|
||||
/*!
|
||||
* \brief the accumulated precision,
|
||||
* assuming a positive instance is missing
|
||||
*/
|
||||
float ap_acc_miss{0.0f};
|
||||
/*!
|
||||
* \brief the accumulated precision,
|
||||
* assuming that one more positive instance is inserted ahead
|
||||
*/
|
||||
float ap_acc_add{0.0f};
|
||||
/* \brief the accumulated positive instance count */
|
||||
float hits{0.0f};
|
||||
|
||||
XGBOOST_DEVICE MAPStats() {} // NOLINT
|
||||
XGBOOST_DEVICE MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits)
|
||||
: ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {}
|
||||
|
||||
// For prefix scan
|
||||
XGBOOST_DEVICE MAPStats operator +(const MAPStats &v1) const {
|
||||
return {ap_acc + v1.ap_acc, ap_acc_miss + v1.ap_acc_miss,
|
||||
ap_acc_add + v1.ap_acc_add, hits + v1.hits};
|
||||
}
|
||||
|
||||
// For test purposes - compare for equality
|
||||
XGBOOST_DEVICE bool operator ==(const MAPStats &rhs) const {
|
||||
return ap_acc == rhs.ap_acc && ap_acc_miss == rhs.ap_acc_miss &&
|
||||
ap_acc_add == rhs.ap_acc_add && hits == rhs.hits;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE inline static void Swap(T &v0, T &v1) {
|
||||
#if defined(__CUDACC__)
|
||||
thrust::swap(v0, v1);
|
||||
#else
|
||||
std::swap(v0, v1);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Obtain the delta MAP by trying to switch the positions of labels in pos_pred_pos or
|
||||
* neg_pred_pos when sorted by predictions
|
||||
* \param pos_pred_pos positive label's prediction value position when the groups prediction
|
||||
* values are sorted
|
||||
* \param neg_pred_pos negative label's prediction value position when the groups prediction
|
||||
* values are sorted
|
||||
* \param pos_label, neg_label the chosen positive and negative labels
|
||||
* \param p_map_stats a vector containing the accumulated precisions for each position in a list
|
||||
* \param map_stats_size size of the accumulated precisions vector
|
||||
*/
|
||||
XGBOOST_DEVICE inline static bst_float GetLambdaMAP(
|
||||
int pos_pred_pos, int neg_pred_pos,
|
||||
bst_float pos_label, bst_float neg_label,
|
||||
const MAPStats *p_map_stats, uint32_t map_stats_size) {
|
||||
if (pos_pred_pos == neg_pred_pos || p_map_stats[map_stats_size - 1].hits == 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
if (pos_pred_pos > neg_pred_pos) {
|
||||
Swap(pos_pred_pos, neg_pred_pos);
|
||||
Swap(pos_label, neg_label);
|
||||
}
|
||||
bst_float original = p_map_stats[neg_pred_pos].ap_acc;
|
||||
if (pos_pred_pos != 0) original -= p_map_stats[pos_pred_pos - 1].ap_acc;
|
||||
bst_float changed = 0;
|
||||
bst_float label1 = pos_label > 0.0f ? 1.0f : 0.0f;
|
||||
bst_float label2 = neg_label > 0.0f ? 1.0f : 0.0f;
|
||||
if (label1 == label2) {
|
||||
return 0.0;
|
||||
} else if (label1 < label2) {
|
||||
changed += p_map_stats[neg_pred_pos - 1].ap_acc_add - p_map_stats[pos_pred_pos].ap_acc_add;
|
||||
changed += (p_map_stats[pos_pred_pos].hits + 1.0f) / (pos_pred_pos + 1);
|
||||
} else {
|
||||
changed += p_map_stats[neg_pred_pos - 1].ap_acc_miss - p_map_stats[pos_pred_pos].ap_acc_miss;
|
||||
changed += p_map_stats[neg_pred_pos].hits / (neg_pred_pos + 1);
|
||||
}
|
||||
bst_float ans = (changed - original) / (p_map_stats[map_stats_size - 1].hits);
|
||||
if (ans < 0) ans = -ans;
|
||||
return ans;
|
||||
}
|
||||
|
||||
public:
|
||||
/*
|
||||
* \brief obtain preprocessing results for calculating delta MAP
|
||||
* \param sorted_list the list containing entry information
|
||||
* \param map_stats a vector containing the accumulated precisions for each position in a list
|
||||
*/
|
||||
inline static void GetMAPStats(const std::vector<ListEntry> &sorted_list,
|
||||
std::vector<MAPStats> *p_map_acc) {
|
||||
std::vector<MAPStats> &map_acc = *p_map_acc;
|
||||
map_acc.resize(sorted_list.size());
|
||||
bst_float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
|
||||
for (size_t i = 1; i <= sorted_list.size(); ++i) {
|
||||
if (sorted_list[i - 1].label > 0.0f) {
|
||||
hit++;
|
||||
acc1 += hit / i;
|
||||
acc2 += (hit - 1) / i;
|
||||
acc3 += (hit + 1) / i;
|
||||
}
|
||||
map_acc[i - 1] = MAPStats(acc1, acc2, acc3, hit);
|
||||
}
|
||||
}
|
||||
|
||||
static char const* Name() {
|
||||
return "rank:map";
|
||||
}
|
||||
|
||||
static void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||
std::vector<LambdaPair> *io_pairs) {
|
||||
std::vector<LambdaPair> &pairs = *io_pairs;
|
||||
std::vector<MAPStats> map_stats;
|
||||
GetMAPStats(sorted_list, &map_stats);
|
||||
for (auto & pair : pairs) {
|
||||
pair.weight *=
|
||||
GetLambdaMAP(pair.pos_index, pair.neg_index,
|
||||
sorted_list[pair.pos_index].label, sorted_list[pair.neg_index].label,
|
||||
&map_stats[0], map_stats.size());
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
MAPLambdaWeightComputer(const bst_float *dpreds,
|
||||
const bst_float *dlabels,
|
||||
const dh::SegmentSorter<float> &segment_label_sorter)
|
||||
: IndexablePredictionSorter(dpreds, segment_label_sorter),
|
||||
dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()),
|
||||
weight_multiplier_(segment_label_sorter, *this) {
|
||||
this->CreateMAPStats(dlabels, segment_label_sorter);
|
||||
}
|
||||
|
||||
void CreateMAPStats(const bst_float *dlabels,
|
||||
const dh::SegmentSorter<float> &segment_label_sorter) {
|
||||
// For each group, go through the sorted prediction positions, and look up its corresponding
|
||||
// label from the unsorted labels (from the original label list)
|
||||
|
||||
// For each item in the group, compute its MAP stats.
|
||||
// Interleave the computation of map stats amongst different groups.
|
||||
|
||||
// First, determine postive labels in the dataset individually
|
||||
auto nitems = segment_label_sorter.GetNumItems();
|
||||
dh::caching_device_vector<uint32_t> dhits(nitems, 0);
|
||||
// Original positions of the predictions after they have been sorted
|
||||
const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan();
|
||||
// Unsorted labels
|
||||
const float *unsorted_labels = dlabels;
|
||||
auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) {
|
||||
return (unsorted_labels[pred_original_pos[idx]] > 0.0f) ? 1 : 0;
|
||||
}; // NOLINT
|
||||
|
||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(nitems),
|
||||
dhits.begin(),
|
||||
DeterminePositiveLabelLambda);
|
||||
|
||||
// Allocator to be used by sort for managing space overhead while performing prefix scans
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
// Next, prefix scan the positive labels that are segmented to accumulate them.
|
||||
// This is required for computing the accumulated precisions
|
||||
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
|
||||
// Data segmented into different groups...
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
dh::tcbegin(group_segments), dh::tcend(group_segments),
|
||||
dhits.begin(), // Input value
|
||||
dhits.begin()); // In-place scan
|
||||
|
||||
// Compute accumulated precisions for each item, assuming positive and
|
||||
// negative instances are missing.
|
||||
// But first, compute individual item precisions
|
||||
const auto *dhits_arr = dhits.data().get();
|
||||
// Group info on device
|
||||
const auto &dgroups = segment_label_sorter.GetGroupsSpan();
|
||||
auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) {
|
||||
if (unsorted_labels[pred_original_pos[idx]] > 0.0f) {
|
||||
auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1;
|
||||
return MAPStats{static_cast<float>(dhits_arr[idx]) / idx_within_group,
|
||||
static_cast<float>(dhits_arr[idx] - 1) / idx_within_group,
|
||||
static_cast<float>(dhits_arr[idx] + 1) / idx_within_group,
|
||||
1.0f};
|
||||
}
|
||||
return MAPStats{};
|
||||
}; // NOLINT
|
||||
|
||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(nitems),
|
||||
this->dmap_stats_.begin(),
|
||||
ComputeItemPrecisionLambda);
|
||||
|
||||
// Lastly, compute the accumulated precisions for all the items segmented by groups.
|
||||
// The precisions are accumulated within each group
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
dh::tcbegin(group_segments), dh::tcend(group_segments),
|
||||
this->dmap_stats_.begin(), // Input map stats
|
||||
this->dmap_stats_.begin()); // In-place scan and output here
|
||||
}
|
||||
|
||||
inline const common::Span<const MAPStats> GetMapStatsSpan() const {
|
||||
return { dmap_stats_.data().get(), dmap_stats_.size() };
|
||||
}
|
||||
|
||||
// Type containing device pointers that can be cheaply copied on the kernel
|
||||
class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
|
||||
public:
|
||||
MAPLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
|
||||
const MAPLambdaWeightComputer &lwc)
|
||||
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
|
||||
dmap_stats_(lwc.GetMapStatsSpan()) {}
|
||||
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
uint32_t group_begin = dgroups_[gidx];
|
||||
uint32_t group_end = dgroups_[gidx + 1];
|
||||
|
||||
auto pos_lab_orig_posn = dorig_pos_[pidx];
|
||||
auto neg_lab_orig_posn = dorig_pos_[nidx];
|
||||
KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn);
|
||||
|
||||
// Note: the label positive and negative indices are relative to the entire dataset.
|
||||
// Hence, scale them back to an index within the group
|
||||
auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
|
||||
auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
|
||||
return MAPLambdaWeightComputer::GetLambdaMAP(
|
||||
pos_pred_pos, neg_pred_pos,
|
||||
dsorted_labels_[pidx], dsorted_labels_[nidx],
|
||||
&dmap_stats_[group_begin], group_end - group_begin);
|
||||
}
|
||||
|
||||
private:
|
||||
common::Span<const MAPStats> dmap_stats_; // Start address of the map stats for every sorted
|
||||
// prediction value
|
||||
};
|
||||
|
||||
inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; }
|
||||
|
||||
private:
|
||||
dh::caching_device_vector<MAPStats> dmap_stats_;
|
||||
// This computes the adjustment to the weight
|
||||
const MAPLambdaWeightMultiplier weight_multiplier_;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
class SortedLabelList : dh::SegmentSorter<float> {
|
||||
private:
|
||||
const LambdaRankParam ¶m_; // Objective configuration
|
||||
|
||||
public:
|
||||
explicit SortedLabelList(const LambdaRankParam ¶m)
|
||||
: param_(param) {}
|
||||
|
||||
// Sort the labels that are grouped by 'groups'
|
||||
void Sort(const HostDeviceVector<bst_float> &dlabels, const std::vector<uint32_t> &groups) {
|
||||
this->SortItems(dlabels.ConstDevicePointer(), dlabels.Size(), groups);
|
||||
}
|
||||
|
||||
// This kernel can only run *after* the kernel in sort is completed, as they
|
||||
// use the default stream
|
||||
template <typename LambdaWeightComputerT>
|
||||
void ComputeGradients(const bst_float *dpreds, // Unsorted predictions
|
||||
const bst_float *dlabels, // Unsorted labels
|
||||
const HostDeviceVector<bst_float> &weights,
|
||||
int iter,
|
||||
GradientPair *out_gpair,
|
||||
float weight_normalization_factor) {
|
||||
// Group info on device
|
||||
const auto &dgroups = this->GetGroupsSpan();
|
||||
uint32_t ngroups = this->GetNumGroups() + 1;
|
||||
|
||||
uint32_t total_items = this->GetNumItems();
|
||||
uint32_t niter = param_.num_pairsample * total_items;
|
||||
|
||||
float fix_list_weight = param_.fix_list_weight;
|
||||
|
||||
const auto &original_pos = this->GetOriginalPositionsSpan();
|
||||
|
||||
uint32_t num_weights = weights.Size();
|
||||
auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr;
|
||||
|
||||
const auto &sorted_labels = this->GetItemsSpan();
|
||||
|
||||
// This is used to adjust the weight of different elements based on the different ranking
|
||||
// objective function policies
|
||||
LambdaWeightComputerT weight_computer(dpreds, dlabels, *this);
|
||||
auto wmultiplier = weight_computer.GetWeightMultiplier();
|
||||
|
||||
int device_id = -1;
|
||||
dh::safe_cuda(cudaGetDevice(&device_id));
|
||||
// For each instance in the group, compute the gradient pair concurrently
|
||||
dh::LaunchN(niter, nullptr, [=] __device__(uint32_t idx) {
|
||||
// First, determine the group 'idx' belongs to
|
||||
uint32_t item_idx = idx % total_items;
|
||||
uint32_t group_idx =
|
||||
thrust::upper_bound(thrust::seq, dgroups.begin(),
|
||||
dgroups.begin() + ngroups, item_idx) -
|
||||
dgroups.begin();
|
||||
// Span of this group within the larger labels/predictions sorted tuple
|
||||
uint32_t group_begin = dgroups[group_idx - 1];
|
||||
uint32_t group_end = dgroups[group_idx];
|
||||
uint32_t total_group_items = group_end - group_begin;
|
||||
|
||||
// Are the labels diverse enough? If they are all the same, then there is nothing to pick
|
||||
// from another group - bail sooner
|
||||
if (sorted_labels[group_begin] == sorted_labels[group_end - 1]) return;
|
||||
|
||||
// Find the number of labels less than and greater than the current label
|
||||
// at the sorted index position item_idx
|
||||
uint32_t nleft = CountNumItemsToTheLeftOf(
|
||||
sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
|
||||
uint32_t nright = CountNumItemsToTheRightOf(
|
||||
sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]);
|
||||
|
||||
// Create a minstd_rand object to act as our source of randomness
|
||||
thrust::minstd_rand rng((iter + 1) * 1111);
|
||||
rng.discard(((idx / total_items) * total_group_items) + item_idx - group_begin);
|
||||
// Create a uniform_int_distribution to produce a sample from outside of the
|
||||
// present label group
|
||||
thrust::uniform_int_distribution<int> dist(0, nleft + nright - 1);
|
||||
|
||||
int sample = dist(rng);
|
||||
int pos_idx = -1; // Bigger label
|
||||
int neg_idx = -1; // Smaller label
|
||||
// Are we picking a sample to the left/right of the current group?
|
||||
if (sample < nleft) {
|
||||
// Go left
|
||||
pos_idx = sample + group_begin;
|
||||
neg_idx = item_idx;
|
||||
} else {
|
||||
pos_idx = item_idx;
|
||||
uint32_t items_in_group = total_group_items - nleft - nright;
|
||||
neg_idx = sample + items_in_group + group_begin;
|
||||
}
|
||||
|
||||
// Compute and assign the gradients now
|
||||
const float eps = 1e-16f;
|
||||
bst_float p = common::Sigmoid(dpreds[original_pos[pos_idx]] - dpreds[original_pos[neg_idx]]);
|
||||
bst_float g = p - 1.0f;
|
||||
bst_float h = thrust::max(p * (1.0f - p), eps);
|
||||
|
||||
// Rescale each gradient and hessian so that the group has a weighted constant
|
||||
float scale = __frcp_ru(niter / total_items);
|
||||
if (fix_list_weight != 0.0f) {
|
||||
scale *= fix_list_weight / total_group_items;
|
||||
}
|
||||
|
||||
float weight = num_weights ? dweights[group_idx - 1] : 1.0f;
|
||||
weight *= weight_normalization_factor;
|
||||
weight *= wmultiplier.GetWeight(group_idx - 1, pos_idx, neg_idx);
|
||||
weight *= scale;
|
||||
// Accumulate gradient and hessian in both positive and negative indices
|
||||
const GradientPair in_pos_gpair(g * weight, 2.0f * weight * h);
|
||||
dh::AtomicAddGpair(&out_gpair[original_pos[pos_idx]], in_pos_gpair);
|
||||
|
||||
const GradientPair in_neg_gpair(-g * weight, 2.0f * weight * h);
|
||||
dh::AtomicAddGpair(&out_gpair[original_pos[neg_idx]], in_neg_gpair);
|
||||
});
|
||||
|
||||
// Wait until the computations done by the kernel is complete
|
||||
dh::safe_cuda(cudaStreamSynchronize(nullptr));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// objective for lambda rank
|
||||
template <typename LambdaWeightComputerT>
|
||||
class LambdaRankObj : public ObjFunction {
|
||||
public:
|
||||
void Configure(Args const &args) override { param_.UpdateAllowUnknown(args); }
|
||||
ObjInfo Task() const override { return ObjInfo::kRanking; }
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_EQ(preds.Size(), info.labels.Size()) << "label size predict size not match";
|
||||
|
||||
// quick consistency when group is not available
|
||||
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels.Size());
|
||||
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
||||
CHECK(gptr.size() != 0 && gptr.back() == info.labels.Size())
|
||||
<< "group structure not consistent with #rows" << ", "
|
||||
<< "group ponter size: " << gptr.size() << ", "
|
||||
<< "labels size: " << info.labels.Size() << ", "
|
||||
<< "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back());
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
// Check if we have a GPU assignment; else, revert back to CPU
|
||||
auto device = ctx_->gpu_id;
|
||||
if (device >= 0) {
|
||||
ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr);
|
||||
} else {
|
||||
// Revert back to CPU
|
||||
#endif
|
||||
ComputeGradientsOnCPU(preds, info, iter, out_gpair, gptr);
|
||||
#if defined(__CUDACC__)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "map";
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String(LambdaWeightComputerT::Name());
|
||||
out["lambda_rank_param"] = ToJson(param_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
FromJson(in["lambda_rank_param"], ¶m_);
|
||||
}
|
||||
|
||||
private:
|
||||
bst_float ComputeWeightNormalizationFactor(const MetaInfo& info,
|
||||
const std::vector<unsigned> &gptr) {
|
||||
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
bst_float sum_weights = 0;
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
sum_weights += info.GetWeight(k);
|
||||
}
|
||||
return ngroup / sum_weights;
|
||||
}
|
||||
|
||||
void ComputeGradientsOnCPU(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair,
|
||||
const std::vector<unsigned> &gptr) {
|
||||
LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on CPU.";
|
||||
|
||||
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);
|
||||
|
||||
const auto& preds_h = preds.HostVector();
|
||||
const auto& labels = info.labels.HostView();
|
||||
std::vector<GradientPair>& gpair = out_gpair->HostVector();
|
||||
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
out_gpair->Resize(preds.Size());
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(ctx_->Threads())
|
||||
{
|
||||
exc.Run([&]() {
|
||||
// parallel construct, declare random number generator here, so that each
|
||||
// thread use its own random number generator, seed by thread id and current iteration
|
||||
std::minstd_rand rnd((iter + 1) * 1111);
|
||||
std::vector<LambdaPair> pairs;
|
||||
std::vector<ListEntry> lst;
|
||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
exc.Run([&]() {
|
||||
lst.clear(); pairs.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
|
||||
lst.emplace_back(preds_h[j], labels(j), j);
|
||||
gpair[j] = GradientPair(0.0f, 0.0f);
|
||||
}
|
||||
std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred);
|
||||
rec.resize(lst.size());
|
||||
for (unsigned i = 0; i < lst.size(); ++i) {
|
||||
rec[i] = std::make_pair(lst[i].label, i);
|
||||
}
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// enumerate buckets with same label
|
||||
// for each item in the lst, grab another sample randomly
|
||||
for (unsigned i = 0; i < rec.size(); ) {
|
||||
unsigned j = i + 1;
|
||||
while (j < rec.size() && rec[j].first == rec[i].first) ++j;
|
||||
// bucket in [i,j), get a sample outside bucket
|
||||
unsigned nleft = i, nright = static_cast<unsigned>(rec.size() - j);
|
||||
if (nleft + nright != 0) {
|
||||
int nsample = param_.num_pairsample;
|
||||
while (nsample --) {
|
||||
for (unsigned pid = i; pid < j; ++pid) {
|
||||
unsigned ridx =
|
||||
std::uniform_int_distribution<unsigned>(0, nleft + nright - 1)(rnd);
|
||||
if (ridx < nleft) {
|
||||
pairs.emplace_back(rec[ridx].second, rec[pid].second,
|
||||
info.GetWeight(k) * weight_normalization_factor);
|
||||
} else {
|
||||
pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second,
|
||||
info.GetWeight(k) * weight_normalization_factor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i = j;
|
||||
}
|
||||
// get lambda weight for the pairs
|
||||
LambdaWeightComputerT::GetLambdaWeight(lst, &pairs);
|
||||
// rescale each gradient and hessian so that the lst have constant weighted
|
||||
float scale = 1.0f / param_.num_pairsample;
|
||||
if (param_.fix_list_weight != 0.0f) {
|
||||
scale *= param_.fix_list_weight / (gptr[k + 1] - gptr[k]);
|
||||
}
|
||||
for (auto & pair : pairs) {
|
||||
const ListEntry &pos = lst[pair.pos_index];
|
||||
const ListEntry &neg = lst[pair.neg_index];
|
||||
const bst_float w = pair.weight * scale;
|
||||
const float eps = 1e-16f;
|
||||
bst_float p = common::Sigmoid(pos.pred - neg.pred);
|
||||
bst_float g = p - 1.0f;
|
||||
bst_float h = std::max(p * (1.0f - p), eps);
|
||||
// accumulate gradient and hessian in both pid, and nid
|
||||
gpair[pos.rindex] += GradientPair(g * w, 2.0f*w*h);
|
||||
gpair[neg.rindex] += GradientPair(-g * w, 2.0f*w*h);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
void ComputeGradientsOnGPU(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair,
|
||||
const std::vector<unsigned> &gptr) {
|
||||
LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU.";
|
||||
|
||||
auto device = ctx_->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);
|
||||
|
||||
// Set the device ID and copy them to the device
|
||||
out_gpair->SetDevice(device);
|
||||
info.labels.SetDevice(device);
|
||||
preds.SetDevice(device);
|
||||
info.weights_.SetDevice(device);
|
||||
|
||||
out_gpair->Resize(preds.Size());
|
||||
|
||||
auto d_preds = preds.ConstDevicePointer();
|
||||
auto d_gpair = out_gpair->DevicePointer();
|
||||
auto d_labels = info.labels.View(device);
|
||||
|
||||
SortedLabelList slist(param_);
|
||||
|
||||
// Sort the labels within the groups on the device
|
||||
slist.Sort(*info.labels.Data(), gptr);
|
||||
|
||||
// Initialize the gradients next
|
||||
out_gpair->Fill(GradientPair(0.0f, 0.0f));
|
||||
|
||||
// Finally, compute the gradients
|
||||
slist.ComputeGradients<LambdaWeightComputerT>(d_preds, d_labels.Values().data(), info.weights_,
|
||||
iter, d_gpair, weight_normalization_factor);
|
||||
}
|
||||
#endif
|
||||
|
||||
LambdaRankParam param_;
|
||||
};
|
||||
|
||||
#if !defined(GTEST_TEST)
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(LambdaRankParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name())
|
||||
.describe("Pairwise rank objective.")
|
||||
.set_body([]() { return new LambdaRankObj<PairwiseLambdaWeightComputer>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name())
|
||||
.describe("LambdaRank with MAP as objective.")
|
||||
.set_body([]() { return new LambdaRankObj<MAPLambdaWeightComputer>(); });
|
||||
#endif
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@ -223,4 +223,125 @@ TEST(LambdaRank, MakePair) {
|
||||
ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair());
|
||||
}
|
||||
}
|
||||
|
||||
void TestMAPStat(Context const* ctx) {
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
ltr::LambdaRankParam param;
|
||||
param.UpdateAllowUnknown(Args{});
|
||||
|
||||
{
|
||||
std::vector<float> h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
|
||||
info.labels.Reshape(h_data.size(), 1);
|
||||
info.labels.Data()->HostVector() = h_data;
|
||||
info.num_row_ = h_data.size();
|
||||
|
||||
HostDeviceVector<float> predt;
|
||||
auto& h_predt = predt.HostVector();
|
||||
h_predt.resize(h_data.size());
|
||||
std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f);
|
||||
|
||||
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
|
||||
|
||||
predt.SetDevice(ctx->gpu_id);
|
||||
auto rank_idx =
|
||||
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
|
||||
|
||||
if (ctx->IsCPU()) {
|
||||
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
|
||||
p_cache);
|
||||
} else {
|
||||
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
|
||||
}
|
||||
|
||||
Context cpu_ctx;
|
||||
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
|
||||
auto acc = p_cache->Acc(&cpu_ctx);
|
||||
|
||||
ASSERT_EQ(n_rel[0], 1.0);
|
||||
ASSERT_EQ(acc[0], 1.0);
|
||||
|
||||
ASSERT_EQ(n_rel.back(), h_data.size() - 1.0);
|
||||
ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps);
|
||||
}
|
||||
{
|
||||
info.labels.Reshape(16);
|
||||
auto& h_label = info.labels.Data()->HostVector();
|
||||
info.group_ptr_ = {0, 8, 16};
|
||||
info.num_row_ = info.labels.Shape(0);
|
||||
|
||||
std::fill_n(h_label.begin(), 8, 1.0f);
|
||||
std::fill_n(h_label.begin() + 8, 8, 0.0f);
|
||||
HostDeviceVector<float> predt;
|
||||
auto& h_predt = predt.HostVector();
|
||||
h_predt.resize(h_label.size());
|
||||
std::iota(h_predt.rbegin(), h_predt.rbegin() + 8, 0.0f);
|
||||
std::iota(h_predt.rbegin() + 8, h_predt.rend(), 0.0f);
|
||||
|
||||
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
|
||||
|
||||
predt.SetDevice(ctx->gpu_id);
|
||||
auto rank_idx =
|
||||
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
|
||||
|
||||
if (ctx->IsCPU()) {
|
||||
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
|
||||
p_cache);
|
||||
} else {
|
||||
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
|
||||
}
|
||||
|
||||
Context cpu_ctx;
|
||||
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
|
||||
ASSERT_EQ(n_rel[7], 8); // first group
|
||||
ASSERT_EQ(n_rel.back(), 0); // second group
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LambdaRank, MAPStat) {
|
||||
Context ctx;
|
||||
TestMAPStat(&ctx);
|
||||
}
|
||||
|
||||
void TestMAPGPair(Context const* ctx) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:map", ctx)};
|
||||
Args args;
|
||||
obj->Configure(args);
|
||||
|
||||
CheckConfigReload(obj, "rank:map");
|
||||
|
||||
CheckRankingObjFunction(obj, // obj
|
||||
{0, 0.1f, 0, 0.1f}, // score
|
||||
{0, 1, 0, 1}, // label
|
||||
{2.0f, 2.0f}, // weight
|
||||
{0, 2, 4}, // group
|
||||
{1.2054923f, -1.2054923f, 1.2054923f, -1.2054923f}, // out grad
|
||||
{1.2657166f, 1.2657166f, 1.2657166f, 1.2657166f});
|
||||
// disable the second query group with 0 weight
|
||||
CheckRankingObjFunction(obj, // obj
|
||||
{0, 0.1f, 0, 0.1f}, // score
|
||||
{0, 1, 0, 1}, // label
|
||||
{2.0f, 0.0f}, // weight
|
||||
{0, 2, 4}, // group
|
||||
{1.2054923f, -1.2054923f, .0f, .0f}, // out grad
|
||||
{1.2657166f, 1.2657166f, .0f, .0f});
|
||||
}
|
||||
|
||||
TEST(LambdaRank, MAPGPair) {
|
||||
Context ctx;
|
||||
TestMAPGPair(&ctx);
|
||||
}
|
||||
|
||||
void TestPairWiseGPair(Context const* ctx) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:pairwise", ctx)};
|
||||
Args args;
|
||||
obj->Configure(args);
|
||||
|
||||
args.emplace_back("lambdarank_unbiased", "true");
|
||||
}
|
||||
|
||||
TEST(LambdaRank, Pairwise) {
|
||||
Context ctx;
|
||||
TestPairWiseGPair(&ctx);
|
||||
}
|
||||
} // namespace xgboost::obj
|
||||
|
||||
@ -18,6 +18,12 @@ TEST(LambdaRank, GPUNDCGJsonIO) {
|
||||
TestNDCGJsonIO(&ctx);
|
||||
}
|
||||
|
||||
TEST(LambdaRank, GPUMAPStat) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
TestMAPStat(&ctx);
|
||||
}
|
||||
|
||||
TEST(LambdaRank, GPUNDCGGPair) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
@ -153,4 +159,10 @@ TEST(LambdaRank, RankItemCountOnRight) {
|
||||
RankItemCountImpl(sorted_items, wrapper, 1, static_cast<uint32_t>(1));
|
||||
RankItemCountImpl(sorted_items, wrapper, 0, static_cast<uint32_t>(0));
|
||||
}
|
||||
|
||||
TEST(LambdaRank, GPUMAPGPair) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
TestMAPGPair(&ctx);
|
||||
}
|
||||
} // namespace xgboost::obj
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
#include "../helpers.h" // for EmptyDMatrix
|
||||
|
||||
namespace xgboost::obj {
|
||||
void TestMAPStat(Context const* ctx);
|
||||
|
||||
inline void TestNDCGJsonIO(Context const* ctx) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{ObjFunction::Create("rank:ndcg", ctx)};
|
||||
|
||||
@ -37,6 +39,8 @@ void TestNDCGGPair(Context const* ctx);
|
||||
|
||||
void TestUnbiasedNDCG(Context const* ctx);
|
||||
|
||||
void TestMAPGPair(Context const* ctx);
|
||||
|
||||
/**
|
||||
* \brief Initialize test data for make pair tests.
|
||||
*/
|
||||
|
||||
@ -1,83 +0,0 @@
|
||||
// Copyright by Contributors
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:pairwise", &ctx)};
|
||||
obj->Configure(args);
|
||||
CheckConfigReload(obj, "rank:pairwise");
|
||||
|
||||
// Test with setting sample weight to second query group
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{2.0f, 0.0f},
|
||||
{0, 2, 4},
|
||||
{1.9f, -1.9f, 0.0f, 0.0f},
|
||||
{1.995f, 1.995f, 0.0f, 0.0f});
|
||||
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{1.0f, 1.0f},
|
||||
{0, 2, 4},
|
||||
{0.95f, -0.95f, 0.95f, -0.95f},
|
||||
{0.9975f, 0.9975f, 0.9975f, 0.9975f});
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("rank:pairwise", &ctx)};
|
||||
obj->Configure(args);
|
||||
// No computation of gradient/hessian, as there is no diversity in labels
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{1, 1, 1, 1},
|
||||
{2.0f, 0.0f},
|
||||
{0, 2, 4},
|
||||
{0.0f, 0.0f, 0.0f, 0.0f},
|
||||
{0.0f, 0.0f, 0.0f, 0.0f});
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:map", &ctx)};
|
||||
obj->Configure(args);
|
||||
CheckConfigReload(obj, "rank:map");
|
||||
|
||||
// Test with setting sample weight to second query group
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{2.0f, 0.0f},
|
||||
{0, 2, 4},
|
||||
{0.95f, -0.95f, 0.0f, 0.0f},
|
||||
{0.9975f, 0.9975f, 0.0f, 0.0f});
|
||||
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{1.0f, 1.0f},
|
||||
{0, 2, 4},
|
||||
{0.475f, -0.475f, 0.475f, -0.475f},
|
||||
{0.4988f, 0.4988f, 0.4988f, 0.4988f});
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -1,175 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "test_ranking_obj.cc"
|
||||
#include "../../../src/objective/rank_obj.cu"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T = uint32_t, typename Comparator = thrust::greater<T>>
|
||||
std::unique_ptr<dh::SegmentSorter<T>>
|
||||
RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices,
|
||||
const std::vector<T> &hlabels,
|
||||
const std::vector<T> &expected_sorted_hlabels,
|
||||
const std::vector<uint32_t> &expected_orig_pos
|
||||
) {
|
||||
std::unique_ptr<dh::SegmentSorter<T>> seg_sorter_ptr(new dh::SegmentSorter<T>);
|
||||
dh::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
|
||||
|
||||
// Create a bunch of unsorted labels on the device and sort it via the segment sorter
|
||||
dh::device_vector<T> dlabels(hlabels);
|
||||
seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator());
|
||||
|
||||
auto num_items = seg_sorter.GetItemsSpan().size();
|
||||
EXPECT_EQ(num_items, group_indices.back());
|
||||
EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1);
|
||||
|
||||
// Check the labels
|
||||
dh::device_vector<T> sorted_dlabels(num_items);
|
||||
sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()),
|
||||
dh::tcend(seg_sorter.GetItemsSpan()));
|
||||
thrust::host_vector<T> sorted_hlabels(sorted_dlabels);
|
||||
EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels);
|
||||
|
||||
// Check the indices
|
||||
dh::device_vector<uint32_t> dorig_pos(num_items);
|
||||
dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()),
|
||||
dh::tcend(seg_sorter.GetOriginalPositionsSpan()));
|
||||
dh::device_vector<uint32_t> horig_pos(dorig_pos);
|
||||
EXPECT_EQ(expected_orig_pos, horig_pos);
|
||||
|
||||
return seg_sorter_ptr;
|
||||
}
|
||||
|
||||
TEST(Objective, RankSegmentSorterTest) {
|
||||
RankSegmentSorterTestImpl({0, 2, 4, 7, 10, 14, 18, 22, 26}, // Groups
|
||||
{1, 1, // Labels
|
||||
1, 2,
|
||||
3, 2, 1,
|
||||
1, 2, 1,
|
||||
1, 3, 4, 2,
|
||||
1, 2, 1, 1,
|
||||
1, 2, 2, 3,
|
||||
3, 3, 1, 2},
|
||||
{1, 1, // Expected sorted labels
|
||||
2, 1,
|
||||
3, 2, 1,
|
||||
2, 1, 1,
|
||||
4, 3, 2, 1,
|
||||
2, 1, 1, 1,
|
||||
3, 2, 2, 1,
|
||||
3, 3, 2, 1},
|
||||
{0, 1, // Expected original positions
|
||||
3, 2,
|
||||
4, 5, 6,
|
||||
8, 7, 9,
|
||||
12, 11, 13, 10,
|
||||
15, 14, 16, 17,
|
||||
21, 19, 20, 18,
|
||||
22, 23, 25, 24});
|
||||
}
|
||||
|
||||
TEST(Objective, RankSegmentSorterSingleGroupTest) {
|
||||
RankSegmentSorterTestImpl({0, 7}, // Groups
|
||||
{6, 1, 4, 3, 0, 5, 2}, // Labels
|
||||
{6, 5, 4, 3, 2, 1, 0}, // Expected sorted labels
|
||||
{0, 5, 2, 3, 6, 1, 4}); // Expected original positions
|
||||
}
|
||||
|
||||
TEST(Objective, RankSegmentSorterAscendingTest) {
|
||||
RankSegmentSorterTestImpl<uint32_t, thrust::less<uint32_t>>(
|
||||
{0, 4, 7}, // Groups
|
||||
{3, 1, 4, 2, // Labels
|
||||
6, 5, 7},
|
||||
{1, 2, 3, 4, // Expected sorted labels
|
||||
5, 6, 7},
|
||||
{1, 3, 0, 2, // Expected original positions
|
||||
5, 4, 6});
|
||||
}
|
||||
|
||||
TEST(Objective, IndexableSortedItemsTest) {
|
||||
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
|
||||
7.8f, 5.01f, 6.96f,
|
||||
10.3f, 8.7f, 11.4f, 9.45f, 11.4f};
|
||||
dh::device_vector<bst_float> dlabels(hlabels);
|
||||
|
||||
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
|
||||
{0, 4, 7, 12}, // Groups
|
||||
hlabels,
|
||||
{4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels
|
||||
7.8f, 6.96f, 5.01f,
|
||||
11.4f, 11.4f, 10.3f, 9.45f, 8.7f},
|
||||
{3, 0, 2, 1, // Expected original positions
|
||||
4, 6, 5,
|
||||
9, 11, 7, 10, 8});
|
||||
|
||||
segment_label_sorter->CreateIndexableSortedPositions();
|
||||
std::vector<uint32_t> sorted_indices(segment_label_sorter->GetNumItems());
|
||||
dh::CopyDeviceSpanToVector(&sorted_indices,
|
||||
segment_label_sorter->GetIndexableSortedPositionsSpan());
|
||||
std::vector<uint32_t> expected_sorted_indices = {
|
||||
1, 3, 2, 0,
|
||||
4, 6, 5,
|
||||
9, 11, 7, 10, 8};
|
||||
EXPECT_EQ(expected_sorted_indices, sorted_indices);
|
||||
}
|
||||
|
||||
TEST(Objective, ComputeAndCompareMAPStatsTest) {
|
||||
std::vector<float> hlabels = {3.1f, 0.0f, 2.3f, 4.4f, // Labels
|
||||
0.0f, 5.01f, 0.0f,
|
||||
10.3f, 0.0f, 11.4f, 9.45f, 11.4f};
|
||||
dh::device_vector<bst_float> dlabels(hlabels);
|
||||
|
||||
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
|
||||
{0, 4, 7, 12}, // Groups
|
||||
hlabels,
|
||||
{4.4f, 3.1f, 2.3f, 0.0f, // Expected sorted labels
|
||||
5.01f, 0.0f, 0.0f,
|
||||
11.4f, 11.4f, 10.3f, 9.45f, 0.0f},
|
||||
{3, 0, 2, 1, // Expected original positions
|
||||
5, 4, 6,
|
||||
9, 11, 7, 10, 8});
|
||||
|
||||
// Create MAP stats on the device first using the objective
|
||||
std::vector<bst_float> hpreds{-9.78f, 24.367f, 0.908f, -11.47f,
|
||||
-1.03f, -2.79f, -3.1f,
|
||||
104.22f, 103.1f, -101.7f, 100.5f, 45.1f};
|
||||
dh::device_vector<bst_float> dpreds(hpreds);
|
||||
|
||||
xgboost::obj::MAPLambdaWeightComputer map_lw_computer(dpreds.data().get(),
|
||||
dlabels.data().get(),
|
||||
*segment_label_sorter);
|
||||
|
||||
// Get the device MAP stats on host
|
||||
std::vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> dmap_stats(
|
||||
segment_label_sorter->GetNumItems());
|
||||
dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan());
|
||||
|
||||
// Compute the MAP stats on host next to compare
|
||||
std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
|
||||
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
|
||||
|
||||
for (size_t i = 0; i < hgroups.size() - 1; ++i) {
|
||||
auto gbegin = hgroups[i];
|
||||
auto gend = hgroups[i + 1];
|
||||
std::vector<xgboost::obj::ListEntry> lst_entry;
|
||||
for (auto j = gbegin; j < gend; ++j) {
|
||||
lst_entry.emplace_back(hpreds[j], hlabels[j], j);
|
||||
}
|
||||
std::stable_sort(lst_entry.begin(), lst_entry.end(), xgboost::obj::ListEntry::CmpPred);
|
||||
|
||||
// Compute the MAP stats with this list and compare with the ones computed on the device
|
||||
std::vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> hmap_stats;
|
||||
xgboost::obj::MAPLambdaWeightComputer::GetMAPStats(lst_entry, &hmap_stats);
|
||||
for (auto j = gbegin; j < gend; ++j) {
|
||||
EXPECT_EQ(dmap_stats[j].hits, hmap_stats[j - gbegin].hits);
|
||||
EXPECT_NEAR(dmap_stats[j].ap_acc, hmap_stats[j - gbegin].ap_acc, 0.01f);
|
||||
EXPECT_NEAR(dmap_stats[j].ap_acc_miss, hmap_stats[j - gbegin].ap_acc_miss, 0.01f);
|
||||
EXPECT_NEAR(dmap_stats[j].ap_acc_add, hmap_stats[j - gbegin].ap_acc_add, 0.01f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -176,7 +176,7 @@ def test_ranking():
|
||||
def test_ranking_metric() -> None:
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
X, y, qid, w = tm.make_ltr(512, 4, 3, 2)
|
||||
X, y, qid, w = tm.make_ltr(512, 4, 3, 1)
|
||||
# use auc for test as ndcg_score in sklearn works only on label gain instead of exp
|
||||
# gain.
|
||||
# note that the auc in sklearn is different from the one in XGBoost. The one in
|
||||
|
||||
@ -1343,61 +1343,94 @@ class XgboostLocalTest(SparkTestCase):
|
||||
SparkXGBClassifier(evals_result={})
|
||||
|
||||
|
||||
class XgboostRankerLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
|
||||
self.ranker_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
|
||||
],
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
self.ranker_df_test = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
|
||||
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
|
||||
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
|
||||
],
|
||||
["features", "qid", "expected_prediction"],
|
||||
)
|
||||
self.ranker_df_train_1 = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
|
||||
]
|
||||
* 4,
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1"))
|
||||
|
||||
def test_ranker(self):
|
||||
ranker = SparkXGBRanker(qid_col="qid")
|
||||
|
||||
@pytest.fixture
|
||||
def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
|
||||
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
|
||||
ranker_df_train = spark.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
|
||||
],
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
X_train = np.array(
|
||||
[
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0],
|
||||
[9.0, 4.0, 8.0],
|
||||
[np.NaN, 1.0, 5.5],
|
||||
[np.NaN, 6.0, 7.5],
|
||||
[np.NaN, 8.0, 9.5],
|
||||
]
|
||||
)
|
||||
qid_train = np.array([0, 0, 0, 1, 1, 1])
|
||||
y_train = np.array([0, 1, 2, 0, 1, 2])
|
||||
|
||||
X_test = np.array(
|
||||
[
|
||||
[1.5, 2.0, 3.0],
|
||||
[4.5, 5.0, 6.0],
|
||||
[9.0, 4.5, 8.0],
|
||||
[np.NaN, 1.0, 6.0],
|
||||
[np.NaN, 6.0, 7.0],
|
||||
[np.NaN, 8.0, 10.5],
|
||||
]
|
||||
)
|
||||
|
||||
ltr = xgb.XGBRanker(tree_method="approx", objective="rank:pairwise")
|
||||
ltr.fit(X_train, y_train, qid=qid_train)
|
||||
predt = ltr.predict(X_test)
|
||||
|
||||
ranker_df_test = spark.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.5, 2.0, 3.0), 0, float(predt[0])),
|
||||
(Vectors.dense(4.5, 5.0, 6.0), 0, float(predt[1])),
|
||||
(Vectors.dense(9.0, 4.5, 8.0), 0, float(predt[2])),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, float(predt[3])),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, float(predt[4])),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, float(predt[5])),
|
||||
],
|
||||
["features", "qid", "expected_prediction"],
|
||||
)
|
||||
ranker_df_train_1 = spark.createDataFrame(
|
||||
[
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
|
||||
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
|
||||
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
|
||||
]
|
||||
* 4,
|
||||
["features", "label", "qid"],
|
||||
)
|
||||
yield LTRData(ranker_df_train, ranker_df_test, ranker_df_train_1)
|
||||
|
||||
|
||||
class TestPySparkLocalLETOR:
|
||||
def test_ranker(self, ltr_data: LTRData) -> None:
|
||||
ranker = SparkXGBRanker(qid_col="qid", objective="rank:pairwise")
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
|
||||
model = ranker.fit(self.ranker_df_train)
|
||||
pred_result = model.transform(self.ranker_df_test).collect()
|
||||
|
||||
model = ranker.fit(ltr_data.df_train)
|
||||
pred_result = model.transform(ltr_data.df_test).collect()
|
||||
for row in pred_result:
|
||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||
|
||||
def test_ranker_qid_sorted(self):
|
||||
ranker = SparkXGBRanker(qid_col="qid", num_workers=4)
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
|
||||
model = ranker.fit(self.ranker_df_train_1)
|
||||
model.transform(self.ranker_df_test).collect()
|
||||
def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None:
|
||||
ranker = SparkXGBRanker(qid_col="qid", num_workers=4, objective="rank:ndcg")
|
||||
assert ranker.getOrDefault(ranker.objective) == "rank:ndcg"
|
||||
model = ranker.fit(ltr_data.df_train_1)
|
||||
model.transform(ltr_data.df_test).collect()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user