From ee81ba8e1fc82f165119bf8fb86b147c35e06e52 Mon Sep 17 00:00:00 2001 From: sriramch <33358417+sriramch@users.noreply.github.com> Date: Thu, 26 Dec 2019 15:05:38 -0800 Subject: [PATCH] implementation of map ranking algorithm on gpu (#5129) * - implementation of map ranking algorithm - also effected necessary suggestions mentioned in the earlier ranking pr's - made some performance improvements to the ndcg algo as well --- src/common/device_helpers.cuh | 125 +++- src/objective/rank_obj.cu | 601 +++++++++++++------- tests/cpp/common/test_device_helpers.cu | 78 +++ tests/cpp/objective/test_ranking_obj.cc | 29 + tests/cpp/objective/test_ranking_obj_gpu.cu | 129 ++++- tests/python-gpu/test_gpu_ranking.py | 18 + 6 files changed, 714 insertions(+), 266 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 42f027d10..c695087ac 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -142,31 +142,92 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8)); } -/*! - * \brief Find the strict upper bound for an element in a sorted array - * using binary search. - * \param cuts pointer to the first element of the sorted array - * \param n length of the sorted array - * \param v value for which to find the upper bound - * \return the smallest index i such that v < cuts[i], or n if v is greater or equal - * than all elements of the array -*/ -template -DEV_INLINE int UpperBound(const T* __restrict__ cuts, int n, T v) { - if (n == 0) { return 0; } - if (cuts[n - 1] <= v) { return n; } - if (cuts[0] > v) { return 0; } +namespace internal { - int left = 0, right = n - 1; - while (right - left > 1) { - int middle = left + (right - left) / 2; - if (cuts[middle] > v) { - right = middle; +// Items of size 'n' are sorted in an order determined by the Comparator +// If left is true, find the number of elements where 'comp(item, v)' returns true; +// 0 if nothing is true +// If left is false, find the number of elements where '!comp(item, v)' returns true; +// 0 if nothing is true +template > +XGBOOST_DEVICE __forceinline__ uint32_t +CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v, + const Comparator &comp = Comparator()) { + const T *items_begin = items; + uint32_t num_remaining = n; + const T *middle_item = nullptr; + uint32_t middle; + while (num_remaining > 0) { + middle_item = items_begin; + middle = num_remaining / 2; + middle_item += middle; + if ((left && comp(*middle_item, v)) || (!left && !comp(v, *middle_item))) { + items_begin = ++middle_item; + num_remaining -= middle + 1; } else { - left = middle; + num_remaining = middle; } } - return right; + + return left ? items_begin - items : items + n - items_begin; +} + +} + +/*! + * \brief Find the strict upper bound for an element in a sorted array + * using binary search. + * \param items pointer to the first element of the sorted array + * \param n length of the sorted array + * \param v value for which to find the upper bound + * \param comp determines how the items are sorted ascending/descending order - should conform + * to ordering semantics + * \return the smallest index i that has a value > v, or n if none is larger when sorted ascendingly + * or, an index i with a value < v, or 0 if none is smaller when sorted descendingly +*/ +// Preserve existing default behavior of upper bound +template > +XGBOOST_DEVICE __forceinline__ uint32_t UpperBound(const T *__restrict__ items, + uint32_t n, + T v, + const Comp &comp = Comp()) { + if (std::is_same>::value || + std::is_same>::value) { + return n - internal::CountNumItemsImpl(false, items, n, v, comp); + } else { + static_assert(std::is_same>::value || + std::is_same>::value, + "Invalid comparator used in Upperbound - can only be thrust::greater/less"); + return std::numeric_limits::max(); // Simply to quiesce the compiler + } +} + +/*! + * \brief Find the strict lower bound for an element in a sorted array + * using binary search. + * \param items pointer to the first element of the sorted array + * \param n length of the sorted array + * \param v value for which to find the upper bound + * \param comp determines how the items are sorted ascending/descending order - should conform + * to ordering semantics + * \return the smallest index i that has a value >= v, or n if none is larger + * when sorted ascendingly + * or, an index i with a value <= v, or 0 if none is smaller when sorted descendingly +*/ +template > +XGBOOST_DEVICE __forceinline__ uint32_t LowerBound(const T *__restrict__ items, + uint32_t n, + T v, + const Comp &comp = Comp()) { + if (std::is_same>::value || + std::is_same>::value) { + return internal::CountNumItemsImpl(true, items, n, v, comp); + } else { + static_assert(std::is_same>::value || + std::is_same>::value, + "Invalid comparator used in LowerBound - can only be thrust::greater/less"); + return std::numeric_limits::max(); // Simply to quiesce the compiler + } } template @@ -510,7 +571,7 @@ void CopyDeviceSpan(xgboost::common::Span dst, class BulkAllocator { std::vector d_ptr_; std::vector size_; - std::vector device_idx_; + int device_idx_{-1}; static const int kAlign = 256; @@ -593,14 +654,15 @@ class BulkAllocator { * This frees the GPU memory managed by this allocator. */ void Clear() { - for (size_t i = 0; i < d_ptr_.size(); i++) { // NOLINT(modernize-loop-convert) - if (d_ptr_[i] != nullptr) { - safe_cuda(cudaSetDevice(device_idx_[i])); - XGBDeviceAllocator allocator; - allocator.deallocate(thrust::device_ptr(d_ptr_[i]), size_[i]); - d_ptr_[i] = nullptr; - } - } + if (d_ptr_.empty()) return; + + safe_cuda(cudaSetDevice(device_idx_)); + size_t idx = 0; + std::for_each(d_ptr_.begin(), d_ptr_.end(), [&](char *dptr) { + XGBDeviceAllocator().deallocate(thrust::device_ptr(dptr), size_[idx++]); + }); + d_ptr_.clear(); + size_.clear(); } ~BulkAllocator() { @@ -614,6 +676,8 @@ class BulkAllocator { template void Allocate(int device_idx, Args... args) { + if (device_idx_ == -1) device_idx_ = device_idx; + else CHECK(device_idx_ == device_idx); size_t size = GetSizeBytes(args...); char *ptr = AllocateDevice(device_idx, size); @@ -622,7 +686,6 @@ class BulkAllocator { d_ptr_.push_back(ptr); size_.push_back(size); - device_idx_.push_back(device_idx); } }; diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index cb5276776..9b25a03dc 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -18,6 +18,7 @@ #if defined(__CUDACC__) #include #include +#include #include #include @@ -64,6 +65,9 @@ class SegmentSorter { // Need this on the device as it is used in the kernels dh::caching_device_vector dgroups_; // Group information on device + // Where did the item that was originally present at position 'x' move to after they are sorted + dh::caching_device_vector dindexable_sorted_pos_; + // Initialize everything but the segments void Init(uint32_t num_elems) { ditems_.resize(num_elems); @@ -87,28 +91,42 @@ class SegmentSorter { dgroups_ = groups; - // Launch a kernel that populates the segment information for the different groups - uint32_t *gsegs = group_segments_.data().get(); + // Define the segments by assigning a group ID to each element const uint32_t *dgroups = dgroups_.data().get(); uint32_t ngroups = dgroups_.size(); - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - dh::LaunchN(device_id, num_elems, nullptr, [=] __device__(uint32_t idx){ - // Find the group first - uint32_t group_idx = dh::UpperBound(dgroups, ngroups, idx); - gsegs[idx] = group_idx - 1; - }); + auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { + return dh::UpperBound(dgroups, ngroups, idx) - 1; + }; // NOLINT + + thrust::transform(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(num_elems), + group_segments_.begin(), + ComputeGroupIDLambda); } // Accessors that returns device pointer - inline const T *Items() const { return ditems_.data().get(); } - inline uint32_t NumItems() const { return ditems_.size(); } - inline const uint32_t *OriginalPositions() const { return doriginal_pos_.data().get(); } - inline const dh::caching_device_vector &GroupSegments() const { + inline const T *GetItemsPtr() const { return ditems_.data().get(); } + inline uint32_t GetNumItems() const { return ditems_.size(); } + inline const dh::caching_device_vector &GetItems() const { + return ditems_; + } + + inline const uint32_t *GetOriginalPositionsPtr() const { return doriginal_pos_.data().get(); } + inline const dh::caching_device_vector &GetOriginalPositions() const { + return doriginal_pos_; + } + + inline const dh::caching_device_vector &GetGroupSegments() const { return group_segments_; } - inline uint32_t NumGroups() const { return dgroups_.size() - 1; } - inline const uint32_t *GroupIndices() const { return dgroups_.data().get(); } + + inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; } + inline const uint32_t *GetGroupsPtr() const { return dgroups_.data().get(); } + inline const dh::caching_device_vector &GetGroups() const { return dgroups_; } + + inline const dh::caching_device_vector &GetIndexableSortedPositions() const { + return dindexable_sorted_pos_; + } // Sort an array that is divided into multiple groups. The array is sorted within each group. // This version provides the group information that is on the host. @@ -183,45 +201,31 @@ class SegmentSorter { thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), thrust::device_ptr(ditems), ditems_.begin()); } + + // Determine where an item that was originally present at position 'x' has been relocated to + // after a sort. Creation of such an index has to be explicitly requested after a sort + void CreateIndexableSortedPositions() { + dindexable_sorted_pos_.resize(GetNumItems()); + thrust::scatter(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(GetNumItems()), // Rearrange indices... + // ...based on this map + thrust::device_ptr(GetOriginalPositionsPtr()), + dindexable_sorted_pos_.begin()); // Write results into this + } }; // Helper functions -// Items of size 'n' are sorted in a descending order -// If left is true, find the number of elements > v; 0 if nothing is greater -// If left is false, find the number of elements < v; 0 if nothing is lesser -template -XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v) { - const T *items_begin = items; - uint32_t num_remaining = n; - const T *middle_item = nullptr; - uint32_t middle; - while (num_remaining > 0) { - middle_item = items_begin; - middle = num_remaining / 2; - middle_item += middle; - if ((left && *middle_item > v) || (!left && !(v > *middle_item))) { - items_begin = ++middle_item; - num_remaining -= middle + 1; - } else { - num_remaining = middle; - } - } - - return left ? items_begin - items : items + n - items_begin; -} - template XGBOOST_DEVICE __forceinline__ uint32_t CountNumItemsToTheLeftOf(const T * __restrict__ items, uint32_t n, T v) { - return CountNumItemsImpl(true, items, n, v); + return dh::LowerBound(items, n, v, thrust::greater()); } template XGBOOST_DEVICE __forceinline__ uint32_t CountNumItemsToTheRightOf(const T * __restrict__ items, uint32_t n, T v) { - return CountNumItemsImpl(false, items, n, v); + return n - dh::UpperBound(items, n, v, thrust::greater()); } #endif @@ -262,7 +266,8 @@ struct LambdaPair { : pos_index(pos_index), neg_index(neg_index), weight(weight) {} }; -struct PairwiseLambdaWeightComputer { +class PairwiseLambdaWeightComputer { + public: /*! * \brief get lambda weight for existing pairs - for pairwise objective * \param list a list that is sorted by pred score @@ -275,65 +280,131 @@ struct PairwiseLambdaWeightComputer { return "rank:pairwise"; } - // Stopgap method - will be removed when we support other type of ranking - map - // on GPU later - inline static bool SupportOnGPU() { return true; } - #if defined(__CUDACC__) PairwiseLambdaWeightComputer(const bst_float *dpreds, - uint32_t pred_size, + const bst_float *dlabels, const SegmentSorter &segment_label_sorter) {} - struct PairwiseLambdaWeightMultiplier { + class PairwiseLambdaWeightMultiplier { + public: // Adjust the items weight by this value __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { return 1.0f; } }; - inline PairwiseLambdaWeightMultiplier GetWeightMultiplier() const { + inline const PairwiseLambdaWeightMultiplier GetWeightMultiplier() const { return {}; } #endif }; +#if defined(__CUDACC__) +class BaseLambdaWeightMultiplier { + public: + BaseLambdaWeightMultiplier(const SegmentSorter &segment_label_sorter, + const SegmentSorter &segment_pred_sorter) + : dsorted_labels_(segment_label_sorter.GetItemsPtr()), + dorig_pos_(segment_label_sorter.GetOriginalPositionsPtr()), + dgroups_(segment_label_sorter.GetGroupsPtr()), + dindexable_sorted_preds_pos_ptr_( + segment_pred_sorter.GetIndexableSortedPositions().data().get()) {} + + protected: + const float *dsorted_labels_{nullptr}; // Labels sorted within a group + const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted + const uint32_t *dgroups_{nullptr}; // The group indices + // Where can a prediction for a label be found in the original array, when they are sorted + const uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr}; +}; + +// While computing the weight that needs to be adjusted by this ranking objective, we need +// to figure out where positive and negative labels chosen earlier exists, if the group +// were to be sorted by its predictions. To accommodate this, we employ the following algorithm. +// For a given group, let's assume the following: +// labels: 1 5 9 2 4 8 0 7 6 3 +// predictions: 1 9 0 8 2 7 3 6 5 4 +// position: 0 1 2 3 4 5 6 7 8 9 +// +// After label sort: +// labels: 9 8 7 6 5 4 3 2 1 0 +// position: 2 5 7 8 1 4 9 3 0 6 +// +// After prediction sort: +// predictions: 9 8 7 6 5 4 3 2 1 0 +// position: 1 3 5 7 8 9 6 4 0 2 +// +// If a sorted label at position 'x' is chosen, then we need to find out where the prediction +// for this label 'x' exists, if the group were to be sorted by predictions. +// We first take the sorted prediction positions: +// position: 1 3 5 7 8 9 6 4 0 2 +// at indices: 0 1 2 3 4 5 6 7 8 9 +// +// We create a sorted prediction positional array, such that value at position 'x' gives +// us the position in the sorted prediction array where its related prediction lies. +// dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5 +// at indices: 0 1 2 3 4 5 6 7 8 9 +// Basically, swap the previous 2 arrays, sort the indices and reorder positions +// for an O(1) lookup using the position where the sorted label exists. +// +// This type does that using the SegmentSorter +class IndexablePredictionSorter { + public: + IndexablePredictionSorter(const bst_float *dpreds, + const SegmentSorter &segment_label_sorter) { + // Sort the predictions first + segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(), + segment_label_sorter.GetGroupSegments()); + + // Create an index for the sorted prediction positions + segment_pred_sorter_.CreateIndexableSortedPositions(); + } + + inline const SegmentSorter &GetPredictionSorter() const { + return segment_pred_sorter_; + } + + private: + SegmentSorter segment_pred_sorter_; // For sorting the predictions +}; +#endif + // beta version: NDCG lambda rank -struct NDCGLambdaWeightComputer { +class NDCGLambdaWeightComputer +#if defined(__CUDACC__) + : public IndexablePredictionSorter +#endif +{ public: #if defined(__CUDACC__) - // This function object computes the group's DCG for a given group - struct ComputeGroupDCG { + // This function object computes the item's DCG value + class ComputeItemDCG : public thrust::unary_function { public: - XGBOOST_DEVICE ComputeGroupDCG(const float *dsorted_labels, const uint32_t *dgroups) + XGBOOST_DEVICE ComputeItemDCG(const float *dsorted_labels, + const uint32_t *dgroups, + const uint32_t *gidxs) : dsorted_labels_(dsorted_labels), - dgroups_(dgroups) {} + dgroups_(dgroups), + dgidxs_(gidxs) {} - // Compute DCG for group 'gidx' - __device__ __forceinline__ float operator()(uint32_t gidx) const { - uint32_t group_begin = dgroups_[gidx]; - uint32_t group_end = dgroups_[gidx + 1]; - uint32_t group_size = group_end - group_begin; - return ComputeGroupDCGWeight(&dsorted_labels_[group_begin], group_size); + // Compute DCG for the item at 'idx' + __device__ __forceinline__ float operator()(uint32_t idx) const { + return ComputeItemDCGWeight(dsorted_labels_[idx], idx - dgroups_[dgidxs_[idx]]); } private: const float *dsorted_labels_{nullptr}; // Labels sorted within a group const uint32_t *dgroups_{nullptr}; // The group indices - where each group begins and ends + const uint32_t *dgidxs_{nullptr}; // The group each items belongs to }; // Type containing device pointers that can be cheaply copied on the kernel - class NDCGLambdaWeightMultiplier { + class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { public: - NDCGLambdaWeightMultiplier(const float *dsorted_labels, - const uint32_t *dorig_pos, - const uint32_t *dgroups, - const float *dgroup_dcg_ptr, - uint32_t *dindexable_sorted_preds_pos_ptr) - : dsorted_labels_(dsorted_labels), - dorig_pos_(dorig_pos), - dgroups_(dgroups), - dgroup_dcg_ptr_(dgroup_dcg_ptr), - dindexable_sorted_preds_pos_ptr_(dindexable_sorted_preds_pos_ptr) {} + NDCGLambdaWeightMultiplier(const SegmentSorter &segment_label_sorter, + const NDCGLambdaWeightComputer &lwc) + : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), + dgroup_dcg_ptr_(lwc.GetGroupDcgs().data().get()) {} // Adjust the items weight by this value __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { @@ -341,68 +412,56 @@ struct NDCGLambdaWeightComputer { uint32_t group_begin = dgroups_[gidx]; - auto ppred_idx = dorig_pos_[pidx]; - auto npred_idx = dorig_pos_[nidx]; - KERNEL_CHECK(ppred_idx != npred_idx); + auto pos_lab_orig_posn = dorig_pos_[pidx]; + auto neg_lab_orig_posn = dorig_pos_[nidx]; + KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn); // Note: the label positive and negative indices are relative to the entire dataset. // Hence, scale them back to an index within the group - ppred_idx = dindexable_sorted_preds_pos_ptr_[ppred_idx] - group_begin; - npred_idx = dindexable_sorted_preds_pos_ptr_[npred_idx] - group_begin; + auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin; + auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin; return NDCGLambdaWeightComputer::ComputeDeltaWeight( - ppred_idx, npred_idx, + pos_pred_pos, neg_pred_pos, static_cast(dsorted_labels_[pidx]), static_cast(dsorted_labels_[nidx]), dgroup_dcg_ptr_[gidx]); } private: - const float *dsorted_labels_{nullptr}; // Labels sorted within a group - const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted - const uint32_t *dgroups_{nullptr}; // The group indices const float *dgroup_dcg_ptr_{nullptr}; // Start address of the group DCG values - // Where can a prediction for a label be found in the original array, when they are sorted - uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr}; }; NDCGLambdaWeightComputer(const bst_float *dpreds, - uint32_t pred_size, + const bst_float *dlabels, const SegmentSorter &segment_label_sorter) - : dgroup_dcg_(segment_label_sorter.NumGroups()), - dindexable_sorted_preds_pos_(pred_size), - weight_multiplier_(segment_label_sorter.Items(), - segment_label_sorter.OriginalPositions(), - segment_label_sorter.GroupIndices(), - dgroup_dcg_.data().get(), - dindexable_sorted_preds_pos_.data().get()) { - // Sort the predictions first and get the sorted position - SegmentSorter segment_prediction_sorter; - segment_prediction_sorter.SortItems(dpreds, pred_size, segment_label_sorter.GroupSegments()); + : IndexablePredictionSorter(dpreds, segment_label_sorter), + dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f), + weight_multiplier_(segment_label_sorter, *this) { + const auto &group_segments = segment_label_sorter.GetGroupSegments(); - this->CreateIndexableSortedPredictionPositions(segment_prediction_sorter.OriginalPositions()); - - // Compute each group's DCG concurrently - // Set the values to be the group indices first so that the predicate knows which - // group it is dealing with - thrust::sequence(dgroup_dcg_.begin(), dgroup_dcg_.end()); - - // TODO(sriramch): parallelize across all elements, if possible - // Transform each group - the predictate computes the group's DCG - thrust::transform(dgroup_dcg_.begin(), dgroup_dcg_.end(), - dgroup_dcg_.begin(), - ComputeGroupDCG(segment_label_sorter.Items(), - segment_label_sorter.GroupIndices())); + // Compute each elements DCG values and reduce them across groups concurrently. + auto end_range = + thrust::reduce_by_key(group_segments.begin(), group_segments.end(), + thrust::make_transform_iterator( + // The indices need not be sequential within a group, as we care only + // about the sum of items DCG values within a group + segment_label_sorter.GetOriginalPositions().begin(), + ComputeItemDCG(segment_label_sorter.GetItemsPtr(), + segment_label_sorter.GetGroupsPtr(), + group_segments.data().get())), + thrust::make_discard_iterator(), // We don't care for the group indices + dgroup_dcg_.begin()); // Sum of the item's DCG values in the group + CHECK(end_range.second - dgroup_dcg_.begin() == dgroup_dcg_.size()); } - inline NDCGLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; } - inline const dh::caching_device_vector &GetSortedPredPos() const { - return dindexable_sorted_preds_pos_; + inline const dh::caching_device_vector &GetGroupDcgs() const { + return dgroup_dcg_; + } + + inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const { + return weight_multiplier_; } #endif - // Stopgap method - will be removed when we support other type of ranking - map - // on GPU later - inline static bool SupportOnGPU() { return true; } - static void GetLambdaWeight(const std::vector &sorted_list, std::vector *io_pairs) { std::vector &pairs = *io_pairs; @@ -434,29 +493,31 @@ struct NDCGLambdaWeightComputer { return "rank:ndcg"; } - private: - XGBOOST_DEVICE inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels, - uint32_t size) { + inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels, uint32_t size) { double sumdcg = 0.0; for (uint32_t i = 0; i < size; ++i) { - const auto rel = static_cast(sorted_labels[i]); - if (rel != 0) { - sumdcg += ((1 << rel) - 1) / std::log2(static_cast(i + 2)); - } + sumdcg += ComputeItemDCGWeight(sorted_labels[i], i); } + return static_cast(sumdcg); } + private: + XGBOOST_DEVICE inline static bst_float ComputeItemDCGWeight(unsigned label, uint32_t idx) { + return (label != 0) ? (((1 << label) - 1) / std::log2(static_cast(idx + 2))) : 0; + } + // Compute the weight adjustment for an item within a group: - // ppred_idx => Where does the positive label live, had the list been sorted by prediction - // npred_idx => Where does the negative label live, had the list been sorted by prediction + // pos_pred_pos => Where does the positive label live, had the list been sorted by prediction + // neg_pred_pos => Where does the negative label live, had the list been sorted by prediction // pos_label => positive label value from sorted label list // neg_label => negative label value from sorted label list - XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t ppred_idx, uint32_t npred_idx, + XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t pos_pred_pos, + uint32_t neg_pred_pos, int pos_label, int neg_label, float idcg) { - float pos_loginv = 1.0f / std::log2(ppred_idx + 2.0f); - float neg_loginv = 1.0f / std::log2(npred_idx + 2.0f); + float pos_loginv = 1.0f / std::log2(pos_pred_pos + 2.0f); + float neg_loginv = 1.0f / std::log2(neg_pred_pos + 2.0f); bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv; float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv; bst_float delta = (original - changed) * (1.0f / idcg); @@ -465,105 +526,103 @@ struct NDCGLambdaWeightComputer { } #if defined(__CUDACC__) - // While computing the weight that needs to be adjusted by this ranking objective, we need - // to figure out where positive and negative labels chosen earlier exists, if the group - // were to be sorted by its predictions. To accommodate this, we employ the following algorithm. - // For a given group, let's assume the following: - // labels: 1 5 9 2 4 8 0 7 6 3 - // predictions: 1 9 0 8 2 7 3 6 5 4 - // position: 0 1 2 3 4 5 6 7 8 9 - // - // After label sort: - // labels: 9 8 7 6 5 4 3 2 1 0 - // position: 2 5 7 8 1 4 9 3 0 6 - // - // After prediction sort: - // predictions: 9 8 7 6 5 4 3 2 1 0 - // position: 1 3 5 7 8 9 6 4 0 2 - // - // If a sorted label at position 'x' is chosen, then we need to find out where the prediction - // for this label 'x' exists, if the group were to be sorted by predictions. - // We first take the sorted prediction positions: - // position: 1 3 5 7 8 9 6 4 0 2 - // at indices: 0 1 2 3 4 5 6 7 8 9 - // - // We create a sorted prediction positional array, such that value at position 'x' gives - // us the position in the sorted prediction array where its related prediction lies. - // dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5 - // at indices: 0 1 2 3 4 5 6 7 8 9 - // Basically, swap the previous 2 arrays, sort the indices and reorder positions - // for an O(1) lookup using the position where the sorted label exists - void CreateIndexableSortedPredictionPositions(const uint32_t *dsorted_preds_pos) { - dh::caching_device_vector indices(dindexable_sorted_preds_pos_.size()); - thrust::sequence(indices.begin(), indices.end()); - thrust::scatter(indices.begin(), indices.end(), // Rearrange indices... - thrust::device_ptr(dsorted_preds_pos), // ...based on this map - dindexable_sorted_preds_pos_.begin()); // Write results into this - } - dh::caching_device_vector dgroup_dcg_; - // Where can a prediction for a label be found in the original array, when they are sorted - dh::caching_device_vector dindexable_sorted_preds_pos_; - NDCGLambdaWeightMultiplier weight_multiplier_; // This computes the adjustment to the weight + // This computes the adjustment to the weight + const NDCGLambdaWeightMultiplier weight_multiplier_; #endif }; -struct MAPLambdaWeightComputer { - private: +class MAPLambdaWeightComputer +#if defined(__CUDACC__) + : public IndexablePredictionSorter +#endif +{ + public: struct MAPStats { /*! \brief the accumulated precision */ - float ap_acc; + float ap_acc{0.0f}; /*! * \brief the accumulated precision, * assuming a positive instance is missing */ - float ap_acc_miss; + float ap_acc_miss{0.0f}; /*! * \brief the accumulated precision, * assuming that one more positive instance is inserted ahead */ - float ap_acc_add; + float ap_acc_add{0.0f}; /* \brief the accumulated positive instance count */ - float hits; - MAPStats() = default; - MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits) - : ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {} + float hits{0.0f}; + + XGBOOST_DEVICE MAPStats() {} // NOLINT + XGBOOST_DEVICE MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits) + : ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {} + + // For prefix scan + XGBOOST_DEVICE MAPStats operator +(const MAPStats &v1) const { + return {ap_acc + v1.ap_acc, ap_acc_miss + v1.ap_acc_miss, + ap_acc_add + v1.ap_acc_add, hits + v1.hits}; + } + + // For test purposes - compare for equality + XGBOOST_DEVICE bool operator ==(const MAPStats &rhs) const { + return ap_acc == rhs.ap_acc && ap_acc_miss == rhs.ap_acc_miss && + ap_acc_add == rhs.ap_acc_add && hits == rhs.hits; + } }; + private: + template + XGBOOST_DEVICE inline static void Swap(T &v0, T &v1) { +#if defined(__CUDACC__) + thrust::swap(v0, v1); +#else + std::swap(v0, v1); +#endif + } + /*! - * \brief Obtain the delta MAP if trying to switch the positions of instances in index1 or index2 - * in sorted triples - * \param sorted_list the list containing entry information - * \param index1,index2 the instances switched - * \param map_stats a vector containing the accumulated precisions for each position in a list + * \brief Obtain the delta MAP by trying to switch the positions of labels in pos_pred_pos or + * neg_pred_pos when sorted by predictions + * \param pos_pred_pos positive label's prediction value position when the groups prediction + * values are sorted + * \param neg_pred_pos negative label's prediction value position when the groups prediction + * values are sorted + * \param pos_label, neg_label the chosen positive and negative labels + * \param p_map_stats a vector containing the accumulated precisions for each position in a list + * \param map_stats_size size of the accumulated precisions vector */ - inline static bst_float GetLambdaMAP(const std::vector &sorted_list, - int index1, int index2, - std::vector *p_map_stats) { - std::vector &map_stats = *p_map_stats; - if (index1 == index2 || map_stats[map_stats.size() - 1].hits == 0) { + XGBOOST_DEVICE inline static bst_float GetLambdaMAP( + int pos_pred_pos, int neg_pred_pos, + bst_float pos_label, bst_float neg_label, + const MAPStats *p_map_stats, uint32_t map_stats_size) { + if (pos_pred_pos == neg_pred_pos || p_map_stats[map_stats_size - 1].hits == 0) { return 0.0f; } - if (index1 > index2) std::swap(index1, index2); - bst_float original = map_stats[index2].ap_acc; - if (index1 != 0) original -= map_stats[index1 - 1].ap_acc; + if (pos_pred_pos > neg_pred_pos) { + Swap(pos_pred_pos, neg_pred_pos); + Swap(pos_label, neg_label); + } + bst_float original = p_map_stats[neg_pred_pos].ap_acc; + if (pos_pred_pos != 0) original -= p_map_stats[pos_pred_pos - 1].ap_acc; bst_float changed = 0; - bst_float label1 = sorted_list[index1].label > 0.0f ? 1.0f : 0.0f; - bst_float label2 = sorted_list[index2].label > 0.0f ? 1.0f : 0.0f; + bst_float label1 = pos_label > 0.0f ? 1.0f : 0.0f; + bst_float label2 = neg_label > 0.0f ? 1.0f : 0.0f; if (label1 == label2) { return 0.0; } else if (label1 < label2) { - changed += map_stats[index2 - 1].ap_acc_add - map_stats[index1].ap_acc_add; - changed += (map_stats[index1].hits + 1.0f) / (index1 + 1); + changed += p_map_stats[neg_pred_pos - 1].ap_acc_add - p_map_stats[pos_pred_pos].ap_acc_add; + changed += (p_map_stats[pos_pred_pos].hits + 1.0f) / (pos_pred_pos + 1); } else { - changed += map_stats[index2 - 1].ap_acc_miss - map_stats[index1].ap_acc_miss; - changed += map_stats[index2].hits / (index2 + 1); + changed += p_map_stats[neg_pred_pos - 1].ap_acc_miss - p_map_stats[pos_pred_pos].ap_acc_miss; + changed += p_map_stats[neg_pred_pos].hits / (neg_pred_pos + 1); } - bst_float ans = (changed - original) / (map_stats[map_stats.size() - 1].hits); + bst_float ans = (changed - original) / (p_map_stats[map_stats_size - 1].hits); if (ans < 0) ans = -ans; return ans; } + public: /* * \brief obtain preprocessing results for calculating delta MAP * \param sorted_list the list containing entry information @@ -585,11 +644,6 @@ struct MAPLambdaWeightComputer { } } - public: - // Stopgap method - will be removed when we support other type of ranking - map - // on GPU later - inline static bool SupportOnGPU() { return false; } - static char const* Name() { return "rank:map"; } @@ -601,26 +655,132 @@ struct MAPLambdaWeightComputer { GetMAPStats(sorted_list, &map_stats); for (auto & pair : pairs) { pair.weight *= - GetLambdaMAP(sorted_list, pair.pos_index, - pair.neg_index, &map_stats); + GetLambdaMAP(pair.pos_index, pair.neg_index, + sorted_list[pair.pos_index].label, sorted_list[pair.neg_index].label, + &map_stats[0], map_stats.size()); } } #if defined(__CUDACC__) MAPLambdaWeightComputer(const bst_float *dpreds, - uint32_t pred_size, - const SegmentSorter &segment_label_sorter) {} + const bst_float *dlabels, + const SegmentSorter &segment_label_sorter) + : IndexablePredictionSorter(dpreds, segment_label_sorter), + dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()), + weight_multiplier_(segment_label_sorter, *this) { + this->CreateMAPStats(dlabels, segment_label_sorter); + } + + void CreateMAPStats(const bst_float *dlabels, + const SegmentSorter &segment_label_sorter) { + // For each group, go through the sorted prediction positions, and look up its corresponding + // label from the unsorted labels (from the original label list) + + // For each item in the group, compute its MAP stats. + // Interleave the computation of map stats amongst different groups. + + // First, determine postive labels in the dataset individually + auto nitems = segment_label_sorter.GetNumItems(); + dh::caching_device_vector dhits(nitems, 0); + // Original positions of the predictions after they have been sorted + const uint32_t *pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsPtr(); + // Unsorted labels + const float *unsorted_labels = dlabels; + auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) { + return (unsorted_labels[pred_original_pos[idx]] > 0.0f) ? 1 : 0; + }; // NOLINT + + thrust::transform(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(nitems), + dhits.begin(), + DeterminePositiveLabelLambda); + + // Allocator to be used by sort for managing space overhead while performing prefix scans + dh::XGBCachingDeviceAllocator alloc; + + // Next, prefix scan the positive labels that are segmented to accumulate them. + // This is required for computing the accumulated precisions + const auto &group_segments = segment_label_sorter.GetGroupSegments(); + // Data segmented into different groups... + thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), + group_segments.begin(), group_segments.end(), + dhits.begin(), // Input value + dhits.begin()); // In-place scan + + // Compute accumulated precisions for each item, assuming positive and + // negative instances are missing. + // But first, compute individual item precisions + const auto *dgidx_arr = group_segments.data().get(); + const auto *dhits_arr = dhits.data().get(); + // Group info on device + const uint32_t *dgroups = segment_label_sorter.GetGroupsPtr(); + uint32_t ngroups = segment_label_sorter.GetNumGroups(); + auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) { + if (unsorted_labels[pred_original_pos[idx]] > 0.0f) { + auto idx_within_group = (idx - dgroups[dgidx_arr[idx]]) + 1; + return MAPStats(static_cast(dhits_arr[idx]) / idx_within_group, + static_cast(dhits_arr[idx] - 1) / idx_within_group, + static_cast(dhits_arr[idx] + 1) / idx_within_group, + 1.0f); + } + return MAPStats(); + }; // NOLINT + + thrust::transform(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(nitems), + this->dmap_stats_.begin(), + ComputeItemPrecisionLambda); + + // Lastly, compute the accumulated precisions for all the items segmented by groups. + // The precisions are accumulated within each group + thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), + group_segments.begin(), group_segments.end(), + this->dmap_stats_.begin(), // Input map stats + this->dmap_stats_.begin()); // In-place scan and output here + } + + inline const dh::caching_device_vector &GetMapStats() const { + return dmap_stats_; + } + + // Type containing device pointers that can be cheaply copied on the kernel + class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { + public: + MAPLambdaWeightMultiplier(const SegmentSorter &segment_label_sorter, + const MAPLambdaWeightComputer &lwc) + : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), + dmap_stats_ptr_(lwc.GetMapStats().data().get()) {} - struct MAPLambdaWeightMultiplier { // Adjust the items weight by this value __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - return 1.0f; + uint32_t group_begin = dgroups_[gidx]; + uint32_t group_end = dgroups_[gidx + 1]; + + auto pos_lab_orig_posn = dorig_pos_[pidx]; + auto neg_lab_orig_posn = dorig_pos_[nidx]; + KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn); + + // Note: the label positive and negative indices are relative to the entire dataset. + // Hence, scale them back to an index within the group + auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin; + auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin; + return MAPLambdaWeightComputer::GetLambdaMAP( + pos_pred_pos, neg_pred_pos, + dsorted_labels_[pidx], dsorted_labels_[nidx], + &dmap_stats_ptr_[group_begin], group_end - group_begin); } + + private: + const MAPStats *dmap_stats_ptr_{nullptr}; // Start address of the map stats for every sorted + // prediction value }; - inline MAPLambdaWeightMultiplier GetWeightMultiplier() const { - return {}; - } + inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; } + + private: + dh::caching_device_vector dmap_stats_; + // This computes the adjustment to the weight + const MAPLambdaWeightMultiplier weight_multiplier_; #endif }; @@ -641,30 +801,31 @@ class SortedLabelList : SegmentSorter { // This kernel can only run *after* the kernel in sort is completed, as they // use the default stream template - void ComputeGradients(const bst_float *dpreds, + void ComputeGradients(const bst_float *dpreds, // Unsorted predictions + const bst_float *dlabels, // Unsorted labels const HostDeviceVector &weights, int iter, GradientPair *out_gpair, float weight_normalization_factor) { // Group info on device - const uint32_t *dgroups = this->GroupIndices(); - uint32_t ngroups = this->NumGroups() + 1; + const uint32_t *dgroups = this->GetGroupsPtr(); + uint32_t ngroups = this->GetNumGroups() + 1; - uint32_t total_items = this->NumItems(); + uint32_t total_items = this->GetNumItems(); uint32_t niter = param_.num_pairsample * total_items; float fix_list_weight = param_.fix_list_weight; - const uint32_t *original_pos = this->OriginalPositions(); + const uint32_t *original_pos = this->GetOriginalPositionsPtr(); uint32_t num_weights = weights.Size(); auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr; - const bst_float *sorted_labels = this->Items(); + const bst_float *sorted_labels = this->GetItemsPtr(); // This is used to adjust the weight of different elements based on the different ranking // objective function policies - LambdaWeightComputerT weight_computer(dpreds, total_items, *this); + LambdaWeightComputerT weight_computer(dpreds, dlabels, *this); auto wmultiplier = weight_computer.GetWeightMultiplier(); int device_id = -1; @@ -762,10 +923,9 @@ class LambdaRankObj : public ObjFunction { << "group structure not consistent with #rows"; #if defined(__CUDACC__) - // For now, we only support pairwise ranking computation on GPU. // Check if we have a GPU assignment; else, revert back to CPU auto device = tparam_->gpu_id; - if (device >= 0 && LambdaWeightComputerT::SupportOnGPU()) { + if (device >= 0) { ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr); } else { // Revert back to CPU @@ -809,7 +969,7 @@ class LambdaRankObj : public ObjFunction { int iter, HostDeviceVector* out_gpair, const std::vector &gptr) { - LOG(DEBUG) << "Computing pairwise gradients on CPU."; + LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on CPU."; bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); @@ -893,7 +1053,7 @@ class LambdaRankObj : public ObjFunction { int iter, HostDeviceVector* out_gpair, const std::vector &gptr) { - LOG(DEBUG) << "Computing pairwise gradients on GPU."; + LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU."; auto device = tparam_->gpu_id; dh::safe_cuda(cudaSetDevice(device)); @@ -910,6 +1070,7 @@ class LambdaRankObj : public ObjFunction { auto d_preds = preds.ConstDevicePointer(); auto d_gpair = out_gpair->DevicePointer(); + auto d_labels = info.labels_.ConstDevicePointer(); SortedLabelList slist(param_); @@ -921,7 +1082,7 @@ class LambdaRankObj : public ObjFunction { // Finally, compute the gradients slist.ComputeGradients - (d_preds, info.weights_, iter, d_gpair, weight_normalization_factor); + (d_preds, d_labels, info.weights_, iter, d_gpair, weight_normalization_factor); } #endif diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 795f9420d..c2293a78a 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -84,3 +84,81 @@ void TestAllocator() { TEST(bulkAllocator, Test) { TestAllocator(); } + +template > +void TestUpperBoundImpl(const std::vector &vec, T val_to_find, + const Comp &comp = Comp()) { + EXPECT_EQ(dh::UpperBound(vec.data(), vec.size(), val_to_find, comp), + std::upper_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin()); +} + +template > +void TestLowerBoundImpl(const std::vector &vec, T val_to_find, + const Comp &comp = Comp()) { + EXPECT_EQ(dh::LowerBound(vec.data(), vec.size(), val_to_find, comp), + std::lower_bound(vec.begin(), vec.end(), val_to_find, comp) - vec.begin()); +} + +TEST(UpperBound, DataAscending) { + std::vector hvec{0, 3, 5, 5, 7, 8, 9, 10, 10}; + + // Test boundary conditions + TestUpperBoundImpl(hvec, hvec.front()); // Result 1 + TestUpperBoundImpl(hvec, hvec.front() - 1); // Result 0 + TestUpperBoundImpl(hvec, hvec.back() + 1); // Result hvec.size() + TestUpperBoundImpl(hvec, hvec.back()); // Result hvec.size() + + // Test other values - both missing and present + TestUpperBoundImpl(hvec, 3); // Result 2 + TestUpperBoundImpl(hvec, 4); // Result 2 + TestUpperBoundImpl(hvec, 5); // Result 4 +} + +TEST(UpperBound, DataDescending) { + std::vector hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0}; + const auto &comparator = thrust::greater(); + + // Test boundary conditions + TestUpperBoundImpl(hvec, hvec.front(), comparator); // Result 2 + TestUpperBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0 + TestUpperBoundImpl(hvec, hvec.back(), comparator); // Result hvec.size() + TestUpperBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size() + + // Test other values - both missing and present + TestUpperBoundImpl(hvec, 9, comparator); // Result 3 + TestUpperBoundImpl(hvec, 7, comparator); // Result 5 + TestUpperBoundImpl(hvec, 4, comparator); // Result 7 + TestUpperBoundImpl(hvec, 8, comparator); // Result 4 +} + +TEST(LowerBound, DataAscending) { + std::vector hvec{0, 3, 5, 5, 7, 8, 9, 10, 10}; + + // Test boundary conditions + TestLowerBoundImpl(hvec, hvec.front()); // Result 0 + TestLowerBoundImpl(hvec, hvec.front() - 1); // Result 0 + TestLowerBoundImpl(hvec, hvec.back()); // Result 7 + TestLowerBoundImpl(hvec, hvec.back() + 1); // Result hvec.size() + + // Test other values - both missing and present + TestLowerBoundImpl(hvec, 3); // Result 1 + TestLowerBoundImpl(hvec, 4); // Result 2 + TestLowerBoundImpl(hvec, 5); // Result 2 +} + +TEST(LowerBound, DataDescending) { + std::vector hvec{10, 10, 9, 8, 7, 5, 5, 3, 0, 0}; + const auto &comparator = thrust::greater(); + + // Test boundary conditions + TestLowerBoundImpl(hvec, hvec.front(), comparator); // Result 0 + TestLowerBoundImpl(hvec, hvec.front() + 1, comparator); // Result 0 + TestLowerBoundImpl(hvec, hvec.back(), comparator); // Result 8 + TestLowerBoundImpl(hvec, hvec.back() - 1, comparator); // Result hvec.size() + + // Test other values - both missing and present + TestLowerBoundImpl(hvec, 9, comparator); // Result 2 + TestLowerBoundImpl(hvec, 7, comparator); // Result 4 + TestLowerBoundImpl(hvec, 4, comparator); // Result 7 + TestLowerBoundImpl(hvec, 8, comparator); // Result 3 +} diff --git a/tests/cpp/objective/test_ranking_obj.cc b/tests/cpp/objective/test_ranking_obj.cc index 9b2c0ded5..6f3571a0f 100644 --- a/tests/cpp/objective/test_ranking_obj.cc +++ b/tests/cpp/objective/test_ranking_obj.cc @@ -105,4 +105,33 @@ TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) { ASSERT_NO_THROW(obj->DefaultEvalMetric()); } +TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) { + std::vector> args; + xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX); + + std::unique_ptr obj { + xgboost::ObjFunction::Create("rank:map", &lparam) + }; + obj->Configure(args); + CheckConfigReload(obj, "rank:map"); + + // Test with setting sample weight to second query group + CheckRankingObjFunction(obj, + {0, 0.1f, 0, 0.1f}, + {0, 1, 0, 1}, + {2.0f, 0.0f}, + {0, 2, 4}, + {0.95f, -0.95f, 0.0f, 0.0f}, + {0.9975f, 0.9975f, 0.0f, 0.0f}); + + CheckRankingObjFunction(obj, + {0, 0.1f, 0, 0.1f}, + {0, 1, 0, 1}, + {1.0f, 1.0f}, + {0, 2, 4}, + {0.475f, -0.475f, 0.475f, -0.475f}, + {0.4988f, 0.4988f, 0.4988f, 0.4988f}); + ASSERT_NO_THROW(obj->DefaultEvalMetric()); +} + } // namespace xgboost diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index 394cd0092..d48284ac2 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -19,22 +19,22 @@ RankSegmentSorterTestImpl(const std::vector &group_indices, dh::device_vector dlabels(hlabels); seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator()); - EXPECT_EQ(seg_sorter.NumItems(), group_indices.back()); - EXPECT_EQ(seg_sorter.NumGroups(), group_indices.size() - 1); + EXPECT_EQ(seg_sorter.GetNumItems(), group_indices.back()); + EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1); // Check the labels - dh::device_vector sorted_dlabels(seg_sorter.NumItems()); - sorted_dlabels.assign(thrust::device_ptr(seg_sorter.Items()), - thrust::device_ptr(seg_sorter.Items()) - + seg_sorter.NumItems()); + dh::device_vector sorted_dlabels(seg_sorter.GetNumItems()); + sorted_dlabels.assign(thrust::device_ptr(seg_sorter.GetItemsPtr()), + thrust::device_ptr(seg_sorter.GetItemsPtr()) + + seg_sorter.GetNumItems()); thrust::host_vector sorted_hlabels(sorted_dlabels); EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels); // Check the indices - dh::device_vector dorig_pos(seg_sorter.NumItems()); - dorig_pos.assign(thrust::device_ptr(seg_sorter.OriginalPositions()), - thrust::device_ptr(seg_sorter.OriginalPositions()) - + seg_sorter.NumItems()); + dh::device_vector dorig_pos(seg_sorter.GetNumItems()); + dorig_pos.assign(thrust::device_ptr(seg_sorter.GetOriginalPositionsPtr()), + thrust::device_ptr(seg_sorter.GetOriginalPositionsPtr()) + + seg_sorter.GetNumItems()); dh::device_vector horig_pos(dorig_pos); EXPECT_EQ(expected_orig_pos, horig_pos); @@ -125,11 +125,14 @@ TEST(Objective, RankItemCountOnRight) { } TEST(Objective, NDCGLambdaWeightComputerTest) { + std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels + 7.8f, 5.01f, 6.96f, + 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}; + dh::device_vector dlabels(hlabels); + auto segment_label_sorter = RankSegmentSorterTestImpl( {0, 4, 7, 12}, // Groups - {3.1f, 1.2f, 2.3f, 4.4f, // Labels - 7.8f, 5.01f, 6.96f, - 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}, + hlabels, {4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels 7.8f, 6.96f, 5.01f, 11.4f, 11.4f, 10.3f, 9.45f, 8.7f}, @@ -142,18 +145,114 @@ TEST(Objective, NDCGLambdaWeightComputerTest) { -1.03f, -2.79f, -3.1f, 104.22f, 103.1f, -101.7f, 100.5f, 45.1f}; dh::device_vector dpreds(hpreds); + xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(), - dpreds.size(), + dlabels.data().get(), *segment_label_sorter); // Where will the predictions move from its current position, if they were sorted // descendingly? - auto dsorted_pred_pos = ndcg_lw_computer.GetSortedPredPos(); + auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositions(); thrust::host_vector hsorted_pred_pos(dsorted_pred_pos); std::vector expected_sorted_pred_pos{2, 0, 1, 3, 4, 5, 6, 7, 8, 11, 9, 10}; EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos); + + // Check group DCG values + thrust::host_vector hgroup_dcgs(ndcg_lw_computer.GetGroupDcgs()); + thrust::host_vector hgroups(segment_label_sorter->GetGroups()); + thrust::host_vector hsorted_labels(segment_label_sorter->GetItems()); + EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups()); + for (auto i = 0; i < hgroup_dcgs.size(); ++i) { + // Compute group DCG value on CPU and compare + auto gbegin = hgroups[i]; + auto gend = hgroups[i + 1]; + EXPECT_NEAR( + hgroup_dcgs[i], + xgboost::obj::NDCGLambdaWeightComputer::ComputeGroupDCGWeight(&hsorted_labels[gbegin], + gend - gbegin), + 0.01f); + } +} + +TEST(Objective, IndexableSortedItemsTest) { + std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels + 7.8f, 5.01f, 6.96f, + 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}; + dh::device_vector dlabels(hlabels); + + auto segment_label_sorter = RankSegmentSorterTestImpl( + {0, 4, 7, 12}, // Groups + hlabels, + {4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels + 7.8f, 6.96f, 5.01f, + 11.4f, 11.4f, 10.3f, 9.45f, 8.7f}, + {3, 0, 2, 1, // Expected original positions + 4, 6, 5, + 9, 11, 7, 10, 8}); + + segment_label_sorter->CreateIndexableSortedPositions(); + thrust::host_vector sorted_indices(segment_label_sorter->GetIndexableSortedPositions()); + std::vector expected_sorted_indices = { + 1, 3, 2, 0, + 4, 6, 5, + 9, 11, 7, 10, 8}; + EXPECT_EQ(expected_sorted_indices, sorted_indices); +} + +TEST(Objective, ComputeAndCompareMAPStatsTest) { + std::vector hlabels = {3.1f, 0.0f, 2.3f, 4.4f, // Labels + 0.0f, 5.01f, 0.0f, + 10.3f, 0.0f, 11.4f, 9.45f, 11.4f}; + dh::device_vector dlabels(hlabels); + + auto segment_label_sorter = RankSegmentSorterTestImpl( + {0, 4, 7, 12}, // Groups + hlabels, + {4.4f, 3.1f, 2.3f, 0.0f, // Expected sorted labels + 5.01f, 0.0f, 0.0f, + 11.4f, 11.4f, 10.3f, 9.45f, 0.0f}, + {3, 0, 2, 1, // Expected original positions + 5, 4, 6, + 9, 11, 7, 10, 8}); + + // Create MAP stats on the device first using the objective + std::vector hpreds{-9.78f, 24.367f, 0.908f, -11.47f, + -1.03f, -2.79f, -3.1f, + 104.22f, 103.1f, -101.7f, 100.5f, 45.1f}; + dh::device_vector dpreds(hpreds); + + xgboost::obj::MAPLambdaWeightComputer map_lw_computer(dpreds.data().get(), + dlabels.data().get(), + *segment_label_sorter); + + // Get the device MAP stats on host + thrust::host_vector dmap_stats( + map_lw_computer.GetMapStats()); + + // Compute the MAP stats on host next to compare + thrust::host_vector hgroups(segment_label_sorter->GetGroups()); + + for (auto i = 0; i < hgroups.size() - 1; ++i) { + auto gbegin = hgroups[i]; + auto gend = hgroups[i + 1]; + std::vector lst_entry; + for (auto j = gbegin; j < gend; ++j) { + lst_entry.emplace_back(hpreds[j], hlabels[j], j); + } + std::stable_sort(lst_entry.begin(), lst_entry.end(), xgboost::obj::ListEntry::CmpPred); + + // Compute the MAP stats with this list and compare with the ones computed on the device + std::vector hmap_stats; + xgboost::obj::MAPLambdaWeightComputer::GetMAPStats(lst_entry, &hmap_stats); + for (auto j = gbegin; j < gend; ++j) { + EXPECT_EQ(dmap_stats[j].hits, hmap_stats[j - gbegin].hits); + EXPECT_NEAR(dmap_stats[j].ap_acc, hmap_stats[j - gbegin].ap_acc, 0.01f); + EXPECT_NEAR(dmap_stats[j].ap_acc_miss, hmap_stats[j - gbegin].ap_acc_miss, 0.01f); + EXPECT_NEAR(dmap_stats[j].ap_acc_add, hmap_stats[j - gbegin].ap_acc_add, 0.01f); + } + } } } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index d51d7f006..58c2fd78d 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -141,3 +141,21 @@ class TestRanking(unittest.TestCase): Train an XGBoost ranking model with ndcg objective function and compare ndcg metric """ self.__test_training_with_rank_objective('rank:ndcg', 'ndcg') + + def test_training_rank_map_map(self): + """ + Train an XGBoost ranking model with map objective function and compare map metric + """ + self.__test_training_with_rank_objective('rank:map', 'map') + + def test_training_rank_map_auc(self): + """ + Train an XGBoost ranking model with map objective function and compare auc metric + """ + self.__test_training_with_rank_objective('rank:map', 'auc') + + def test_training_rank_map_ndcg(self): + """ + Train an XGBoost ranking model with map objective function and compare ndcg metric + """ + self.__test_training_with_rank_objective('rank:map', 'ndcg')