Move segment sorter to common (#5378)
- move segment sorter to common - this is the first of a handful of pr's that splits the larger pr #5326 - it moves this facility to common (from ranking objective class), so that it can be used for metric computation - it also wraps all the bald device pointers into span.
This commit is contained in:
@@ -48,172 +48,6 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
// This type sorts an array which is divided into multiple groups. The sorting is influenced
|
||||
// by the function object 'Comparator'
|
||||
template <typename T>
|
||||
class SegmentSorter {
|
||||
private:
|
||||
// Items sorted within the group
|
||||
dh::caching_device_vector<T> ditems_;
|
||||
|
||||
// Original position of the items before they are sorted descendingly within its groups
|
||||
dh::caching_device_vector<uint32_t> doriginal_pos_;
|
||||
|
||||
// Segments within the original list that delineates the different groups
|
||||
dh::caching_device_vector<uint32_t> group_segments_;
|
||||
|
||||
// Need this on the device as it is used in the kernels
|
||||
dh::caching_device_vector<uint32_t> dgroups_; // Group information on device
|
||||
|
||||
// Where did the item that was originally present at position 'x' move to after they are sorted
|
||||
dh::caching_device_vector<uint32_t> dindexable_sorted_pos_;
|
||||
|
||||
// Initialize everything but the segments
|
||||
void Init(uint32_t num_elems) {
|
||||
ditems_.resize(num_elems);
|
||||
|
||||
doriginal_pos_.resize(num_elems);
|
||||
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
|
||||
}
|
||||
|
||||
// Initialize all with group info
|
||||
void Init(const std::vector<uint32_t> &groups) {
|
||||
uint32_t num_elems = groups.back();
|
||||
this->Init(num_elems);
|
||||
this->CreateGroupSegments(groups);
|
||||
}
|
||||
|
||||
public:
|
||||
// This needs to be public due to device lambda
|
||||
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
|
||||
uint32_t num_elems = groups.back();
|
||||
group_segments_.resize(num_elems);
|
||||
|
||||
dgroups_ = groups;
|
||||
|
||||
// Define the segments by assigning a group ID to each element
|
||||
const uint32_t *dgroups = dgroups_.data().get();
|
||||
uint32_t ngroups = dgroups_.size();
|
||||
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
|
||||
return dh::UpperBound(dgroups, ngroups, idx) - 1;
|
||||
}; // NOLINT
|
||||
|
||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(num_elems),
|
||||
group_segments_.begin(),
|
||||
ComputeGroupIDLambda);
|
||||
}
|
||||
|
||||
// Accessors that returns device pointer
|
||||
inline const T *GetItemsPtr() const { return ditems_.data().get(); }
|
||||
inline uint32_t GetNumItems() const { return ditems_.size(); }
|
||||
inline const dh::caching_device_vector<T> &GetItems() const {
|
||||
return ditems_;
|
||||
}
|
||||
|
||||
inline const uint32_t *GetOriginalPositionsPtr() const { return doriginal_pos_.data().get(); }
|
||||
inline const dh::caching_device_vector<uint32_t> &GetOriginalPositions() const {
|
||||
return doriginal_pos_;
|
||||
}
|
||||
|
||||
inline const dh::caching_device_vector<uint32_t> &GetGroupSegments() const {
|
||||
return group_segments_;
|
||||
}
|
||||
|
||||
inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; }
|
||||
inline const uint32_t *GetGroupsPtr() const { return dgroups_.data().get(); }
|
||||
inline const dh::caching_device_vector<uint32_t> &GetGroups() const { return dgroups_; }
|
||||
|
||||
inline const dh::caching_device_vector<uint32_t> &GetIndexableSortedPositions() const {
|
||||
return dindexable_sorted_pos_;
|
||||
}
|
||||
|
||||
// Sort an array that is divided into multiple groups. The array is sorted within each group.
|
||||
// This version provides the group information that is on the host.
|
||||
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
|
||||
// is used.
|
||||
template <typename Comparator = thrust::greater<T>>
|
||||
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
|
||||
const Comparator &comp = Comparator()) {
|
||||
this->Init(groups);
|
||||
this->SortItems(ditems, item_size, group_segments_, comp);
|
||||
}
|
||||
|
||||
// Sort an array that is divided into multiple groups. The array is sorted within each group.
|
||||
// This version provides the group information that is on the device.
|
||||
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
|
||||
// is used.
|
||||
template <typename Comparator = thrust::greater<T>>
|
||||
void SortItems(const T *ditems, uint32_t item_size,
|
||||
const dh::caching_device_vector<uint32_t> &group_segments,
|
||||
const Comparator &comp = Comparator()) {
|
||||
this->Init(item_size);
|
||||
|
||||
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
|
||||
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
|
||||
// when comparators are used. Hence, the following algorithm is used. This is done so that
|
||||
// we can grab the appropriate related values from the original list later, after the
|
||||
// items are sorted.
|
||||
//
|
||||
// Here is the internal representation:
|
||||
// dgroups_: [ 0, 3, 5, 8, 10 ]
|
||||
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
|
||||
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
|
||||
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
|
||||
//
|
||||
// Sort the items first and make a note of the original positions in doriginal_pos_
|
||||
// based on the sort
|
||||
// ditems_: 4 4 3 3 2 1 1 1 1 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
|
||||
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
|
||||
// in kernel, sorting using predicates etc.
|
||||
|
||||
ditems_.assign(thrust::device_ptr<const T>(ditems),
|
||||
thrust::device_ptr<const T>(ditems) + item_size);
|
||||
|
||||
// Allocator to be used by sort for managing space overhead while sorting
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
|
||||
ditems_.begin(), ditems_.end(),
|
||||
doriginal_pos_.begin(), comp);
|
||||
|
||||
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
|
||||
// holisitic item sort order on the segments
|
||||
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
|
||||
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
|
||||
dh::caching_device_vector<uint32_t> group_segments_c(group_segments);
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
group_segments.begin(), group_segments_c.begin());
|
||||
|
||||
// Now, sort the group segments so that you may bring the items within the group together,
|
||||
// in the process also noting the relative changes to the doriginal_pos_ while that happens
|
||||
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
|
||||
group_segments_c.begin(), group_segments_c.end(),
|
||||
doriginal_pos_.begin(), thrust::less<uint32_t>());
|
||||
|
||||
// Finally, gather the original items based on doriginal_pos_ to sort the input and
|
||||
// to store them in ditems_
|
||||
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
|
||||
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
|
||||
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
|
||||
thrust::device_ptr<const T>(ditems), ditems_.begin());
|
||||
}
|
||||
|
||||
// Determine where an item that was originally present at position 'x' has been relocated to
|
||||
// after a sort. Creation of such an index has to be explicitly requested after a sort
|
||||
void CreateIndexableSortedPositions() {
|
||||
dindexable_sorted_pos_.resize(GetNumItems());
|
||||
thrust::scatter(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(GetNumItems()), // Rearrange indices...
|
||||
// ...based on this map
|
||||
thrust::device_ptr<const uint32_t>(GetOriginalPositionsPtr()),
|
||||
dindexable_sorted_pos_.begin()); // Write results into this
|
||||
}
|
||||
};
|
||||
|
||||
// Helper functions
|
||||
|
||||
template <typename T>
|
||||
@@ -283,7 +117,7 @@ class PairwiseLambdaWeightComputer {
|
||||
#if defined(__CUDACC__)
|
||||
PairwiseLambdaWeightComputer(const bst_float *dpreds,
|
||||
const bst_float *dlabels,
|
||||
const SegmentSorter<float> &segment_label_sorter) {}
|
||||
const dh::SegmentSorter<float> &segment_label_sorter) {}
|
||||
|
||||
class PairwiseLambdaWeightMultiplier {
|
||||
public:
|
||||
@@ -302,20 +136,20 @@ class PairwiseLambdaWeightComputer {
|
||||
#if defined(__CUDACC__)
|
||||
class BaseLambdaWeightMultiplier {
|
||||
public:
|
||||
BaseLambdaWeightMultiplier(const SegmentSorter<float> &segment_label_sorter,
|
||||
const SegmentSorter<float> &segment_pred_sorter)
|
||||
: dsorted_labels_(segment_label_sorter.GetItemsPtr()),
|
||||
dorig_pos_(segment_label_sorter.GetOriginalPositionsPtr()),
|
||||
dgroups_(segment_label_sorter.GetGroupsPtr()),
|
||||
dindexable_sorted_preds_pos_ptr_(
|
||||
segment_pred_sorter.GetIndexableSortedPositions().data().get()) {}
|
||||
BaseLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
|
||||
const dh::SegmentSorter<float> &segment_pred_sorter)
|
||||
: dsorted_labels_(segment_label_sorter.GetItemsSpan()),
|
||||
dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()),
|
||||
dgroups_(segment_label_sorter.GetGroupsSpan()),
|
||||
dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {}
|
||||
|
||||
protected:
|
||||
const float *dsorted_labels_{nullptr}; // Labels sorted within a group
|
||||
const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted
|
||||
const uint32_t *dgroups_{nullptr}; // The group indices
|
||||
const common::Span<const float> dsorted_labels_; // Labels sorted within a group
|
||||
const common::Span<const uint32_t> dorig_pos_; // Original indices of the labels
|
||||
// before they are sorted
|
||||
const common::Span<const uint32_t> dgroups_; // The group indices
|
||||
// Where can a prediction for a label be found in the original array, when they are sorted
|
||||
const uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr};
|
||||
const common::Span<const uint32_t> dindexable_sorted_preds_pos_;
|
||||
};
|
||||
|
||||
// While computing the weight that needs to be adjusted by this ranking objective, we need
|
||||
@@ -342,8 +176,8 @@ class BaseLambdaWeightMultiplier {
|
||||
//
|
||||
// We create a sorted prediction positional array, such that value at position 'x' gives
|
||||
// us the position in the sorted prediction array where its related prediction lies.
|
||||
// dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5
|
||||
// at indices: 0 1 2 3 4 5 6 7 8 9
|
||||
// dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5
|
||||
// at indices: 0 1 2 3 4 5 6 7 8 9
|
||||
// Basically, swap the previous 2 arrays, sort the indices and reorder positions
|
||||
// for an O(1) lookup using the position where the sorted label exists.
|
||||
//
|
||||
@@ -351,21 +185,21 @@ class BaseLambdaWeightMultiplier {
|
||||
class IndexablePredictionSorter {
|
||||
public:
|
||||
IndexablePredictionSorter(const bst_float *dpreds,
|
||||
const SegmentSorter<float> &segment_label_sorter) {
|
||||
const dh::SegmentSorter<float> &segment_label_sorter) {
|
||||
// Sort the predictions first
|
||||
segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(),
|
||||
segment_label_sorter.GetGroupSegments());
|
||||
segment_label_sorter.GetGroupSegmentsSpan());
|
||||
|
||||
// Create an index for the sorted prediction positions
|
||||
segment_pred_sorter_.CreateIndexableSortedPositions();
|
||||
}
|
||||
|
||||
inline const SegmentSorter<float> &GetPredictionSorter() const {
|
||||
inline const dh::SegmentSorter<float> &GetPredictionSorter() const {
|
||||
return segment_pred_sorter_;
|
||||
}
|
||||
|
||||
private:
|
||||
SegmentSorter<float> segment_pred_sorter_; // For sorting the predictions
|
||||
dh::SegmentSorter<float> segment_pred_sorter_; // For sorting the predictions
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -380,9 +214,9 @@ class NDCGLambdaWeightComputer
|
||||
// This function object computes the item's DCG value
|
||||
class ComputeItemDCG : public thrust::unary_function<uint32_t, float> {
|
||||
public:
|
||||
XGBOOST_DEVICE ComputeItemDCG(const float *dsorted_labels,
|
||||
const uint32_t *dgroups,
|
||||
const uint32_t *gidxs)
|
||||
XGBOOST_DEVICE ComputeItemDCG(const common::Span<const float> &dsorted_labels,
|
||||
const common::Span<const uint32_t> &dgroups,
|
||||
const common::Span<const uint32_t> &gidxs)
|
||||
: dsorted_labels_(dsorted_labels),
|
||||
dgroups_(dgroups),
|
||||
dgidxs_(gidxs) {}
|
||||
@@ -393,22 +227,23 @@ class NDCGLambdaWeightComputer
|
||||
}
|
||||
|
||||
private:
|
||||
const float *dsorted_labels_{nullptr}; // Labels sorted within a group
|
||||
const uint32_t *dgroups_{nullptr}; // The group indices - where each group begins and ends
|
||||
const uint32_t *dgidxs_{nullptr}; // The group each items belongs to
|
||||
const common::Span<const float> dsorted_labels_; // Labels sorted within a group
|
||||
const common::Span<const uint32_t> dgroups_; // The group indices - where each group
|
||||
// begins and ends
|
||||
const common::Span<const uint32_t> dgidxs_; // The group each items belongs to
|
||||
};
|
||||
|
||||
// Type containing device pointers that can be cheaply copied on the kernel
|
||||
class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
|
||||
public:
|
||||
NDCGLambdaWeightMultiplier(const SegmentSorter<float> &segment_label_sorter,
|
||||
NDCGLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
|
||||
const NDCGLambdaWeightComputer &lwc)
|
||||
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
|
||||
dgroup_dcg_ptr_(lwc.GetGroupDcgs().data().get()) {}
|
||||
dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {}
|
||||
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
if (dgroup_dcg_ptr_[gidx] == 0.0) return 0.0f;
|
||||
if (dgroup_dcgs_[gidx] == 0.0) return 0.0f;
|
||||
|
||||
uint32_t group_begin = dgroups_[gidx];
|
||||
|
||||
@@ -418,43 +253,47 @@ class NDCGLambdaWeightComputer
|
||||
|
||||
// Note: the label positive and negative indices are relative to the entire dataset.
|
||||
// Hence, scale them back to an index within the group
|
||||
auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin;
|
||||
auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin;
|
||||
auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
|
||||
auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
|
||||
return NDCGLambdaWeightComputer::ComputeDeltaWeight(
|
||||
pos_pred_pos, neg_pred_pos,
|
||||
static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]),
|
||||
dgroup_dcg_ptr_[gidx]);
|
||||
dgroup_dcgs_[gidx]);
|
||||
}
|
||||
|
||||
private:
|
||||
const float *dgroup_dcg_ptr_{nullptr}; // Start address of the group DCG values
|
||||
const common::Span<const float> dgroup_dcgs_; // Group DCG values
|
||||
};
|
||||
|
||||
NDCGLambdaWeightComputer(const bst_float *dpreds,
|
||||
const bst_float *dlabels,
|
||||
const SegmentSorter<float> &segment_label_sorter)
|
||||
const dh::SegmentSorter<float> &segment_label_sorter)
|
||||
: IndexablePredictionSorter(dpreds, segment_label_sorter),
|
||||
dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f),
|
||||
weight_multiplier_(segment_label_sorter, *this) {
|
||||
const auto &group_segments = segment_label_sorter.GetGroupSegments();
|
||||
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
|
||||
|
||||
// Allocator to be used for managing space overhead while performing transformed reductions
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
// Compute each elements DCG values and reduce them across groups concurrently.
|
||||
auto end_range =
|
||||
thrust::reduce_by_key(group_segments.begin(), group_segments.end(),
|
||||
thrust::reduce_by_key(thrust::cuda::par(alloc),
|
||||
dh::tcbegin(group_segments), dh::tcend(group_segments),
|
||||
thrust::make_transform_iterator(
|
||||
// The indices need not be sequential within a group, as we care only
|
||||
// about the sum of items DCG values within a group
|
||||
segment_label_sorter.GetOriginalPositions().begin(),
|
||||
ComputeItemDCG(segment_label_sorter.GetItemsPtr(),
|
||||
segment_label_sorter.GetGroupsPtr(),
|
||||
group_segments.data().get())),
|
||||
dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()),
|
||||
ComputeItemDCG(segment_label_sorter.GetItemsSpan(),
|
||||
segment_label_sorter.GetGroupsSpan(),
|
||||
group_segments)),
|
||||
thrust::make_discard_iterator(), // We don't care for the group indices
|
||||
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
|
||||
CHECK(end_range.second - dgroup_dcg_.begin() == dgroup_dcg_.size());
|
||||
}
|
||||
|
||||
inline const dh::caching_device_vector<float> &GetGroupDcgs() const {
|
||||
return dgroup_dcg_;
|
||||
inline const common::Span<const float> GetGroupDcgsSpan() const {
|
||||
return { dgroup_dcg_.data().get(), dgroup_dcg_.size() };
|
||||
}
|
||||
|
||||
inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const {
|
||||
@@ -664,7 +503,7 @@ class MAPLambdaWeightComputer
|
||||
#if defined(__CUDACC__)
|
||||
MAPLambdaWeightComputer(const bst_float *dpreds,
|
||||
const bst_float *dlabels,
|
||||
const SegmentSorter<float> &segment_label_sorter)
|
||||
const dh::SegmentSorter<float> &segment_label_sorter)
|
||||
: IndexablePredictionSorter(dpreds, segment_label_sorter),
|
||||
dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()),
|
||||
weight_multiplier_(segment_label_sorter, *this) {
|
||||
@@ -672,7 +511,7 @@ class MAPLambdaWeightComputer
|
||||
}
|
||||
|
||||
void CreateMAPStats(const bst_float *dlabels,
|
||||
const SegmentSorter<float> &segment_label_sorter) {
|
||||
const dh::SegmentSorter<float> &segment_label_sorter) {
|
||||
// For each group, go through the sorted prediction positions, and look up its corresponding
|
||||
// label from the unsorted labels (from the original label list)
|
||||
|
||||
@@ -683,7 +522,7 @@ class MAPLambdaWeightComputer
|
||||
auto nitems = segment_label_sorter.GetNumItems();
|
||||
dh::caching_device_vector<uint32_t> dhits(nitems, 0);
|
||||
// Original positions of the predictions after they have been sorted
|
||||
const uint32_t *pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsPtr();
|
||||
const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan();
|
||||
// Unsorted labels
|
||||
const float *unsorted_labels = dlabels;
|
||||
auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) {
|
||||
@@ -700,24 +539,23 @@ class MAPLambdaWeightComputer
|
||||
|
||||
// Next, prefix scan the positive labels that are segmented to accumulate them.
|
||||
// This is required for computing the accumulated precisions
|
||||
const auto &group_segments = segment_label_sorter.GetGroupSegments();
|
||||
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
|
||||
// Data segmented into different groups...
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
group_segments.begin(), group_segments.end(),
|
||||
dh::tcbegin(group_segments), dh::tcend(group_segments),
|
||||
dhits.begin(), // Input value
|
||||
dhits.begin()); // In-place scan
|
||||
|
||||
// Compute accumulated precisions for each item, assuming positive and
|
||||
// negative instances are missing.
|
||||
// But first, compute individual item precisions
|
||||
const auto *dgidx_arr = group_segments.data().get();
|
||||
const auto *dhits_arr = dhits.data().get();
|
||||
// Group info on device
|
||||
const uint32_t *dgroups = segment_label_sorter.GetGroupsPtr();
|
||||
const auto &dgroups = segment_label_sorter.GetGroupsSpan();
|
||||
uint32_t ngroups = segment_label_sorter.GetNumGroups();
|
||||
auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) {
|
||||
if (unsorted_labels[pred_original_pos[idx]] > 0.0f) {
|
||||
auto idx_within_group = (idx - dgroups[dgidx_arr[idx]]) + 1;
|
||||
auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1;
|
||||
return MAPStats(static_cast<float>(dhits_arr[idx]) / idx_within_group,
|
||||
static_cast<float>(dhits_arr[idx] - 1) / idx_within_group,
|
||||
static_cast<float>(dhits_arr[idx] + 1) / idx_within_group,
|
||||
@@ -734,22 +572,22 @@ class MAPLambdaWeightComputer
|
||||
// Lastly, compute the accumulated precisions for all the items segmented by groups.
|
||||
// The precisions are accumulated within each group
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
group_segments.begin(), group_segments.end(),
|
||||
dh::tcbegin(group_segments), dh::tcend(group_segments),
|
||||
this->dmap_stats_.begin(), // Input map stats
|
||||
this->dmap_stats_.begin()); // In-place scan and output here
|
||||
}
|
||||
|
||||
inline const dh::caching_device_vector<MAPStats> &GetMapStats() const {
|
||||
return dmap_stats_;
|
||||
inline const common::Span<const MAPStats> GetMapStatsSpan() const {
|
||||
return { dmap_stats_.data().get(), dmap_stats_.size() };
|
||||
}
|
||||
|
||||
// Type containing device pointers that can be cheaply copied on the kernel
|
||||
class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
|
||||
public:
|
||||
MAPLambdaWeightMultiplier(const SegmentSorter<float> &segment_label_sorter,
|
||||
MAPLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
|
||||
const MAPLambdaWeightComputer &lwc)
|
||||
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
|
||||
dmap_stats_ptr_(lwc.GetMapStats().data().get()) {}
|
||||
dmap_stats_(lwc.GetMapStatsSpan()) {}
|
||||
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
@@ -762,17 +600,17 @@ class MAPLambdaWeightComputer
|
||||
|
||||
// Note: the label positive and negative indices are relative to the entire dataset.
|
||||
// Hence, scale them back to an index within the group
|
||||
auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin;
|
||||
auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin;
|
||||
auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
|
||||
auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
|
||||
return MAPLambdaWeightComputer::GetLambdaMAP(
|
||||
pos_pred_pos, neg_pred_pos,
|
||||
dsorted_labels_[pidx], dsorted_labels_[nidx],
|
||||
&dmap_stats_ptr_[group_begin], group_end - group_begin);
|
||||
&dmap_stats_[group_begin], group_end - group_begin);
|
||||
}
|
||||
|
||||
private:
|
||||
const MAPStats *dmap_stats_ptr_{nullptr}; // Start address of the map stats for every sorted
|
||||
// prediction value
|
||||
common::Span<const MAPStats> dmap_stats_; // Start address of the map stats for every sorted
|
||||
// prediction value
|
||||
};
|
||||
|
||||
inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; }
|
||||
@@ -785,7 +623,7 @@ class MAPLambdaWeightComputer
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
class SortedLabelList : SegmentSorter<float> {
|
||||
class SortedLabelList : dh::SegmentSorter<float> {
|
||||
private:
|
||||
const LambdaRankParam ¶m_; // Objective configuration
|
||||
|
||||
@@ -808,7 +646,7 @@ class SortedLabelList : SegmentSorter<float> {
|
||||
GradientPair *out_gpair,
|
||||
float weight_normalization_factor) {
|
||||
// Group info on device
|
||||
const uint32_t *dgroups = this->GetGroupsPtr();
|
||||
const auto &dgroups = this->GetGroupsSpan();
|
||||
uint32_t ngroups = this->GetNumGroups() + 1;
|
||||
|
||||
uint32_t total_items = this->GetNumItems();
|
||||
@@ -816,12 +654,12 @@ class SortedLabelList : SegmentSorter<float> {
|
||||
|
||||
float fix_list_weight = param_.fix_list_weight;
|
||||
|
||||
const uint32_t *original_pos = this->GetOriginalPositionsPtr();
|
||||
const auto &original_pos = this->GetOriginalPositionsSpan();
|
||||
|
||||
uint32_t num_weights = weights.Size();
|
||||
auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr;
|
||||
|
||||
const bst_float *sorted_labels = this->GetItemsPtr();
|
||||
const auto &sorted_labels = this->GetItemsSpan();
|
||||
|
||||
// This is used to adjust the weight of different elements based on the different ranking
|
||||
// objective function policies
|
||||
@@ -834,7 +672,7 @@ class SortedLabelList : SegmentSorter<float> {
|
||||
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
|
||||
// First, determine the group 'idx' belongs to
|
||||
uint32_t item_idx = idx % total_items;
|
||||
uint32_t group_idx = dh::UpperBound(dgroups, ngroups, item_idx);
|
||||
uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx);
|
||||
// Span of this group within the larger labels/predictions sorted tuple
|
||||
uint32_t group_begin = dgroups[group_idx - 1];
|
||||
uint32_t group_end = dgroups[group_idx];
|
||||
@@ -847,9 +685,9 @@ class SortedLabelList : SegmentSorter<float> {
|
||||
// Find the number of labels less than and greater than the current label
|
||||
// at the sorted index position item_idx
|
||||
uint32_t nleft = CountNumItemsToTheLeftOf(
|
||||
sorted_labels + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
|
||||
sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
|
||||
uint32_t nright = CountNumItemsToTheRightOf(
|
||||
sorted_labels + item_idx, group_end - item_idx, sorted_labels[item_idx]);
|
||||
sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]);
|
||||
|
||||
// Create a minstd_rand object to act as our source of randomness
|
||||
thrust::minstd_rand rng((iter + 1) * 1111);
|
||||
|
||||
Reference in New Issue
Block a user