- ndcg ltr implementation on gpu (#5004)
* - ndcg ltr implementation on gpu - this is a follow-up to the pairwise ltr implementation
This commit is contained in:
parent
f4e7b707c9
commit
2abe69d774
@ -98,7 +98,7 @@ struct EvalAMS : public Metric {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
rec[i] = std::make_pair(h_preds[i], i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
if (ntop == 0) ntop = ndata;
|
||||
const double br = 10.0;
|
||||
@ -168,7 +168,7 @@ struct EvalAuc : public Metric {
|
||||
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
|
||||
rec.emplace_back(h_preds[j], j);
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// calculate AUC
|
||||
double sum_pospair = 0.0;
|
||||
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
|
||||
@ -321,7 +321,7 @@ struct EvalPrecision : public EvalRankList{
|
||||
protected:
|
||||
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
|
||||
// calculate Precision
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhit = 0;
|
||||
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
|
||||
nhit += (rec[j].second != 0);
|
||||
@ -369,7 +369,7 @@ struct EvalMAP : public EvalRankList {
|
||||
|
||||
protected:
|
||||
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhits = 0;
|
||||
double sumap = 0.0;
|
||||
for (size_t i = 0; i < rec.size(); ++i) {
|
||||
@ -481,7 +481,7 @@ struct EvalAucPR : public Metric {
|
||||
total_neg += wt * (1.0f - h_labels[j]);
|
||||
rec.emplace_back(h_preds[j], j);
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// we need pos > 0 && neg > 0
|
||||
if (0.0 == total_pos || 0.0 == total_neg) {
|
||||
auc_error += 1;
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST)
|
||||
DMLC_REGISTRY_FILE_TAG(rank_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -46,6 +46,185 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
// This type sorts an array which is divided into multiple groups. The sorting is influenced
|
||||
// by the function object 'Comparator'
|
||||
template <typename T>
|
||||
class SegmentSorter {
|
||||
private:
|
||||
// Items sorted within the group
|
||||
dh::caching_device_vector<T> ditems_;
|
||||
|
||||
// Original position of the items before they are sorted descendingly within its groups
|
||||
dh::caching_device_vector<uint32_t> doriginal_pos_;
|
||||
|
||||
// Segments within the original list that delineates the different groups
|
||||
dh::caching_device_vector<uint32_t> group_segments_;
|
||||
|
||||
// Need this on the device as it is used in the kernels
|
||||
dh::caching_device_vector<uint32_t> dgroups_; // Group information on device
|
||||
|
||||
// Initialize everything but the segments
|
||||
void Init(uint32_t num_elems) {
|
||||
ditems_.resize(num_elems);
|
||||
|
||||
doriginal_pos_.resize(num_elems);
|
||||
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
|
||||
}
|
||||
|
||||
// Initialize all with group info
|
||||
void Init(const std::vector<uint32_t> &groups) {
|
||||
uint32_t num_elems = groups.back();
|
||||
this->Init(num_elems);
|
||||
this->CreateGroupSegments(groups);
|
||||
}
|
||||
|
||||
public:
|
||||
// This needs to be public due to device lambda
|
||||
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
|
||||
uint32_t num_elems = groups.back();
|
||||
group_segments_.resize(num_elems);
|
||||
|
||||
dgroups_ = groups;
|
||||
|
||||
// Launch a kernel that populates the segment information for the different groups
|
||||
uint32_t *gsegs = group_segments_.data().get();
|
||||
const uint32_t *dgroups = dgroups_.data().get();
|
||||
uint32_t ngroups = dgroups_.size();
|
||||
int device_id = -1;
|
||||
dh::safe_cuda(cudaGetDevice(&device_id));
|
||||
dh::LaunchN(device_id, num_elems, nullptr, [=] __device__(uint32_t idx){
|
||||
// Find the group first
|
||||
uint32_t group_idx = dh::UpperBound(dgroups, ngroups, idx);
|
||||
gsegs[idx] = group_idx - 1;
|
||||
});
|
||||
}
|
||||
|
||||
// Accessors that returns device pointer
|
||||
inline const T *Items() const { return ditems_.data().get(); }
|
||||
inline uint32_t NumItems() const { return ditems_.size(); }
|
||||
inline const uint32_t *OriginalPositions() const { return doriginal_pos_.data().get(); }
|
||||
inline const dh::caching_device_vector<uint32_t> &GroupSegments() const {
|
||||
return group_segments_;
|
||||
}
|
||||
inline uint32_t NumGroups() const { return dgroups_.size() - 1; }
|
||||
inline const uint32_t *GroupIndices() const { return dgroups_.data().get(); }
|
||||
|
||||
// Sort an array that is divided into multiple groups. The array is sorted within each group.
|
||||
// This version provides the group information that is on the host.
|
||||
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
|
||||
// is used.
|
||||
template <typename Comparator = thrust::greater<T>>
|
||||
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
|
||||
const Comparator &comp = Comparator()) {
|
||||
this->Init(groups);
|
||||
this->SortItems(ditems, item_size, group_segments_, comp);
|
||||
}
|
||||
|
||||
// Sort an array that is divided into multiple groups. The array is sorted within each group.
|
||||
// This version provides the group information that is on the device.
|
||||
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
|
||||
// is used.
|
||||
template <typename Comparator = thrust::greater<T>>
|
||||
void SortItems(const T *ditems, uint32_t item_size,
|
||||
const dh::caching_device_vector<uint32_t> &group_segments,
|
||||
const Comparator &comp = Comparator()) {
|
||||
this->Init(item_size);
|
||||
|
||||
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
|
||||
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
|
||||
// when comparators are used. Hence, the following algorithm is used. This is done so that
|
||||
// we can grab the appropriate related values from the original list later, after the
|
||||
// items are sorted.
|
||||
//
|
||||
// Here is the internal representation:
|
||||
// dgroups_: [ 0, 3, 5, 8, 10 ]
|
||||
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
|
||||
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
|
||||
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
|
||||
//
|
||||
// Sort the items first and make a note of the original positions in doriginal_pos_
|
||||
// based on the sort
|
||||
// ditems_: 4 4 3 3 2 1 1 1 1 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
|
||||
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
|
||||
// in kernel, sorting using predicates etc.
|
||||
|
||||
ditems_.assign(thrust::device_ptr<const T>(ditems),
|
||||
thrust::device_ptr<const T>(ditems) + item_size);
|
||||
|
||||
// Allocator to be used by sort for managing space overhead while sorting
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
|
||||
ditems_.begin(), ditems_.end(),
|
||||
doriginal_pos_.begin(), comp);
|
||||
|
||||
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
|
||||
// holisitic item sort order on the segments
|
||||
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
|
||||
dh::caching_device_vector<uint32_t> group_segments_c(group_segments);
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
group_segments.begin(), group_segments_c.begin());
|
||||
|
||||
// Now, sort the group segments so that you may bring the items within the group together,
|
||||
// in the process also noting the relative changes to the doriginal_pos_ while that happens
|
||||
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
|
||||
group_segments_c.begin(), group_segments_c.end(),
|
||||
doriginal_pos_.begin(), thrust::less<uint32_t>());
|
||||
|
||||
// Finally, gather the original items based on doriginal_pos_ to sort the input and
|
||||
// to store them in ditems_
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
|
||||
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
thrust::device_ptr<const T>(ditems), ditems_.begin());
|
||||
}
|
||||
};
|
||||
|
||||
// Helper functions
|
||||
|
||||
// Items of size 'n' are sorted in a descending order
|
||||
// If left is true, find the number of elements > v; 0 if nothing is greater
|
||||
// If left is false, find the number of elements < v; 0 if nothing is lesser
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||
CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v) {
|
||||
const T *items_begin = items;
|
||||
uint32_t num_remaining = n;
|
||||
const T *middle_item = nullptr;
|
||||
uint32_t middle;
|
||||
while (num_remaining > 0) {
|
||||
middle_item = items_begin;
|
||||
middle = num_remaining / 2;
|
||||
middle_item += middle;
|
||||
if ((left && *middle_item > v) || (!left && !(v > *middle_item))) {
|
||||
items_begin = ++middle_item;
|
||||
num_remaining -= middle + 1;
|
||||
} else {
|
||||
num_remaining = middle;
|
||||
}
|
||||
}
|
||||
|
||||
return left ? items_begin - items : items + n - items_begin;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||
CountNumItemsToTheLeftOf(const T * __restrict__ items, uint32_t n, T v) {
|
||||
return CountNumItemsImpl(true, items, n, v);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE __forceinline__ uint32_t
|
||||
CountNumItemsToTheRightOf(const T * __restrict__ items, uint32_t n, T v) {
|
||||
return CountNumItemsImpl(false, items, n, v);
|
||||
}
|
||||
#endif
|
||||
|
||||
/*! \brief helper information in a list */
|
||||
struct ListEntry {
|
||||
/*! \brief the predict score we in the data */
|
||||
@ -96,16 +275,133 @@ struct PairwiseLambdaWeightComputer {
|
||||
return "rank:pairwise";
|
||||
}
|
||||
|
||||
// Stopgap method - will be removed when we support other type of ranking - ndcg, map etc.
|
||||
// Stopgap method - will be removed when we support other type of ranking - map
|
||||
// on GPU later
|
||||
inline static bool SupportOnGPU() { return true; }
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
PairwiseLambdaWeightComputer(const bst_float *dpreds,
|
||||
uint32_t pred_size,
|
||||
const SegmentSorter<float> &segment_label_sorter) {}
|
||||
|
||||
struct PairwiseLambdaWeightMultiplier {
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
return 1.0f;
|
||||
}
|
||||
};
|
||||
|
||||
inline PairwiseLambdaWeightMultiplier GetWeightMultiplier() const {
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
// beta version: NDCG lambda rank
|
||||
struct NDCGLambdaWeightComputer {
|
||||
// Stopgap method - will be removed when we support other type of ranking - ndcg, map etc.
|
||||
public:
|
||||
#if defined(__CUDACC__)
|
||||
// This function object computes the group's DCG for a given group
|
||||
struct ComputeGroupDCG {
|
||||
public:
|
||||
XGBOOST_DEVICE ComputeGroupDCG(const float *dsorted_labels, const uint32_t *dgroups)
|
||||
: dsorted_labels_(dsorted_labels),
|
||||
dgroups_(dgroups) {}
|
||||
|
||||
// Compute DCG for group 'gidx'
|
||||
__device__ __forceinline__ float operator()(uint32_t gidx) const {
|
||||
uint32_t group_begin = dgroups_[gidx];
|
||||
uint32_t group_end = dgroups_[gidx + 1];
|
||||
uint32_t group_size = group_end - group_begin;
|
||||
return ComputeGroupDCGWeight(&dsorted_labels_[group_begin], group_size);
|
||||
}
|
||||
|
||||
private:
|
||||
const float *dsorted_labels_{nullptr}; // Labels sorted within a group
|
||||
const uint32_t *dgroups_{nullptr}; // The group indices - where each group begins and ends
|
||||
};
|
||||
|
||||
// Type containing device pointers that can be cheaply copied on the kernel
|
||||
class NDCGLambdaWeightMultiplier {
|
||||
public:
|
||||
NDCGLambdaWeightMultiplier(const float *dsorted_labels,
|
||||
const uint32_t *dorig_pos,
|
||||
const uint32_t *dgroups,
|
||||
const float *dgroup_dcg_ptr,
|
||||
uint32_t *dindexable_sorted_preds_pos_ptr)
|
||||
: dsorted_labels_(dsorted_labels),
|
||||
dorig_pos_(dorig_pos),
|
||||
dgroups_(dgroups),
|
||||
dgroup_dcg_ptr_(dgroup_dcg_ptr),
|
||||
dindexable_sorted_preds_pos_ptr_(dindexable_sorted_preds_pos_ptr) {}
|
||||
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
if (dgroup_dcg_ptr_[gidx] == 0.0) return 0.0f;
|
||||
|
||||
uint32_t group_begin = dgroups_[gidx];
|
||||
|
||||
auto ppred_idx = dorig_pos_[pidx];
|
||||
auto npred_idx = dorig_pos_[nidx];
|
||||
KERNEL_CHECK(ppred_idx != npred_idx);
|
||||
|
||||
// Note: the label positive and negative indices are relative to the entire dataset.
|
||||
// Hence, scale them back to an index within the group
|
||||
ppred_idx = dindexable_sorted_preds_pos_ptr_[ppred_idx] - group_begin;
|
||||
npred_idx = dindexable_sorted_preds_pos_ptr_[npred_idx] - group_begin;
|
||||
return NDCGLambdaWeightComputer::ComputeDeltaWeight(
|
||||
ppred_idx, npred_idx,
|
||||
static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]),
|
||||
dgroup_dcg_ptr_[gidx]);
|
||||
}
|
||||
|
||||
private:
|
||||
const float *dsorted_labels_{nullptr}; // Labels sorted within a group
|
||||
const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted
|
||||
const uint32_t *dgroups_{nullptr}; // The group indices
|
||||
const float *dgroup_dcg_ptr_{nullptr}; // Start address of the group DCG values
|
||||
// Where can a prediction for a label be found in the original array, when they are sorted
|
||||
uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr};
|
||||
};
|
||||
|
||||
NDCGLambdaWeightComputer(const bst_float *dpreds,
|
||||
uint32_t pred_size,
|
||||
const SegmentSorter<float> &segment_label_sorter)
|
||||
: dgroup_dcg_(segment_label_sorter.NumGroups()),
|
||||
dindexable_sorted_preds_pos_(pred_size),
|
||||
weight_multiplier_(segment_label_sorter.Items(),
|
||||
segment_label_sorter.OriginalPositions(),
|
||||
segment_label_sorter.GroupIndices(),
|
||||
dgroup_dcg_.data().get(),
|
||||
dindexable_sorted_preds_pos_.data().get()) {
|
||||
// Sort the predictions first and get the sorted position
|
||||
SegmentSorter<float> segment_prediction_sorter;
|
||||
segment_prediction_sorter.SortItems(dpreds, pred_size, segment_label_sorter.GroupSegments());
|
||||
|
||||
this->CreateIndexableSortedPredictionPositions(segment_prediction_sorter.OriginalPositions());
|
||||
|
||||
// Compute each group's DCG concurrently
|
||||
// Set the values to be the group indices first so that the predicate knows which
|
||||
// group it is dealing with
|
||||
thrust::sequence(dgroup_dcg_.begin(), dgroup_dcg_.end());
|
||||
|
||||
// TODO(sriramch): parallelize across all elements, if possible
|
||||
// Transform each group - the predictate computes the group's DCG
|
||||
thrust::transform(dgroup_dcg_.begin(), dgroup_dcg_.end(),
|
||||
dgroup_dcg_.begin(),
|
||||
ComputeGroupDCG(segment_label_sorter.Items(),
|
||||
segment_label_sorter.GroupIndices()));
|
||||
}
|
||||
|
||||
inline NDCGLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; }
|
||||
inline const dh::caching_device_vector<uint32_t> &GetSortedPredPos() const {
|
||||
return dindexable_sorted_preds_pos_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Stopgap method - will be removed when we support other type of ranking - map
|
||||
// on GPU later
|
||||
inline static bool SupportOnGPU() { return false; }
|
||||
inline static bool SupportOnGPU() { return true; }
|
||||
|
||||
static void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||
std::vector<LambdaPair> *io_pairs) {
|
||||
@ -116,29 +412,20 @@ struct NDCGLambdaWeightComputer {
|
||||
for (size_t i = 0; i < sorted_list.size(); ++i) {
|
||||
labels[i] = sorted_list[i].label;
|
||||
}
|
||||
std::sort(labels.begin(), labels.end(), std::greater<bst_float>());
|
||||
IDCG = CalcDCG(labels);
|
||||
std::stable_sort(labels.begin(), labels.end(), std::greater<bst_float>());
|
||||
IDCG = ComputeGroupDCGWeight(&labels[0], labels.size());
|
||||
}
|
||||
if (IDCG == 0.0) {
|
||||
for (auto & pair : pairs) {
|
||||
pair.weight = 0.0f;
|
||||
}
|
||||
} else {
|
||||
IDCG = 1.0f / IDCG;
|
||||
for (auto & pair : pairs) {
|
||||
unsigned pos_idx = pair.pos_index;
|
||||
unsigned neg_idx = pair.neg_index;
|
||||
float pos_loginv = 1.0f / std::log2(pos_idx + 2.0f);
|
||||
float neg_loginv = 1.0f / std::log2(neg_idx + 2.0f);
|
||||
auto pos_label = static_cast<int>(sorted_list[pos_idx].label);
|
||||
auto neg_label = static_cast<int>(sorted_list[neg_idx].label);
|
||||
bst_float original =
|
||||
((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
|
||||
float changed =
|
||||
((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
|
||||
bst_float delta = (original - changed) * IDCG;
|
||||
if (delta < 0.0f) delta = - delta;
|
||||
pair.weight *= delta;
|
||||
pair.weight *= ComputeDeltaWeight(pos_idx, neg_idx,
|
||||
sorted_list[pos_idx].label, sorted_list[neg_idx].label,
|
||||
IDCG);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -148,16 +435,77 @@ struct NDCGLambdaWeightComputer {
|
||||
}
|
||||
|
||||
private:
|
||||
inline static bst_float CalcDCG(const std::vector<bst_float> &labels) {
|
||||
XGBOOST_DEVICE inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels,
|
||||
uint32_t size) {
|
||||
double sumdcg = 0.0;
|
||||
for (size_t i = 0; i < labels.size(); ++i) {
|
||||
const auto rel = static_cast<unsigned>(labels[i]);
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
const auto rel = static_cast<unsigned>(sorted_labels[i]);
|
||||
if (rel != 0) {
|
||||
sumdcg += ((1 << rel) - 1) / std::log2(static_cast<bst_float>(i + 2));
|
||||
}
|
||||
}
|
||||
return static_cast<bst_float>(sumdcg);
|
||||
}
|
||||
|
||||
// Compute the weight adjustment for an item within a group:
|
||||
// ppred_idx => Where does the positive label live, had the list been sorted by prediction
|
||||
// npred_idx => Where does the negative label live, had the list been sorted by prediction
|
||||
// pos_label => positive label value from sorted label list
|
||||
// neg_label => negative label value from sorted label list
|
||||
XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t ppred_idx, uint32_t npred_idx,
|
||||
int pos_label, int neg_label,
|
||||
float idcg) {
|
||||
float pos_loginv = 1.0f / std::log2(ppred_idx + 2.0f);
|
||||
float neg_loginv = 1.0f / std::log2(npred_idx + 2.0f);
|
||||
bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
|
||||
float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
|
||||
bst_float delta = (original - changed) * (1.0f / idcg);
|
||||
if (delta < 0.0f) delta = - delta;
|
||||
return delta;
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
// While computing the weight that needs to be adjusted by this ranking objective, we need
|
||||
// to figure out where positive and negative labels chosen earlier exists, if the group
|
||||
// were to be sorted by its predictions. To accommodate this, we employ the following algorithm.
|
||||
// For a given group, let's assume the following:
|
||||
// labels: 1 5 9 2 4 8 0 7 6 3
|
||||
// predictions: 1 9 0 8 2 7 3 6 5 4
|
||||
// position: 0 1 2 3 4 5 6 7 8 9
|
||||
//
|
||||
// After label sort:
|
||||
// labels: 9 8 7 6 5 4 3 2 1 0
|
||||
// position: 2 5 7 8 1 4 9 3 0 6
|
||||
//
|
||||
// After prediction sort:
|
||||
// predictions: 9 8 7 6 5 4 3 2 1 0
|
||||
// position: 1 3 5 7 8 9 6 4 0 2
|
||||
//
|
||||
// If a sorted label at position 'x' is chosen, then we need to find out where the prediction
|
||||
// for this label 'x' exists, if the group were to be sorted by predictions.
|
||||
// We first take the sorted prediction positions:
|
||||
// position: 1 3 5 7 8 9 6 4 0 2
|
||||
// at indices: 0 1 2 3 4 5 6 7 8 9
|
||||
//
|
||||
// We create a sorted prediction positional array, such that value at position 'x' gives
|
||||
// us the position in the sorted prediction array where its related prediction lies.
|
||||
// dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5
|
||||
// at indices: 0 1 2 3 4 5 6 7 8 9
|
||||
// Basically, swap the previous 2 arrays, sort the indices and reorder positions
|
||||
// for an O(1) lookup using the position where the sorted label exists
|
||||
void CreateIndexableSortedPredictionPositions(const uint32_t *dsorted_preds_pos) {
|
||||
dh::caching_device_vector<uint32_t> indices(dindexable_sorted_preds_pos_.size());
|
||||
thrust::sequence(indices.begin(), indices.end());
|
||||
thrust::scatter(indices.begin(), indices.end(), // Rearrange indices...
|
||||
thrust::device_ptr<const uint32_t>(dsorted_preds_pos), // ...based on this map
|
||||
dindexable_sorted_preds_pos_.begin()); // Write results into this
|
||||
}
|
||||
|
||||
dh::caching_device_vector<float> dgroup_dcg_;
|
||||
// Where can a prediction for a label be found in the original array, when they are sorted
|
||||
dh::caching_device_vector<uint32_t> dindexable_sorted_preds_pos_;
|
||||
NDCGLambdaWeightMultiplier weight_multiplier_; // This computes the adjustment to the weight
|
||||
#endif
|
||||
};
|
||||
|
||||
struct MAPLambdaWeightComputer {
|
||||
@ -238,7 +586,7 @@ struct MAPLambdaWeightComputer {
|
||||
}
|
||||
|
||||
public:
|
||||
// Stopgap method - will be removed when we support other type of ranking - ndcg, map etc.
|
||||
// Stopgap method - will be removed when we support other type of ranking - map
|
||||
// on GPU later
|
||||
inline static bool SupportOnGPU() { return false; }
|
||||
|
||||
@ -257,177 +605,79 @@ struct MAPLambdaWeightComputer {
|
||||
pair.neg_index, &map_stats);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
MAPLambdaWeightComputer(const bst_float *dpreds,
|
||||
uint32_t pred_size,
|
||||
const SegmentSorter<float> &segment_label_sorter) {}
|
||||
|
||||
struct MAPLambdaWeightMultiplier {
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
return 1.0f;
|
||||
}
|
||||
};
|
||||
|
||||
inline MAPLambdaWeightMultiplier GetWeightMultiplier() const {
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
// Helper functions
|
||||
|
||||
// Labels of size 'n' are sorted in a descending order
|
||||
// If left is true, find the number of elements > v; 0 if nothing is greater
|
||||
// If left is false, find the number of elements < v; 0 if nothing is lesser
|
||||
__device__ __forceinline__ int
|
||||
CountNumLabelsImpl(bool left, const float * __restrict__ labels, int n, float v) {
|
||||
const float *labels_begin = labels;
|
||||
int num_remaining = n;
|
||||
const float *middle_item = nullptr;
|
||||
int middle;
|
||||
while (num_remaining > 0) {
|
||||
middle_item = labels_begin;
|
||||
middle = num_remaining / 2;
|
||||
middle_item += middle;
|
||||
if ((left && *middle_item > v) || (!left && !(v > *middle_item))) {
|
||||
labels_begin = ++middle_item;
|
||||
num_remaining -= middle + 1;
|
||||
} else {
|
||||
num_remaining = middle;
|
||||
}
|
||||
}
|
||||
|
||||
return left ? labels_begin - labels : labels + n - labels_begin;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int
|
||||
CountNumLabelsToTheLeftOf(const float * __restrict__ labels, int n, float v) {
|
||||
return CountNumLabelsImpl(true, labels, n, v);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int
|
||||
CountNumLabelsToTheRightOf(const float * __restrict__ labels, int n, float v) {
|
||||
return CountNumLabelsImpl(false, labels, n, v);
|
||||
}
|
||||
|
||||
class SortedLabelList {
|
||||
class SortedLabelList : SegmentSorter<float> {
|
||||
private:
|
||||
// Labels sorted within the group
|
||||
dh::caching_device_vector<float> dsorted_labels_;
|
||||
|
||||
// Original position of the labels before they are sorted descendingly within its groups
|
||||
dh::caching_device_vector<uint32_t> doriginal_pos_;
|
||||
|
||||
// Segments within the original list that delineates the different groups
|
||||
dh::caching_device_vector<uint32_t> group_segments_;
|
||||
|
||||
// Need this on the device as it is used in the kernels
|
||||
dh::caching_device_vector<uint32_t> dgroups_; // Group information on device
|
||||
|
||||
int device_id_{-1}; // GPU device ID
|
||||
const LambdaRankParam ¶m_; // Objective configuration
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc_; // Allocator to be used by sort for managing
|
||||
// space overhead while sorting
|
||||
|
||||
public:
|
||||
SortedLabelList(int dev_id,
|
||||
const LambdaRankParam ¶m)
|
||||
: device_id_(dev_id),
|
||||
param_(param) {}
|
||||
explicit SortedLabelList(const LambdaRankParam ¶m)
|
||||
: param_(param) {}
|
||||
|
||||
void InitWithTrainingInfo(const std::vector<uint32_t> &groups) {
|
||||
int num_elems = groups.back();
|
||||
|
||||
dsorted_labels_.resize(num_elems);
|
||||
|
||||
doriginal_pos_.resize(num_elems);
|
||||
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
|
||||
|
||||
group_segments_.resize(num_elems);
|
||||
|
||||
dgroups_ = groups;
|
||||
|
||||
// Launch a kernel that populates the segment information for the different groups
|
||||
uint32_t *gsegs = group_segments_.data().get();
|
||||
const unsigned *dgroups = dgroups_.data().get();
|
||||
size_t ngroups = dgroups_.size();
|
||||
dh::LaunchN(device_id_, num_elems, nullptr, [=] __device__(unsigned idx){
|
||||
// Find the group first
|
||||
int group_idx = dh::UpperBound(dgroups, ngroups, idx);
|
||||
gsegs[idx] = group_idx - 1;
|
||||
});
|
||||
}
|
||||
|
||||
// Sort the groups by labels. We would like to avoid using predicates to perform the sort,
|
||||
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
|
||||
// when comparators are used. Hence, the following algorithm is used. This is done so that
|
||||
// we can grab the appropriate prediction values from the original list later, after the
|
||||
// labels are sorted.
|
||||
//
|
||||
// Here is the internal representation:
|
||||
// dgroups_: [ 0, 3, 5, 8, 10 ]
|
||||
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
|
||||
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
|
||||
// dsorted_labels_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original labels)
|
||||
//
|
||||
// Sort the labels first and make a note of the original positions in doriginal_pos_
|
||||
// based on the sort
|
||||
// dsorted_labels_: 4 4 3 3 2 1 1 1 1 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
|
||||
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
|
||||
// in kernel, sorting using predicates etc.
|
||||
void Sort(const HostDeviceVector<bst_float> &dlabels) {
|
||||
dsorted_labels_.assign(dh::tcbegin(dlabels), dh::tcend(dlabels));
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc_),
|
||||
dsorted_labels_.begin(), dsorted_labels_.end(),
|
||||
doriginal_pos_.begin(), thrust::greater<float>());
|
||||
|
||||
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
|
||||
// holisitic label sort order on the segments
|
||||
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
|
||||
thrust::device_vector<uint32_t> group_segments_c(group_segments_);
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
group_segments_.begin(), group_segments_c.begin());
|
||||
|
||||
// Now, sort the group segments so that you may bring the labels within the group together,
|
||||
// in the process also noting the relative changes to the doriginal_pos_ while that happens
|
||||
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc_),
|
||||
group_segments_c.begin(), group_segments_c.end(),
|
||||
doriginal_pos_.begin(), thrust::less<int>());
|
||||
|
||||
// Finally, gather the original labels based on doriginal_pos_ to sort the input and
|
||||
// to store them in dsorted_labels_
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
|
||||
// dsorted_labels_: 1 1 0 2 1 3 3 1 4 4 (from unsorted dlabels - dlabels)
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
dh::tcbegin(dlabels), dsorted_labels_.begin());
|
||||
}
|
||||
|
||||
~SortedLabelList() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
// Sort the labels that are grouped by 'groups'
|
||||
void Sort(const HostDeviceVector<bst_float> &dlabels, const std::vector<uint32_t> &groups) {
|
||||
this->SortItems(dlabels.ConstDevicePointer(), dlabels.Size(), groups);
|
||||
}
|
||||
|
||||
// This kernel can only run *after* the kernel in sort is completed, as they
|
||||
// use the default stream
|
||||
template <typename LambdaWeightComputerT>
|
||||
void ComputeGradients(const bst_float *dpreds,
|
||||
GradientPair *out_gpair,
|
||||
const HostDeviceVector<bst_float> &weights,
|
||||
int iter,
|
||||
GradientPair *out_gpair,
|
||||
float weight_normalization_factor) {
|
||||
// Group info on device
|
||||
const unsigned *dgroups = dgroups_.data().get();
|
||||
size_t ngroups = dgroups_.size();
|
||||
const uint32_t *dgroups = this->GroupIndices();
|
||||
uint32_t ngroups = this->NumGroups() + 1;
|
||||
|
||||
auto total_items = group_segments_.size();
|
||||
size_t niter = param_.num_pairsample * total_items;
|
||||
uint32_t total_items = this->NumItems();
|
||||
uint32_t niter = param_.num_pairsample * total_items;
|
||||
|
||||
float fix_list_weight = param_.fix_list_weight;
|
||||
|
||||
const uint32_t *original_pos = doriginal_pos_.data().get();
|
||||
const uint32_t *original_pos = this->OriginalPositions();
|
||||
|
||||
size_t num_weights = weights.Size();
|
||||
uint32_t num_weights = weights.Size();
|
||||
auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr;
|
||||
|
||||
const bst_float *sorted_labels = dsorted_labels_.data().get();
|
||||
const bst_float *sorted_labels = this->Items();
|
||||
|
||||
// This is used to adjust the weight of different elements based on the different ranking
|
||||
// objective function policies
|
||||
LambdaWeightComputerT weight_computer(dpreds, total_items, *this);
|
||||
auto wmultiplier = weight_computer.GetWeightMultiplier();
|
||||
|
||||
int device_id = -1;
|
||||
dh::safe_cuda(cudaGetDevice(&device_id));
|
||||
// For each instance in the group, compute the gradient pair concurrently
|
||||
dh::LaunchN(device_id_, niter, nullptr, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
|
||||
// First, determine the group 'idx' belongs to
|
||||
unsigned item_idx = idx % total_items;
|
||||
int group_idx = dh::UpperBound(dgroups, ngroups, item_idx);
|
||||
uint32_t item_idx = idx % total_items;
|
||||
uint32_t group_idx = dh::UpperBound(dgroups, ngroups, item_idx);
|
||||
// Span of this group within the larger labels/predictions sorted tuple
|
||||
int group_begin = dgroups[group_idx - 1];
|
||||
int group_end = dgroups[group_idx];
|
||||
int total_group_items = group_end - group_begin;
|
||||
uint32_t group_begin = dgroups[group_idx - 1];
|
||||
uint32_t group_end = dgroups[group_idx];
|
||||
uint32_t total_group_items = group_end - group_begin;
|
||||
|
||||
// Are the labels diverse enough? If they are all the same, then there is nothing to pick
|
||||
// from another group - bail sooner
|
||||
@ -435,14 +685,14 @@ class SortedLabelList {
|
||||
|
||||
// Find the number of labels less than and greater than the current label
|
||||
// at the sorted index position item_idx
|
||||
int nleft = CountNumLabelsToTheLeftOf(
|
||||
sorted_labels + group_begin, total_group_items, sorted_labels[item_idx]);
|
||||
int nright = CountNumLabelsToTheRightOf(
|
||||
sorted_labels + group_begin, total_group_items, sorted_labels[item_idx]);
|
||||
uint32_t nleft = CountNumItemsToTheLeftOf(
|
||||
sorted_labels + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
|
||||
uint32_t nright = CountNumItemsToTheRightOf(
|
||||
sorted_labels + item_idx, group_end - item_idx, sorted_labels[item_idx]);
|
||||
|
||||
// Create a minstd_rand object to act as our source of randomness
|
||||
thrust::minstd_rand rng;
|
||||
rng.discard(idx);
|
||||
thrust::minstd_rand rng((iter + 1) * 1111);
|
||||
rng.discard(((idx / total_items) * total_group_items) + item_idx - group_begin);
|
||||
// Create a uniform_int_distribution to produce a sample from outside of the
|
||||
// present label group
|
||||
thrust::uniform_int_distribution<int> dist(0, nleft + nright - 1);
|
||||
@ -457,7 +707,7 @@ class SortedLabelList {
|
||||
neg_idx = item_idx;
|
||||
} else {
|
||||
pos_idx = item_idx;
|
||||
int items_in_group = total_group_items - nleft - nright;
|
||||
uint32_t items_in_group = total_group_items - nleft - nright;
|
||||
neg_idx = sample + items_in_group + group_begin;
|
||||
}
|
||||
|
||||
@ -468,13 +718,14 @@ class SortedLabelList {
|
||||
bst_float h = thrust::max(p * (1.0f - p), eps);
|
||||
|
||||
// Rescale each gradient and hessian so that the group has a weighted constant
|
||||
float scale = 1.0f / (niter / total_items);
|
||||
float scale = __frcp_ru(niter / total_items);
|
||||
if (fix_list_weight != 0.0f) {
|
||||
scale *= fix_list_weight / total_group_items;
|
||||
}
|
||||
|
||||
float weight = num_weights ? dweights[group_idx - 1] : 1.0f;
|
||||
weight *= weight_normalization_factor;
|
||||
weight *= wmultiplier.GetWeight(group_idx - 1, pos_idx, neg_idx);
|
||||
weight *= scale;
|
||||
// Accumulate gradient and hessian in both positive and negative indices
|
||||
const GradientPair in_pos_gpair(g * weight, 2.0f * weight * h);
|
||||
@ -483,6 +734,9 @@ class SortedLabelList {
|
||||
const GradientPair in_neg_gpair(-g * weight, 2.0f * weight * h);
|
||||
dh::AtomicAddGpair(&out_gpair[original_pos[neg_idx]], in_neg_gpair);
|
||||
});
|
||||
|
||||
// Wait until the computations done by the kernel is complete
|
||||
dh::safe_cuda(cudaStreamSynchronize(nullptr));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@ -512,7 +766,7 @@ class LambdaRankObj : public ObjFunction {
|
||||
// Check if we have a GPU assignment; else, revert back to CPU
|
||||
auto device = tparam_->gpu_id;
|
||||
if (device >= 0 && LambdaWeightComputerT::SupportOnGPU()) {
|
||||
ComputeGradientsOnGPU(preds, info, out_gpair, gptr);
|
||||
ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr);
|
||||
} else {
|
||||
// Revert back to CPU
|
||||
#endif
|
||||
@ -558,20 +812,21 @@ class LambdaRankObj : public ObjFunction {
|
||||
LOG(DEBUG) << "Computing pairwise gradients on CPU.";
|
||||
|
||||
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);
|
||||
|
||||
const auto& preds_h = preds.HostVector();
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
std::vector<GradientPair>& gpair = out_gpair->HostVector();
|
||||
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
out_gpair->Resize(preds.Size());
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
// parallel construct, declare random number generator here, so that each
|
||||
// thread use its own random number generator, seed by thread id and current iteration
|
||||
common::RandomEngine rnd(iter * 1111 + omp_get_thread_num());
|
||||
|
||||
std::minstd_rand rnd((iter + 1) * 1111);
|
||||
std::vector<LambdaPair> pairs;
|
||||
std::vector<ListEntry> lst;
|
||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||
const auto& preds_h = preds.HostVector();
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
std::vector<GradientPair>& gpair = out_gpair->HostVector();
|
||||
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
@ -580,12 +835,12 @@ class LambdaRankObj : public ObjFunction {
|
||||
lst.emplace_back(preds_h[j], labels[j], j);
|
||||
gpair[j] = GradientPair(0.0f, 0.0f);
|
||||
}
|
||||
std::sort(lst.begin(), lst.end(), ListEntry::CmpPred);
|
||||
std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred);
|
||||
rec.resize(lst.size());
|
||||
for (unsigned i = 0; i < lst.size(); ++i) {
|
||||
rec[i] = std::make_pair(lst[i].label, i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// enumerate buckets with same label, for each item in the lst, grab another sample randomly
|
||||
for (unsigned i = 0; i < rec.size(); ) {
|
||||
unsigned j = i + 1;
|
||||
@ -635,6 +890,7 @@ class LambdaRankObj : public ObjFunction {
|
||||
#if defined(__CUDACC__)
|
||||
void ComputeGradientsOnGPU(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair,
|
||||
const std::vector<unsigned> &gptr) {
|
||||
LOG(DEBUG) << "Computing pairwise gradients on GPU.";
|
||||
@ -655,33 +911,24 @@ class LambdaRankObj : public ObjFunction {
|
||||
auto d_preds = preds.ConstDevicePointer();
|
||||
auto d_gpair = out_gpair->DevicePointer();
|
||||
|
||||
if (!slist_) {
|
||||
slist_.reset(new SortedLabelList(device, param_));
|
||||
}
|
||||
|
||||
// Create segments based on group info
|
||||
slist_->InitWithTrainingInfo(gptr);
|
||||
SortedLabelList slist(param_);
|
||||
|
||||
// Sort the labels within the groups on the device
|
||||
slist_->Sort(info.labels_);
|
||||
slist.Sort(info.labels_, gptr);
|
||||
|
||||
// Initialize the gradients next
|
||||
out_gpair->Fill(GradientPair(0.0f, 0.0f));
|
||||
|
||||
// Finally, compute the gradients
|
||||
slist_->ComputeGradients(d_preds, d_gpair, info.weights_, weight_normalization_factor);
|
||||
|
||||
// Wait until the computations done by the kernel is complete
|
||||
dh::safe_cuda(cudaStreamSynchronize(nullptr));
|
||||
slist.ComputeGradients<LambdaWeightComputerT>
|
||||
(d_preds, info.weights_, iter, d_gpair, weight_normalization_factor);
|
||||
}
|
||||
#endif
|
||||
|
||||
LambdaRankParam param_;
|
||||
#if defined(__CUDACC__)
|
||||
std::unique_ptr<SortedLabelList> slist_;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if !defined(GTEST_TEST)
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(LambdaRankParam);
|
||||
|
||||
@ -696,6 +943,7 @@ XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, NDCGLambdaWeightComputer::Name())
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name())
|
||||
.describe("LambdaRank with MAP as objective.")
|
||||
.set_body([]() { return new LambdaRankObj<MAPLambdaWeightComputer>(); });
|
||||
#endif
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
@ -76,4 +76,33 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<xgboost::ObjFunction> obj {
|
||||
xgboost::ObjFunction::Create("rank:ndcg", &lparam)
|
||||
};
|
||||
obj->Configure(args);
|
||||
CheckConfigReload(obj, "rank:ndcg");
|
||||
|
||||
// Test with setting sample weight to second query group
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{2.0f, 0.0f},
|
||||
{0, 2, 4},
|
||||
{0.7f, -0.7f, 0.0f, 0.0f},
|
||||
{0.74f, 0.74f, 0.0f, 0.0f});
|
||||
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{1.0f, 1.0f},
|
||||
{0, 2, 4},
|
||||
{0.35f, -0.35f, 0.35f, -0.35f},
|
||||
{0.368f, 0.368f, 0.368f, 0.368f});
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1 +1,159 @@
|
||||
#include "test_ranking_obj.cc"
|
||||
|
||||
#include "../../../src/objective/rank_obj.cu"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T = uint32_t, typename Comparator = thrust::greater<T>>
|
||||
std::unique_ptr<xgboost::obj::SegmentSorter<T>>
|
||||
RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices,
|
||||
const std::vector<T> &hlabels,
|
||||
const std::vector<T> &expected_sorted_hlabels,
|
||||
const std::vector<uint32_t> &expected_orig_pos
|
||||
) {
|
||||
std::unique_ptr<xgboost::obj::SegmentSorter<T>> seg_sorter_ptr(
|
||||
new xgboost::obj::SegmentSorter<T>);
|
||||
xgboost::obj::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
|
||||
|
||||
// Create a bunch of unsorted labels on the device and sort it via the segment sorter
|
||||
dh::device_vector<T> dlabels(hlabels);
|
||||
seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator());
|
||||
|
||||
EXPECT_EQ(seg_sorter.NumItems(), group_indices.back());
|
||||
EXPECT_EQ(seg_sorter.NumGroups(), group_indices.size() - 1);
|
||||
|
||||
// Check the labels
|
||||
dh::device_vector<T> sorted_dlabels(seg_sorter.NumItems());
|
||||
sorted_dlabels.assign(thrust::device_ptr<const T>(seg_sorter.Items()),
|
||||
thrust::device_ptr<const T>(seg_sorter.Items())
|
||||
+ seg_sorter.NumItems());
|
||||
thrust::host_vector<T> sorted_hlabels(sorted_dlabels);
|
||||
EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels);
|
||||
|
||||
// Check the indices
|
||||
dh::device_vector<uint32_t> dorig_pos(seg_sorter.NumItems());
|
||||
dorig_pos.assign(thrust::device_ptr<const uint32_t>(seg_sorter.OriginalPositions()),
|
||||
thrust::device_ptr<const uint32_t>(seg_sorter.OriginalPositions())
|
||||
+ seg_sorter.NumItems());
|
||||
dh::device_vector<uint32_t> horig_pos(dorig_pos);
|
||||
EXPECT_EQ(expected_orig_pos, horig_pos);
|
||||
|
||||
return seg_sorter_ptr;
|
||||
}
|
||||
|
||||
TEST(Objective, RankSegmentSorterTest) {
|
||||
RankSegmentSorterTestImpl({0, 2, 4, 7, 10, 14, 18, 22, 26}, // Groups
|
||||
{1, 1, // Labels
|
||||
1, 2,
|
||||
3, 2, 1,
|
||||
1, 2, 1,
|
||||
1, 3, 4, 2,
|
||||
1, 2, 1, 1,
|
||||
1, 2, 2, 3,
|
||||
3, 3, 1, 2},
|
||||
{1, 1, // Expected sorted labels
|
||||
2, 1,
|
||||
3, 2, 1,
|
||||
2, 1, 1,
|
||||
4, 3, 2, 1,
|
||||
2, 1, 1, 1,
|
||||
3, 2, 2, 1,
|
||||
3, 3, 2, 1},
|
||||
{0, 1, // Expected original positions
|
||||
3, 2,
|
||||
4, 5, 6,
|
||||
8, 7, 9,
|
||||
12, 11, 13, 10,
|
||||
15, 14, 16, 17,
|
||||
21, 19, 20, 18,
|
||||
22, 23, 25, 24});
|
||||
}
|
||||
|
||||
TEST(Objective, RankSegmentSorterSingleGroupTest) {
|
||||
RankSegmentSorterTestImpl({0, 7}, // Groups
|
||||
{6, 1, 4, 3, 0, 5, 2}, // Labels
|
||||
{6, 5, 4, 3, 2, 1, 0}, // Expected sorted labels
|
||||
{0, 5, 2, 3, 6, 1, 4}); // Expected original positions
|
||||
}
|
||||
|
||||
TEST(Objective, RankSegmentSorterAscendingTest) {
|
||||
RankSegmentSorterTestImpl<uint32_t, thrust::less<uint32_t>>(
|
||||
{0, 4, 7}, // Groups
|
||||
{3, 1, 4, 2, // Labels
|
||||
6, 5, 7},
|
||||
{1, 2, 3, 4, // Expected sorted labels
|
||||
5, 6, 7},
|
||||
{1, 3, 0, 2, // Expected original positions
|
||||
5, 4, 6});
|
||||
}
|
||||
|
||||
using CountFunctor = uint32_t (*)(const int *, uint32_t, int);
|
||||
void RankItemCountImpl(const std::vector<int> &sorted_items, CountFunctor f,
|
||||
int find_val, uint32_t exp_val) {
|
||||
EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end());
|
||||
EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val);
|
||||
}
|
||||
|
||||
TEST(Objective, RankItemCountOnLeft) {
|
||||
// Items sorted descendingly
|
||||
std::vector<int> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
|
||||
10, static_cast<uint32_t>(0));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
|
||||
6, static_cast<uint32_t>(2));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
|
||||
4, static_cast<uint32_t>(3));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
|
||||
1, static_cast<uint32_t>(7));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
|
||||
0, static_cast<uint32_t>(12));
|
||||
}
|
||||
|
||||
TEST(Objective, RankItemCountOnRight) {
|
||||
// Items sorted descendingly
|
||||
std::vector<int> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
|
||||
10, static_cast<uint32_t>(11));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
|
||||
6, static_cast<uint32_t>(10));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
|
||||
4, static_cast<uint32_t>(6));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
|
||||
1, static_cast<uint32_t>(1));
|
||||
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
|
||||
0, static_cast<uint32_t>(0));
|
||||
}
|
||||
|
||||
TEST(Objective, NDCGLambdaWeightComputerTest) {
|
||||
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
|
||||
{0, 4, 7, 12}, // Groups
|
||||
{3.1f, 1.2f, 2.3f, 4.4f, // Labels
|
||||
7.8f, 5.01f, 6.96f,
|
||||
10.3f, 8.7f, 11.4f, 9.45f, 11.4f},
|
||||
{4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels
|
||||
7.8f, 6.96f, 5.01f,
|
||||
11.4f, 11.4f, 10.3f, 9.45f, 8.7f},
|
||||
{3, 0, 2, 1, // Expected original positions
|
||||
4, 6, 5,
|
||||
9, 11, 7, 10, 8});
|
||||
|
||||
// Created segmented predictions for the labels from above
|
||||
std::vector<bst_float> hpreds{-9.78f, 24.367f, 0.908f, -11.47f,
|
||||
-1.03f, -2.79f, -3.1f,
|
||||
104.22f, 103.1f, -101.7f, 100.5f, 45.1f};
|
||||
dh::device_vector<bst_float> dpreds(hpreds);
|
||||
xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(),
|
||||
dpreds.size(),
|
||||
*segment_label_sorter);
|
||||
|
||||
// Where will the predictions move from its current position, if they were sorted
|
||||
// descendingly?
|
||||
auto dsorted_pred_pos = ndcg_lw_computer.GetSortedPredPos();
|
||||
thrust::host_vector<uint32_t> hsorted_pred_pos(dsorted_pred_pos);
|
||||
std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3,
|
||||
4, 5, 6,
|
||||
7, 8, 11, 9, 10};
|
||||
EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
143
tests/python-gpu/test_gpu_ranking.py
Normal file
143
tests/python-gpu/test_gpu_ranking.py
Normal file
@ -0,0 +1,143 @@
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import xgboost
|
||||
import os
|
||||
import math
|
||||
import unittest
|
||||
import itertools
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
class TestRanking(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
# download the test data
|
||||
cls.dpath = 'demo/rank/'
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = cls.dpath + '/MQ2008.zip'
|
||||
|
||||
if os.path.exists(cls.dpath) and os.path.exists(target):
|
||||
print ("Skipping dataset download...")
|
||||
else:
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
|
||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
# instantiate the matrices
|
||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
|
||||
cls.dtest = xgboost.DMatrix(x_test, y_test)
|
||||
# set the group counts from the query IDs
|
||||
cls.dtrain.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_train)])
|
||||
cls.dtest.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_test)])
|
||||
cls.dvalid.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_valid)])
|
||||
# save the query IDs for testing
|
||||
cls.qid_train = qid_train
|
||||
cls.qid_test = qid_test
|
||||
cls.qid_valid = qid_valid
|
||||
|
||||
# model training parameters
|
||||
cls.params = {'booster': 'gbtree',
|
||||
'tree_method': 'gpu_hist',
|
||||
'gpu_id': 0,
|
||||
'predictor': 'gpu_predictor'
|
||||
}
|
||||
cls.cpu_params = {'booster': 'gbtree',
|
||||
'tree_method': 'hist',
|
||||
'gpu_id': -1,
|
||||
'predictor': 'cpu_predictor'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""
|
||||
Cleanup test artifacts from download and unpacking
|
||||
:return:
|
||||
"""
|
||||
os.remove(cls.dpath + "MQ2008.zip")
|
||||
shutil.rmtree(cls.dpath + "MQ2008")
|
||||
|
||||
@classmethod
|
||||
def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02):
|
||||
"""
|
||||
Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates
|
||||
the metric and determines if the delta between the metric is within the tolerance level
|
||||
:return:
|
||||
"""
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')]
|
||||
|
||||
num_trees=2500
|
||||
check_metric_improvement_rounds=10
|
||||
|
||||
evals_result = {}
|
||||
cls.params['objective'] = rank_objective
|
||||
cls.params['eval_metric'] = metric_name
|
||||
bst = xgboost.train(cls.params, cls.dtrain, num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist, evals_result=evals_result)
|
||||
gpu_map_metric = evals_result['train'][metric_name][-1]
|
||||
|
||||
evals_result = {}
|
||||
cls.cpu_params['objective'] = rank_objective
|
||||
cls.cpu_params['eval_metric'] = metric_name
|
||||
bstc = xgboost.train(cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist, evals_result=evals_result)
|
||||
cpu_map_metric = evals_result['train'][metric_name][-1]
|
||||
|
||||
print("{0} gpu {1} metric {2}".format(rank_objective, metric_name, gpu_map_metric))
|
||||
print("{0} cpu {1} metric {2}".format(rank_objective, metric_name, cpu_map_metric))
|
||||
print("gpu best score {0} cpu best score {1}".format(bst.best_score, bstc.best_score))
|
||||
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, tolerance)
|
||||
assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance)
|
||||
|
||||
def test_training_rank_pairwise_map_metric(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with pairwise objective function and compare map metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:pairwise', 'map')
|
||||
|
||||
def test_training_rank_pairwise_auc_metric(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with pairwise objective function and compare auc metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:pairwise', 'auc')
|
||||
|
||||
def test_training_rank_pairwise_ndcg_metric(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with pairwise objective function and compare ndcg metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:pairwise', 'ndcg')
|
||||
|
||||
def test_training_rank_ndcg_map(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with ndcg objective function and compare map metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:ndcg', 'map')
|
||||
|
||||
def test_training_rank_ndcg_auc(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with ndcg objective function and compare auc metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:ndcg', 'auc')
|
||||
|
||||
def test_training_rank_ndcg_ndcg(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with ndcg objective function and compare ndcg metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:ndcg', 'ndcg')
|
||||
Loading…
x
Reference in New Issue
Block a user