diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 382505835..b567ffc0e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -8,7 +8,9 @@ #include #include #include +#include +#include #include #include #include @@ -1285,6 +1287,175 @@ thrust::device_ptr tcend(xgboost::common::Span const& span) { return tcbegin(span) + span.size(); } +// This type sorts an array which is divided into multiple groups. The sorting is influenced +// by the function object 'Comparator' +template +class SegmentSorter { + private: + // Items sorted within the group + caching_device_vector ditems_; + + // Original position of the items before they are sorted descendingly within its groups + caching_device_vector doriginal_pos_; + + // Segments within the original list that delineates the different groups + caching_device_vector group_segments_; + + // Need this on the device as it is used in the kernels + caching_device_vector dgroups_; // Group information on device + + // Where did the item that was originally present at position 'x' move to after they are sorted + caching_device_vector dindexable_sorted_pos_; + + // Initialize everything but the segments + void Init(uint32_t num_elems) { + ditems_.resize(num_elems); + + doriginal_pos_.resize(num_elems); + thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end()); + } + + // Initialize all with group info + void Init(const std::vector &groups) { + uint32_t num_elems = groups.back(); + this->Init(num_elems); + this->CreateGroupSegments(groups); + } + + public: + // This needs to be public due to device lambda + void CreateGroupSegments(const std::vector &groups) { + uint32_t num_elems = groups.back(); + group_segments_.resize(num_elems, 0); + + dgroups_ = groups; + + if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them + + // Define the segments by assigning a group ID to each element + const uint32_t *dgroups = dgroups_.data().get(); + uint32_t ngroups = dgroups_.size(); + auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { + return dh::UpperBound(dgroups, ngroups, idx) - 1; + }; // NOLINT + + thrust::transform(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(num_elems), + group_segments_.begin(), + ComputeGroupIDLambda); + } + + // Accessors that returns device pointer + inline uint32_t GetNumItems() const { return ditems_.size(); } + inline const xgboost::common::Span GetItemsSpan() const { + return { ditems_.data().get(), ditems_.size() }; + } + + inline const xgboost::common::Span GetOriginalPositionsSpan() const { + return { doriginal_pos_.data().get(), doriginal_pos_.size() }; + } + + inline const xgboost::common::Span GetGroupSegmentsSpan() const { + return { group_segments_.data().get(), group_segments_.size() }; + } + + inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; } + inline const xgboost::common::Span GetGroupsSpan() const { + return { dgroups_.data().get(), dgroups_.size() }; + } + + inline const xgboost::common::Span GetIndexableSortedPositionsSpan() const { + return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() }; + } + + // Sort an array that is divided into multiple groups. The array is sorted within each group. + // This version provides the group information that is on the host. + // The array is sorted based on an adaptable binary predicate. By default a stateless predicate + // is used. + template > + void SortItems(const T *ditems, uint32_t item_size, const std::vector &groups, + const Comparator &comp = Comparator()) { + this->Init(groups); + this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp); + } + + // Sort an array that is divided into multiple groups. The array is sorted within each group. + // This version provides the group information that is on the device. + // The array is sorted based on an adaptable binary predicate. By default a stateless predicate + // is used. + template > + void SortItems(const T *ditems, uint32_t item_size, + const xgboost::common::Span &group_segments, + const Comparator &comp = Comparator()) { + this->Init(item_size); + + // Sort the items that are grouped. We would like to avoid using predicates to perform the sort, + // as thrust resorts to using a merge sort as opposed to a much much faster radix sort + // when comparators are used. Hence, the following algorithm is used. This is done so that + // we can grab the appropriate related values from the original list later, after the + // items are sorted. + // + // Here is the internal representation: + // dgroups_: [ 0, 3, 5, 8, 10 ] + // group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3 + // doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9 + // ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items) + // + // Sort the items first and make a note of the original positions in doriginal_pos_ + // based on the sort + // ditems_: 4 4 3 3 2 1 1 1 1 0 + // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 + // NOTE: This consumes space, but is much faster than some of the other approaches - sorting + // in kernel, sorting using predicates etc. + + ditems_.assign(thrust::device_ptr(ditems), + thrust::device_ptr(ditems) + item_size); + + // Allocator to be used by sort for managing space overhead while sorting + dh::XGBCachingDeviceAllocator alloc; + + thrust::stable_sort_by_key(thrust::cuda::par(alloc), + ditems_.begin(), ditems_.end(), + doriginal_pos_.begin(), comp); + + if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented + + // Next, gather the segments based on the doriginal_pos_. This is to reflect the + // holisitic item sort order on the segments + // group_segments_c_: 3 3 2 2 1 0 0 1 2 0 + // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same) + caching_device_vector group_segments_c(item_size); + thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), + dh::tcbegin(group_segments), group_segments_c.begin()); + + // Now, sort the group segments so that you may bring the items within the group together, + // in the process also noting the relative changes to the doriginal_pos_ while that happens + // group_segments_c_: 0 0 0 1 1 2 2 2 3 3 + // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 + thrust::stable_sort_by_key(thrust::cuda::par(alloc), + group_segments_c.begin(), group_segments_c.end(), + doriginal_pos_.begin(), thrust::less()); + + // Finally, gather the original items based on doriginal_pos_ to sort the input and + // to store them in ditems_ + // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same) + // ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems) + thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), + thrust::device_ptr(ditems), ditems_.begin()); + } + + // Determine where an item that was originally present at position 'x' has been relocated to + // after a sort. Creation of such an index has to be explicitly requested after a sort + void CreateIndexableSortedPositions() { + dindexable_sorted_pos_.resize(GetNumItems()); + thrust::scatter(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(GetNumItems()), // Rearrange indices... + // ...based on this map + dh::tcbegin(GetOriginalPositionsSpan()), + dindexable_sorted_pos_.begin()); // Write results into this + } +}; + template class LauncherItr { public: diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 117dcd243..a4f3dad4f 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -48,172 +48,6 @@ struct LambdaRankParam : public XGBoostParameter { }; #if defined(__CUDACC__) -// This type sorts an array which is divided into multiple groups. The sorting is influenced -// by the function object 'Comparator' -template -class SegmentSorter { - private: - // Items sorted within the group - dh::caching_device_vector ditems_; - - // Original position of the items before they are sorted descendingly within its groups - dh::caching_device_vector doriginal_pos_; - - // Segments within the original list that delineates the different groups - dh::caching_device_vector group_segments_; - - // Need this on the device as it is used in the kernels - dh::caching_device_vector dgroups_; // Group information on device - - // Where did the item that was originally present at position 'x' move to after they are sorted - dh::caching_device_vector dindexable_sorted_pos_; - - // Initialize everything but the segments - void Init(uint32_t num_elems) { - ditems_.resize(num_elems); - - doriginal_pos_.resize(num_elems); - thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end()); - } - - // Initialize all with group info - void Init(const std::vector &groups) { - uint32_t num_elems = groups.back(); - this->Init(num_elems); - this->CreateGroupSegments(groups); - } - - public: - // This needs to be public due to device lambda - void CreateGroupSegments(const std::vector &groups) { - uint32_t num_elems = groups.back(); - group_segments_.resize(num_elems); - - dgroups_ = groups; - - // Define the segments by assigning a group ID to each element - const uint32_t *dgroups = dgroups_.data().get(); - uint32_t ngroups = dgroups_.size(); - auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { - return dh::UpperBound(dgroups, ngroups, idx) - 1; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(num_elems), - group_segments_.begin(), - ComputeGroupIDLambda); - } - - // Accessors that returns device pointer - inline const T *GetItemsPtr() const { return ditems_.data().get(); } - inline uint32_t GetNumItems() const { return ditems_.size(); } - inline const dh::caching_device_vector &GetItems() const { - return ditems_; - } - - inline const uint32_t *GetOriginalPositionsPtr() const { return doriginal_pos_.data().get(); } - inline const dh::caching_device_vector &GetOriginalPositions() const { - return doriginal_pos_; - } - - inline const dh::caching_device_vector &GetGroupSegments() const { - return group_segments_; - } - - inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; } - inline const uint32_t *GetGroupsPtr() const { return dgroups_.data().get(); } - inline const dh::caching_device_vector &GetGroups() const { return dgroups_; } - - inline const dh::caching_device_vector &GetIndexableSortedPositions() const { - return dindexable_sorted_pos_; - } - - // Sort an array that is divided into multiple groups. The array is sorted within each group. - // This version provides the group information that is on the host. - // The array is sorted based on an adaptable binary predicate. By default a stateless predicate - // is used. - template > - void SortItems(const T *ditems, uint32_t item_size, const std::vector &groups, - const Comparator &comp = Comparator()) { - this->Init(groups); - this->SortItems(ditems, item_size, group_segments_, comp); - } - - // Sort an array that is divided into multiple groups. The array is sorted within each group. - // This version provides the group information that is on the device. - // The array is sorted based on an adaptable binary predicate. By default a stateless predicate - // is used. - template > - void SortItems(const T *ditems, uint32_t item_size, - const dh::caching_device_vector &group_segments, - const Comparator &comp = Comparator()) { - this->Init(item_size); - - // Sort the items that are grouped. We would like to avoid using predicates to perform the sort, - // as thrust resorts to using a merge sort as opposed to a much much faster radix sort - // when comparators are used. Hence, the following algorithm is used. This is done so that - // we can grab the appropriate related values from the original list later, after the - // items are sorted. - // - // Here is the internal representation: - // dgroups_: [ 0, 3, 5, 8, 10 ] - // group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3 - // doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9 - // ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items) - // - // Sort the items first and make a note of the original positions in doriginal_pos_ - // based on the sort - // ditems_: 4 4 3 3 2 1 1 1 1 0 - // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 - // NOTE: This consumes space, but is much faster than some of the other approaches - sorting - // in kernel, sorting using predicates etc. - - ditems_.assign(thrust::device_ptr(ditems), - thrust::device_ptr(ditems) + item_size); - - // Allocator to be used by sort for managing space overhead while sorting - dh::XGBCachingDeviceAllocator alloc; - - thrust::stable_sort_by_key(thrust::cuda::par(alloc), - ditems_.begin(), ditems_.end(), - doriginal_pos_.begin(), comp); - - // Next, gather the segments based on the doriginal_pos_. This is to reflect the - // holisitic item sort order on the segments - // group_segments_c_: 3 3 2 2 1 0 0 1 2 0 - // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same) - dh::caching_device_vector group_segments_c(group_segments); - thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), - group_segments.begin(), group_segments_c.begin()); - - // Now, sort the group segments so that you may bring the items within the group together, - // in the process also noting the relative changes to the doriginal_pos_ while that happens - // group_segments_c_: 0 0 0 1 1 2 2 2 3 3 - // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 - thrust::stable_sort_by_key(thrust::cuda::par(alloc), - group_segments_c.begin(), group_segments_c.end(), - doriginal_pos_.begin(), thrust::less()); - - // Finally, gather the original items based on doriginal_pos_ to sort the input and - // to store them in ditems_ - // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same) - // ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems) - thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), - thrust::device_ptr(ditems), ditems_.begin()); - } - - // Determine where an item that was originally present at position 'x' has been relocated to - // after a sort. Creation of such an index has to be explicitly requested after a sort - void CreateIndexableSortedPositions() { - dindexable_sorted_pos_.resize(GetNumItems()); - thrust::scatter(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(GetNumItems()), // Rearrange indices... - // ...based on this map - thrust::device_ptr(GetOriginalPositionsPtr()), - dindexable_sorted_pos_.begin()); // Write results into this - } -}; - // Helper functions template @@ -283,7 +117,7 @@ class PairwiseLambdaWeightComputer { #if defined(__CUDACC__) PairwiseLambdaWeightComputer(const bst_float *dpreds, const bst_float *dlabels, - const SegmentSorter &segment_label_sorter) {} + const dh::SegmentSorter &segment_label_sorter) {} class PairwiseLambdaWeightMultiplier { public: @@ -302,20 +136,20 @@ class PairwiseLambdaWeightComputer { #if defined(__CUDACC__) class BaseLambdaWeightMultiplier { public: - BaseLambdaWeightMultiplier(const SegmentSorter &segment_label_sorter, - const SegmentSorter &segment_pred_sorter) - : dsorted_labels_(segment_label_sorter.GetItemsPtr()), - dorig_pos_(segment_label_sorter.GetOriginalPositionsPtr()), - dgroups_(segment_label_sorter.GetGroupsPtr()), - dindexable_sorted_preds_pos_ptr_( - segment_pred_sorter.GetIndexableSortedPositions().data().get()) {} + BaseLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, + const dh::SegmentSorter &segment_pred_sorter) + : dsorted_labels_(segment_label_sorter.GetItemsSpan()), + dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()), + dgroups_(segment_label_sorter.GetGroupsSpan()), + dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {} protected: - const float *dsorted_labels_{nullptr}; // Labels sorted within a group - const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted - const uint32_t *dgroups_{nullptr}; // The group indices + const common::Span dsorted_labels_; // Labels sorted within a group + const common::Span dorig_pos_; // Original indices of the labels + // before they are sorted + const common::Span dgroups_; // The group indices // Where can a prediction for a label be found in the original array, when they are sorted - const uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr}; + const common::Span dindexable_sorted_preds_pos_; }; // While computing the weight that needs to be adjusted by this ranking objective, we need @@ -342,8 +176,8 @@ class BaseLambdaWeightMultiplier { // // We create a sorted prediction positional array, such that value at position 'x' gives // us the position in the sorted prediction array where its related prediction lies. -// dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5 -// at indices: 0 1 2 3 4 5 6 7 8 9 +// dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5 +// at indices: 0 1 2 3 4 5 6 7 8 9 // Basically, swap the previous 2 arrays, sort the indices and reorder positions // for an O(1) lookup using the position where the sorted label exists. // @@ -351,21 +185,21 @@ class BaseLambdaWeightMultiplier { class IndexablePredictionSorter { public: IndexablePredictionSorter(const bst_float *dpreds, - const SegmentSorter &segment_label_sorter) { + const dh::SegmentSorter &segment_label_sorter) { // Sort the predictions first segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(), - segment_label_sorter.GetGroupSegments()); + segment_label_sorter.GetGroupSegmentsSpan()); // Create an index for the sorted prediction positions segment_pred_sorter_.CreateIndexableSortedPositions(); } - inline const SegmentSorter &GetPredictionSorter() const { + inline const dh::SegmentSorter &GetPredictionSorter() const { return segment_pred_sorter_; } private: - SegmentSorter segment_pred_sorter_; // For sorting the predictions + dh::SegmentSorter segment_pred_sorter_; // For sorting the predictions }; #endif @@ -380,9 +214,9 @@ class NDCGLambdaWeightComputer // This function object computes the item's DCG value class ComputeItemDCG : public thrust::unary_function { public: - XGBOOST_DEVICE ComputeItemDCG(const float *dsorted_labels, - const uint32_t *dgroups, - const uint32_t *gidxs) + XGBOOST_DEVICE ComputeItemDCG(const common::Span &dsorted_labels, + const common::Span &dgroups, + const common::Span &gidxs) : dsorted_labels_(dsorted_labels), dgroups_(dgroups), dgidxs_(gidxs) {} @@ -393,22 +227,23 @@ class NDCGLambdaWeightComputer } private: - const float *dsorted_labels_{nullptr}; // Labels sorted within a group - const uint32_t *dgroups_{nullptr}; // The group indices - where each group begins and ends - const uint32_t *dgidxs_{nullptr}; // The group each items belongs to + const common::Span dsorted_labels_; // Labels sorted within a group + const common::Span dgroups_; // The group indices - where each group + // begins and ends + const common::Span dgidxs_; // The group each items belongs to }; // Type containing device pointers that can be cheaply copied on the kernel class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { public: - NDCGLambdaWeightMultiplier(const SegmentSorter &segment_label_sorter, + NDCGLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, const NDCGLambdaWeightComputer &lwc) : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), - dgroup_dcg_ptr_(lwc.GetGroupDcgs().data().get()) {} + dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {} // Adjust the items weight by this value __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - if (dgroup_dcg_ptr_[gidx] == 0.0) return 0.0f; + if (dgroup_dcgs_[gidx] == 0.0) return 0.0f; uint32_t group_begin = dgroups_[gidx]; @@ -418,43 +253,47 @@ class NDCGLambdaWeightComputer // Note: the label positive and negative indices are relative to the entire dataset. // Hence, scale them back to an index within the group - auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin; - auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin; + auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin; + auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin; return NDCGLambdaWeightComputer::ComputeDeltaWeight( pos_pred_pos, neg_pred_pos, static_cast(dsorted_labels_[pidx]), static_cast(dsorted_labels_[nidx]), - dgroup_dcg_ptr_[gidx]); + dgroup_dcgs_[gidx]); } private: - const float *dgroup_dcg_ptr_{nullptr}; // Start address of the group DCG values + const common::Span dgroup_dcgs_; // Group DCG values }; NDCGLambdaWeightComputer(const bst_float *dpreds, const bst_float *dlabels, - const SegmentSorter &segment_label_sorter) + const dh::SegmentSorter &segment_label_sorter) : IndexablePredictionSorter(dpreds, segment_label_sorter), dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f), weight_multiplier_(segment_label_sorter, *this) { - const auto &group_segments = segment_label_sorter.GetGroupSegments(); + const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan(); + + // Allocator to be used for managing space overhead while performing transformed reductions + dh::XGBCachingDeviceAllocator alloc; // Compute each elements DCG values and reduce them across groups concurrently. auto end_range = - thrust::reduce_by_key(group_segments.begin(), group_segments.end(), + thrust::reduce_by_key(thrust::cuda::par(alloc), + dh::tcbegin(group_segments), dh::tcend(group_segments), thrust::make_transform_iterator( // The indices need not be sequential within a group, as we care only // about the sum of items DCG values within a group - segment_label_sorter.GetOriginalPositions().begin(), - ComputeItemDCG(segment_label_sorter.GetItemsPtr(), - segment_label_sorter.GetGroupsPtr(), - group_segments.data().get())), + dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()), + ComputeItemDCG(segment_label_sorter.GetItemsSpan(), + segment_label_sorter.GetGroupsSpan(), + group_segments)), thrust::make_discard_iterator(), // We don't care for the group indices dgroup_dcg_.begin()); // Sum of the item's DCG values in the group CHECK(end_range.second - dgroup_dcg_.begin() == dgroup_dcg_.size()); } - inline const dh::caching_device_vector &GetGroupDcgs() const { - return dgroup_dcg_; + inline const common::Span GetGroupDcgsSpan() const { + return { dgroup_dcg_.data().get(), dgroup_dcg_.size() }; } inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const { @@ -664,7 +503,7 @@ class MAPLambdaWeightComputer #if defined(__CUDACC__) MAPLambdaWeightComputer(const bst_float *dpreds, const bst_float *dlabels, - const SegmentSorter &segment_label_sorter) + const dh::SegmentSorter &segment_label_sorter) : IndexablePredictionSorter(dpreds, segment_label_sorter), dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()), weight_multiplier_(segment_label_sorter, *this) { @@ -672,7 +511,7 @@ class MAPLambdaWeightComputer } void CreateMAPStats(const bst_float *dlabels, - const SegmentSorter &segment_label_sorter) { + const dh::SegmentSorter &segment_label_sorter) { // For each group, go through the sorted prediction positions, and look up its corresponding // label from the unsorted labels (from the original label list) @@ -683,7 +522,7 @@ class MAPLambdaWeightComputer auto nitems = segment_label_sorter.GetNumItems(); dh::caching_device_vector dhits(nitems, 0); // Original positions of the predictions after they have been sorted - const uint32_t *pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsPtr(); + const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan(); // Unsorted labels const float *unsorted_labels = dlabels; auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) { @@ -700,24 +539,23 @@ class MAPLambdaWeightComputer // Next, prefix scan the positive labels that are segmented to accumulate them. // This is required for computing the accumulated precisions - const auto &group_segments = segment_label_sorter.GetGroupSegments(); + const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan(); // Data segmented into different groups... thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - group_segments.begin(), group_segments.end(), + dh::tcbegin(group_segments), dh::tcend(group_segments), dhits.begin(), // Input value dhits.begin()); // In-place scan // Compute accumulated precisions for each item, assuming positive and // negative instances are missing. // But first, compute individual item precisions - const auto *dgidx_arr = group_segments.data().get(); const auto *dhits_arr = dhits.data().get(); // Group info on device - const uint32_t *dgroups = segment_label_sorter.GetGroupsPtr(); + const auto &dgroups = segment_label_sorter.GetGroupsSpan(); uint32_t ngroups = segment_label_sorter.GetNumGroups(); auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) { if (unsorted_labels[pred_original_pos[idx]] > 0.0f) { - auto idx_within_group = (idx - dgroups[dgidx_arr[idx]]) + 1; + auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1; return MAPStats(static_cast(dhits_arr[idx]) / idx_within_group, static_cast(dhits_arr[idx] - 1) / idx_within_group, static_cast(dhits_arr[idx] + 1) / idx_within_group, @@ -734,22 +572,22 @@ class MAPLambdaWeightComputer // Lastly, compute the accumulated precisions for all the items segmented by groups. // The precisions are accumulated within each group thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - group_segments.begin(), group_segments.end(), + dh::tcbegin(group_segments), dh::tcend(group_segments), this->dmap_stats_.begin(), // Input map stats this->dmap_stats_.begin()); // In-place scan and output here } - inline const dh::caching_device_vector &GetMapStats() const { - return dmap_stats_; + inline const common::Span GetMapStatsSpan() const { + return { dmap_stats_.data().get(), dmap_stats_.size() }; } // Type containing device pointers that can be cheaply copied on the kernel class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { public: - MAPLambdaWeightMultiplier(const SegmentSorter &segment_label_sorter, + MAPLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, const MAPLambdaWeightComputer &lwc) : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), - dmap_stats_ptr_(lwc.GetMapStats().data().get()) {} + dmap_stats_(lwc.GetMapStatsSpan()) {} // Adjust the items weight by this value __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { @@ -762,17 +600,17 @@ class MAPLambdaWeightComputer // Note: the label positive and negative indices are relative to the entire dataset. // Hence, scale them back to an index within the group - auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin; - auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin; + auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin; + auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin; return MAPLambdaWeightComputer::GetLambdaMAP( pos_pred_pos, neg_pred_pos, dsorted_labels_[pidx], dsorted_labels_[nidx], - &dmap_stats_ptr_[group_begin], group_end - group_begin); + &dmap_stats_[group_begin], group_end - group_begin); } private: - const MAPStats *dmap_stats_ptr_{nullptr}; // Start address of the map stats for every sorted - // prediction value + common::Span dmap_stats_; // Start address of the map stats for every sorted + // prediction value }; inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; } @@ -785,7 +623,7 @@ class MAPLambdaWeightComputer }; #if defined(__CUDACC__) -class SortedLabelList : SegmentSorter { +class SortedLabelList : dh::SegmentSorter { private: const LambdaRankParam ¶m_; // Objective configuration @@ -808,7 +646,7 @@ class SortedLabelList : SegmentSorter { GradientPair *out_gpair, float weight_normalization_factor) { // Group info on device - const uint32_t *dgroups = this->GetGroupsPtr(); + const auto &dgroups = this->GetGroupsSpan(); uint32_t ngroups = this->GetNumGroups() + 1; uint32_t total_items = this->GetNumItems(); @@ -816,12 +654,12 @@ class SortedLabelList : SegmentSorter { float fix_list_weight = param_.fix_list_weight; - const uint32_t *original_pos = this->GetOriginalPositionsPtr(); + const auto &original_pos = this->GetOriginalPositionsSpan(); uint32_t num_weights = weights.Size(); auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr; - const bst_float *sorted_labels = this->GetItemsPtr(); + const auto &sorted_labels = this->GetItemsSpan(); // This is used to adjust the weight of different elements based on the different ranking // objective function policies @@ -834,7 +672,7 @@ class SortedLabelList : SegmentSorter { dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) { // First, determine the group 'idx' belongs to uint32_t item_idx = idx % total_items; - uint32_t group_idx = dh::UpperBound(dgroups, ngroups, item_idx); + uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx); // Span of this group within the larger labels/predictions sorted tuple uint32_t group_begin = dgroups[group_idx - 1]; uint32_t group_end = dgroups[group_idx]; @@ -847,9 +685,9 @@ class SortedLabelList : SegmentSorter { // Find the number of labels less than and greater than the current label // at the sorted index position item_idx uint32_t nleft = CountNumItemsToTheLeftOf( - sorted_labels + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]); + sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]); uint32_t nright = CountNumItemsToTheRightOf( - sorted_labels + item_idx, group_end - item_idx, sorted_labels[item_idx]); + sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]); // Create a minstd_rand object to act as our source of randomness thrust::minstd_rand rng((iter + 1) * 1111); diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index d48284ac2..dc8fd267d 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -5,36 +5,34 @@ namespace xgboost { template > -std::unique_ptr> +std::unique_ptr> RankSegmentSorterTestImpl(const std::vector &group_indices, const std::vector &hlabels, const std::vector &expected_sorted_hlabels, const std::vector &expected_orig_pos ) { - std::unique_ptr> seg_sorter_ptr( - new xgboost::obj::SegmentSorter); - xgboost::obj::SegmentSorter &seg_sorter(*seg_sorter_ptr); + std::unique_ptr> seg_sorter_ptr(new dh::SegmentSorter); + dh::SegmentSorter &seg_sorter(*seg_sorter_ptr); // Create a bunch of unsorted labels on the device and sort it via the segment sorter dh::device_vector dlabels(hlabels); seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator()); - EXPECT_EQ(seg_sorter.GetNumItems(), group_indices.back()); + auto num_items = seg_sorter.GetItemsSpan().size(); + EXPECT_EQ(num_items, group_indices.back()); EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1); // Check the labels - dh::device_vector sorted_dlabels(seg_sorter.GetNumItems()); - sorted_dlabels.assign(thrust::device_ptr(seg_sorter.GetItemsPtr()), - thrust::device_ptr(seg_sorter.GetItemsPtr()) - + seg_sorter.GetNumItems()); + dh::device_vector sorted_dlabels(num_items); + sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()), + dh::tcend(seg_sorter.GetItemsSpan())); thrust::host_vector sorted_hlabels(sorted_dlabels); EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels); // Check the indices - dh::device_vector dorig_pos(seg_sorter.GetNumItems()); - dorig_pos.assign(thrust::device_ptr(seg_sorter.GetOriginalPositionsPtr()), - thrust::device_ptr(seg_sorter.GetOriginalPositionsPtr()) - + seg_sorter.GetNumItems()); + dh::device_vector dorig_pos(num_items); + dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()), + dh::tcend(seg_sorter.GetOriginalPositionsSpan())); dh::device_vector horig_pos(dorig_pos); EXPECT_EQ(expected_orig_pos, horig_pos); @@ -152,18 +150,22 @@ TEST(Objective, NDCGLambdaWeightComputerTest) { // Where will the predictions move from its current position, if they were sorted // descendingly? - auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositions(); - thrust::host_vector hsorted_pred_pos(dsorted_pred_pos); + auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan(); + std::vector hsorted_pred_pos(segment_label_sorter->GetNumItems()); + dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos); std::vector expected_sorted_pred_pos{2, 0, 1, 3, 4, 5, 6, 7, 8, 11, 9, 10}; EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos); // Check group DCG values - thrust::host_vector hgroup_dcgs(ndcg_lw_computer.GetGroupDcgs()); - thrust::host_vector hgroups(segment_label_sorter->GetGroups()); - thrust::host_vector hsorted_labels(segment_label_sorter->GetItems()); + std::vector hgroup_dcgs(segment_label_sorter->GetNumGroups()); + dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan()); + std::vector hgroups(segment_label_sorter->GetNumGroups() + 1); + dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan()); EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups()); + std::vector hsorted_labels(segment_label_sorter->GetNumItems()); + dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan()); for (auto i = 0; i < hgroup_dcgs.size(); ++i) { // Compute group DCG value on CPU and compare auto gbegin = hgroups[i]; @@ -193,7 +195,9 @@ TEST(Objective, IndexableSortedItemsTest) { 9, 11, 7, 10, 8}); segment_label_sorter->CreateIndexableSortedPositions(); - thrust::host_vector sorted_indices(segment_label_sorter->GetIndexableSortedPositions()); + std::vector sorted_indices(segment_label_sorter->GetNumItems()); + dh::CopyDeviceSpanToVector(&sorted_indices, + segment_label_sorter->GetIndexableSortedPositionsSpan()); std::vector expected_sorted_indices = { 1, 3, 2, 0, 4, 6, 5, @@ -228,11 +232,13 @@ TEST(Objective, ComputeAndCompareMAPStatsTest) { *segment_label_sorter); // Get the device MAP stats on host - thrust::host_vector dmap_stats( - map_lw_computer.GetMapStats()); + std::vector dmap_stats( + segment_label_sorter->GetNumItems()); + dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan()); // Compute the MAP stats on host next to compare - thrust::host_vector hgroups(segment_label_sorter->GetGroups()); + std::vector hgroups(segment_label_sorter->GetNumGroups() + 1); + dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan()); for (auto i = 0; i < hgroups.size() - 1; ++i) { auto gbegin = hgroups[i]; diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 579436245..68144bc6e 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -40,9 +40,9 @@ void VerifySampling(size_t page_size, EXPECT_EQ(sample.page->matrix.n_rows, kRows); EXPECT_EQ(sample.gpair.size(), kRows); } else { - EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012f); - EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012f); - EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012f); + EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.016f); + EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.016f); + EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.016f); } GradientPair sum_sampled_gpair{}; @@ -52,11 +52,11 @@ void VerifySampling(size_t page_size, sum_sampled_gpair += gp; } if (check_sum) { - EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows); - EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows); + EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.03f * kRows); + EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.03f * kRows); } else { - EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.02f); - EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.02f); + EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.03f); + EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.03f); } }