Move segment sorter to common (#5378)

- move segment sorter to common
- this is the first of a handful of pr's that splits the larger pr #5326
- it moves this facility to common (from ranking objective class), so that it can be
    used for metric computation
- it also wraps all the bald device pointers into span.
This commit is contained in:
sriramch 2020-02-28 23:42:07 -08:00 committed by GitHub
parent 2ba8c13b69
commit b81f8cbbc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 275 additions and 260 deletions

View File

@ -8,7 +8,9 @@
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <thrust/logical.h> #include <thrust/logical.h>
#include <thrust/gather.h>
#include <omp.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cub/util_allocator.cuh> #include <cub/util_allocator.cuh>
@ -1285,6 +1287,175 @@ thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) {
return tcbegin(span) + span.size(); return tcbegin(span) + span.size();
} }
// This type sorts an array which is divided into multiple groups. The sorting is influenced
// by the function object 'Comparator'
template <typename T>
class SegmentSorter {
private:
// Items sorted within the group
caching_device_vector<T> ditems_;
// Original position of the items before they are sorted descendingly within its groups
caching_device_vector<uint32_t> doriginal_pos_;
// Segments within the original list that delineates the different groups
caching_device_vector<uint32_t> group_segments_;
// Need this on the device as it is used in the kernels
caching_device_vector<uint32_t> dgroups_; // Group information on device
// Where did the item that was originally present at position 'x' move to after they are sorted
caching_device_vector<uint32_t> dindexable_sorted_pos_;
// Initialize everything but the segments
void Init(uint32_t num_elems) {
ditems_.resize(num_elems);
doriginal_pos_.resize(num_elems);
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
}
// Initialize all with group info
void Init(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
this->Init(num_elems);
this->CreateGroupSegments(groups);
}
public:
// This needs to be public due to device lambda
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
group_segments_.resize(num_elems, 0);
dgroups_ = groups;
if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them
// Define the segments by assigning a group ID to each element
const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size();
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
return dh::UpperBound(dgroups, ngroups, idx) - 1;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(num_elems),
group_segments_.begin(),
ComputeGroupIDLambda);
}
// Accessors that returns device pointer
inline uint32_t GetNumItems() const { return ditems_.size(); }
inline const xgboost::common::Span<const T> GetItemsSpan() const {
return { ditems_.data().get(), ditems_.size() };
}
inline const xgboost::common::Span<const uint32_t> GetOriginalPositionsSpan() const {
return { doriginal_pos_.data().get(), doriginal_pos_.size() };
}
inline const xgboost::common::Span<const uint32_t> GetGroupSegmentsSpan() const {
return { group_segments_.data().get(), group_segments_.size() };
}
inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; }
inline const xgboost::common::Span<const uint32_t> GetGroupsSpan() const {
return { dgroups_.data().get(), dgroups_.size() };
}
inline const xgboost::common::Span<const uint32_t> GetIndexableSortedPositionsSpan() const {
return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() };
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the host.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
const Comparator &comp = Comparator()) {
this->Init(groups);
this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp);
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the device.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size,
const xgboost::common::Span<const uint32_t> &group_segments,
const Comparator &comp = Comparator()) {
this->Init(item_size);
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
// when comparators are used. Hence, the following algorithm is used. This is done so that
// we can grab the appropriate related values from the original list later, after the
// items are sorted.
//
// Here is the internal representation:
// dgroups_: [ 0, 3, 5, 8, 10 ]
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
//
// Sort the items first and make a note of the original positions in doriginal_pos_
// based on the sort
// ditems_: 4 4 3 3 2 1 1 1 1 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
// in kernel, sorting using predicates etc.
ditems_.assign(thrust::device_ptr<const T>(ditems),
thrust::device_ptr<const T>(ditems) + item_size);
// Allocator to be used by sort for managing space overhead while sorting
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
ditems_.begin(), ditems_.end(),
doriginal_pos_.begin(), comp);
if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
// holisitic item sort order on the segments
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
caching_device_vector<uint32_t> group_segments_c(item_size);
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
dh::tcbegin(group_segments), group_segments_c.begin());
// Now, sort the group segments so that you may bring the items within the group together,
// in the process also noting the relative changes to the doriginal_pos_ while that happens
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
group_segments_c.begin(), group_segments_c.end(),
doriginal_pos_.begin(), thrust::less<uint32_t>());
// Finally, gather the original items based on doriginal_pos_ to sort the input and
// to store them in ditems_
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
thrust::device_ptr<const T>(ditems), ditems_.begin());
}
// Determine where an item that was originally present at position 'x' has been relocated to
// after a sort. Creation of such an index has to be explicitly requested after a sort
void CreateIndexableSortedPositions() {
dindexable_sorted_pos_.resize(GetNumItems());
thrust::scatter(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(GetNumItems()), // Rearrange indices...
// ...based on this map
dh::tcbegin(GetOriginalPositionsSpan()),
dindexable_sorted_pos_.begin()); // Write results into this
}
};
template <typename FunctionT> template <typename FunctionT>
class LauncherItr { class LauncherItr {
public: public:

View File

@ -48,172 +48,6 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
}; };
#if defined(__CUDACC__) #if defined(__CUDACC__)
// This type sorts an array which is divided into multiple groups. The sorting is influenced
// by the function object 'Comparator'
template <typename T>
class SegmentSorter {
private:
// Items sorted within the group
dh::caching_device_vector<T> ditems_;
// Original position of the items before they are sorted descendingly within its groups
dh::caching_device_vector<uint32_t> doriginal_pos_;
// Segments within the original list that delineates the different groups
dh::caching_device_vector<uint32_t> group_segments_;
// Need this on the device as it is used in the kernels
dh::caching_device_vector<uint32_t> dgroups_; // Group information on device
// Where did the item that was originally present at position 'x' move to after they are sorted
dh::caching_device_vector<uint32_t> dindexable_sorted_pos_;
// Initialize everything but the segments
void Init(uint32_t num_elems) {
ditems_.resize(num_elems);
doriginal_pos_.resize(num_elems);
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
}
// Initialize all with group info
void Init(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
this->Init(num_elems);
this->CreateGroupSegments(groups);
}
public:
// This needs to be public due to device lambda
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
group_segments_.resize(num_elems);
dgroups_ = groups;
// Define the segments by assigning a group ID to each element
const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size();
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
return dh::UpperBound(dgroups, ngroups, idx) - 1;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(num_elems),
group_segments_.begin(),
ComputeGroupIDLambda);
}
// Accessors that returns device pointer
inline const T *GetItemsPtr() const { return ditems_.data().get(); }
inline uint32_t GetNumItems() const { return ditems_.size(); }
inline const dh::caching_device_vector<T> &GetItems() const {
return ditems_;
}
inline const uint32_t *GetOriginalPositionsPtr() const { return doriginal_pos_.data().get(); }
inline const dh::caching_device_vector<uint32_t> &GetOriginalPositions() const {
return doriginal_pos_;
}
inline const dh::caching_device_vector<uint32_t> &GetGroupSegments() const {
return group_segments_;
}
inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; }
inline const uint32_t *GetGroupsPtr() const { return dgroups_.data().get(); }
inline const dh::caching_device_vector<uint32_t> &GetGroups() const { return dgroups_; }
inline const dh::caching_device_vector<uint32_t> &GetIndexableSortedPositions() const {
return dindexable_sorted_pos_;
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the host.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
const Comparator &comp = Comparator()) {
this->Init(groups);
this->SortItems(ditems, item_size, group_segments_, comp);
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the device.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size,
const dh::caching_device_vector<uint32_t> &group_segments,
const Comparator &comp = Comparator()) {
this->Init(item_size);
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
// when comparators are used. Hence, the following algorithm is used. This is done so that
// we can grab the appropriate related values from the original list later, after the
// items are sorted.
//
// Here is the internal representation:
// dgroups_: [ 0, 3, 5, 8, 10 ]
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
//
// Sort the items first and make a note of the original positions in doriginal_pos_
// based on the sort
// ditems_: 4 4 3 3 2 1 1 1 1 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
// in kernel, sorting using predicates etc.
ditems_.assign(thrust::device_ptr<const T>(ditems),
thrust::device_ptr<const T>(ditems) + item_size);
// Allocator to be used by sort for managing space overhead while sorting
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
ditems_.begin(), ditems_.end(),
doriginal_pos_.begin(), comp);
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
// holisitic item sort order on the segments
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
dh::caching_device_vector<uint32_t> group_segments_c(group_segments);
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
group_segments.begin(), group_segments_c.begin());
// Now, sort the group segments so that you may bring the items within the group together,
// in the process also noting the relative changes to the doriginal_pos_ while that happens
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
group_segments_c.begin(), group_segments_c.end(),
doriginal_pos_.begin(), thrust::less<uint32_t>());
// Finally, gather the original items based on doriginal_pos_ to sort the input and
// to store them in ditems_
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
thrust::device_ptr<const T>(ditems), ditems_.begin());
}
// Determine where an item that was originally present at position 'x' has been relocated to
// after a sort. Creation of such an index has to be explicitly requested after a sort
void CreateIndexableSortedPositions() {
dindexable_sorted_pos_.resize(GetNumItems());
thrust::scatter(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(GetNumItems()), // Rearrange indices...
// ...based on this map
thrust::device_ptr<const uint32_t>(GetOriginalPositionsPtr()),
dindexable_sorted_pos_.begin()); // Write results into this
}
};
// Helper functions // Helper functions
template <typename T> template <typename T>
@ -283,7 +117,7 @@ class PairwiseLambdaWeightComputer {
#if defined(__CUDACC__) #if defined(__CUDACC__)
PairwiseLambdaWeightComputer(const bst_float *dpreds, PairwiseLambdaWeightComputer(const bst_float *dpreds,
const bst_float *dlabels, const bst_float *dlabels,
const SegmentSorter<float> &segment_label_sorter) {} const dh::SegmentSorter<float> &segment_label_sorter) {}
class PairwiseLambdaWeightMultiplier { class PairwiseLambdaWeightMultiplier {
public: public:
@ -302,20 +136,20 @@ class PairwiseLambdaWeightComputer {
#if defined(__CUDACC__) #if defined(__CUDACC__)
class BaseLambdaWeightMultiplier { class BaseLambdaWeightMultiplier {
public: public:
BaseLambdaWeightMultiplier(const SegmentSorter<float> &segment_label_sorter, BaseLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
const SegmentSorter<float> &segment_pred_sorter) const dh::SegmentSorter<float> &segment_pred_sorter)
: dsorted_labels_(segment_label_sorter.GetItemsPtr()), : dsorted_labels_(segment_label_sorter.GetItemsSpan()),
dorig_pos_(segment_label_sorter.GetOriginalPositionsPtr()), dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()),
dgroups_(segment_label_sorter.GetGroupsPtr()), dgroups_(segment_label_sorter.GetGroupsSpan()),
dindexable_sorted_preds_pos_ptr_( dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {}
segment_pred_sorter.GetIndexableSortedPositions().data().get()) {}
protected: protected:
const float *dsorted_labels_{nullptr}; // Labels sorted within a group const common::Span<const float> dsorted_labels_; // Labels sorted within a group
const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted const common::Span<const uint32_t> dorig_pos_; // Original indices of the labels
const uint32_t *dgroups_{nullptr}; // The group indices // before they are sorted
const common::Span<const uint32_t> dgroups_; // The group indices
// Where can a prediction for a label be found in the original array, when they are sorted // Where can a prediction for a label be found in the original array, when they are sorted
const uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr}; const common::Span<const uint32_t> dindexable_sorted_preds_pos_;
}; };
// While computing the weight that needs to be adjusted by this ranking objective, we need // While computing the weight that needs to be adjusted by this ranking objective, we need
@ -342,7 +176,7 @@ class BaseLambdaWeightMultiplier {
// //
// We create a sorted prediction positional array, such that value at position 'x' gives // We create a sorted prediction positional array, such that value at position 'x' gives
// us the position in the sorted prediction array where its related prediction lies. // us the position in the sorted prediction array where its related prediction lies.
// dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5 // dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5
// at indices: 0 1 2 3 4 5 6 7 8 9 // at indices: 0 1 2 3 4 5 6 7 8 9
// Basically, swap the previous 2 arrays, sort the indices and reorder positions // Basically, swap the previous 2 arrays, sort the indices and reorder positions
// for an O(1) lookup using the position where the sorted label exists. // for an O(1) lookup using the position where the sorted label exists.
@ -351,21 +185,21 @@ class BaseLambdaWeightMultiplier {
class IndexablePredictionSorter { class IndexablePredictionSorter {
public: public:
IndexablePredictionSorter(const bst_float *dpreds, IndexablePredictionSorter(const bst_float *dpreds,
const SegmentSorter<float> &segment_label_sorter) { const dh::SegmentSorter<float> &segment_label_sorter) {
// Sort the predictions first // Sort the predictions first
segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(), segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(),
segment_label_sorter.GetGroupSegments()); segment_label_sorter.GetGroupSegmentsSpan());
// Create an index for the sorted prediction positions // Create an index for the sorted prediction positions
segment_pred_sorter_.CreateIndexableSortedPositions(); segment_pred_sorter_.CreateIndexableSortedPositions();
} }
inline const SegmentSorter<float> &GetPredictionSorter() const { inline const dh::SegmentSorter<float> &GetPredictionSorter() const {
return segment_pred_sorter_; return segment_pred_sorter_;
} }
private: private:
SegmentSorter<float> segment_pred_sorter_; // For sorting the predictions dh::SegmentSorter<float> segment_pred_sorter_; // For sorting the predictions
}; };
#endif #endif
@ -380,9 +214,9 @@ class NDCGLambdaWeightComputer
// This function object computes the item's DCG value // This function object computes the item's DCG value
class ComputeItemDCG : public thrust::unary_function<uint32_t, float> { class ComputeItemDCG : public thrust::unary_function<uint32_t, float> {
public: public:
XGBOOST_DEVICE ComputeItemDCG(const float *dsorted_labels, XGBOOST_DEVICE ComputeItemDCG(const common::Span<const float> &dsorted_labels,
const uint32_t *dgroups, const common::Span<const uint32_t> &dgroups,
const uint32_t *gidxs) const common::Span<const uint32_t> &gidxs)
: dsorted_labels_(dsorted_labels), : dsorted_labels_(dsorted_labels),
dgroups_(dgroups), dgroups_(dgroups),
dgidxs_(gidxs) {} dgidxs_(gidxs) {}
@ -393,22 +227,23 @@ class NDCGLambdaWeightComputer
} }
private: private:
const float *dsorted_labels_{nullptr}; // Labels sorted within a group const common::Span<const float> dsorted_labels_; // Labels sorted within a group
const uint32_t *dgroups_{nullptr}; // The group indices - where each group begins and ends const common::Span<const uint32_t> dgroups_; // The group indices - where each group
const uint32_t *dgidxs_{nullptr}; // The group each items belongs to // begins and ends
const common::Span<const uint32_t> dgidxs_; // The group each items belongs to
}; };
// Type containing device pointers that can be cheaply copied on the kernel // Type containing device pointers that can be cheaply copied on the kernel
class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
public: public:
NDCGLambdaWeightMultiplier(const SegmentSorter<float> &segment_label_sorter, NDCGLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
const NDCGLambdaWeightComputer &lwc) const NDCGLambdaWeightComputer &lwc)
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
dgroup_dcg_ptr_(lwc.GetGroupDcgs().data().get()) {} dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {}
// Adjust the items weight by this value // Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
if (dgroup_dcg_ptr_[gidx] == 0.0) return 0.0f; if (dgroup_dcgs_[gidx] == 0.0) return 0.0f;
uint32_t group_begin = dgroups_[gidx]; uint32_t group_begin = dgroups_[gidx];
@ -418,43 +253,47 @@ class NDCGLambdaWeightComputer
// Note: the label positive and negative indices are relative to the entire dataset. // Note: the label positive and negative indices are relative to the entire dataset.
// Hence, scale them back to an index within the group // Hence, scale them back to an index within the group
auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin; auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin; auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
return NDCGLambdaWeightComputer::ComputeDeltaWeight( return NDCGLambdaWeightComputer::ComputeDeltaWeight(
pos_pred_pos, neg_pred_pos, pos_pred_pos, neg_pred_pos,
static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]), static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]),
dgroup_dcg_ptr_[gidx]); dgroup_dcgs_[gidx]);
} }
private: private:
const float *dgroup_dcg_ptr_{nullptr}; // Start address of the group DCG values const common::Span<const float> dgroup_dcgs_; // Group DCG values
}; };
NDCGLambdaWeightComputer(const bst_float *dpreds, NDCGLambdaWeightComputer(const bst_float *dpreds,
const bst_float *dlabels, const bst_float *dlabels,
const SegmentSorter<float> &segment_label_sorter) const dh::SegmentSorter<float> &segment_label_sorter)
: IndexablePredictionSorter(dpreds, segment_label_sorter), : IndexablePredictionSorter(dpreds, segment_label_sorter),
dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f), dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f),
weight_multiplier_(segment_label_sorter, *this) { weight_multiplier_(segment_label_sorter, *this) {
const auto &group_segments = segment_label_sorter.GetGroupSegments(); const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
// Allocator to be used for managing space overhead while performing transformed reductions
dh::XGBCachingDeviceAllocator<char> alloc;
// Compute each elements DCG values and reduce them across groups concurrently. // Compute each elements DCG values and reduce them across groups concurrently.
auto end_range = auto end_range =
thrust::reduce_by_key(group_segments.begin(), group_segments.end(), thrust::reduce_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
thrust::make_transform_iterator( thrust::make_transform_iterator(
// The indices need not be sequential within a group, as we care only // The indices need not be sequential within a group, as we care only
// about the sum of items DCG values within a group // about the sum of items DCG values within a group
segment_label_sorter.GetOriginalPositions().begin(), dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()),
ComputeItemDCG(segment_label_sorter.GetItemsPtr(), ComputeItemDCG(segment_label_sorter.GetItemsSpan(),
segment_label_sorter.GetGroupsPtr(), segment_label_sorter.GetGroupsSpan(),
group_segments.data().get())), group_segments)),
thrust::make_discard_iterator(), // We don't care for the group indices thrust::make_discard_iterator(), // We don't care for the group indices
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
CHECK(end_range.second - dgroup_dcg_.begin() == dgroup_dcg_.size()); CHECK(end_range.second - dgroup_dcg_.begin() == dgroup_dcg_.size());
} }
inline const dh::caching_device_vector<float> &GetGroupDcgs() const { inline const common::Span<const float> GetGroupDcgsSpan() const {
return dgroup_dcg_; return { dgroup_dcg_.data().get(), dgroup_dcg_.size() };
} }
inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const { inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const {
@ -664,7 +503,7 @@ class MAPLambdaWeightComputer
#if defined(__CUDACC__) #if defined(__CUDACC__)
MAPLambdaWeightComputer(const bst_float *dpreds, MAPLambdaWeightComputer(const bst_float *dpreds,
const bst_float *dlabels, const bst_float *dlabels,
const SegmentSorter<float> &segment_label_sorter) const dh::SegmentSorter<float> &segment_label_sorter)
: IndexablePredictionSorter(dpreds, segment_label_sorter), : IndexablePredictionSorter(dpreds, segment_label_sorter),
dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()), dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()),
weight_multiplier_(segment_label_sorter, *this) { weight_multiplier_(segment_label_sorter, *this) {
@ -672,7 +511,7 @@ class MAPLambdaWeightComputer
} }
void CreateMAPStats(const bst_float *dlabels, void CreateMAPStats(const bst_float *dlabels,
const SegmentSorter<float> &segment_label_sorter) { const dh::SegmentSorter<float> &segment_label_sorter) {
// For each group, go through the sorted prediction positions, and look up its corresponding // For each group, go through the sorted prediction positions, and look up its corresponding
// label from the unsorted labels (from the original label list) // label from the unsorted labels (from the original label list)
@ -683,7 +522,7 @@ class MAPLambdaWeightComputer
auto nitems = segment_label_sorter.GetNumItems(); auto nitems = segment_label_sorter.GetNumItems();
dh::caching_device_vector<uint32_t> dhits(nitems, 0); dh::caching_device_vector<uint32_t> dhits(nitems, 0);
// Original positions of the predictions after they have been sorted // Original positions of the predictions after they have been sorted
const uint32_t *pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsPtr(); const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan();
// Unsorted labels // Unsorted labels
const float *unsorted_labels = dlabels; const float *unsorted_labels = dlabels;
auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) { auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) {
@ -700,24 +539,23 @@ class MAPLambdaWeightComputer
// Next, prefix scan the positive labels that are segmented to accumulate them. // Next, prefix scan the positive labels that are segmented to accumulate them.
// This is required for computing the accumulated precisions // This is required for computing the accumulated precisions
const auto &group_segments = segment_label_sorter.GetGroupSegments(); const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
// Data segmented into different groups... // Data segmented into different groups...
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
group_segments.begin(), group_segments.end(), dh::tcbegin(group_segments), dh::tcend(group_segments),
dhits.begin(), // Input value dhits.begin(), // Input value
dhits.begin()); // In-place scan dhits.begin()); // In-place scan
// Compute accumulated precisions for each item, assuming positive and // Compute accumulated precisions for each item, assuming positive and
// negative instances are missing. // negative instances are missing.
// But first, compute individual item precisions // But first, compute individual item precisions
const auto *dgidx_arr = group_segments.data().get();
const auto *dhits_arr = dhits.data().get(); const auto *dhits_arr = dhits.data().get();
// Group info on device // Group info on device
const uint32_t *dgroups = segment_label_sorter.GetGroupsPtr(); const auto &dgroups = segment_label_sorter.GetGroupsSpan();
uint32_t ngroups = segment_label_sorter.GetNumGroups(); uint32_t ngroups = segment_label_sorter.GetNumGroups();
auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) { auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) {
if (unsorted_labels[pred_original_pos[idx]] > 0.0f) { if (unsorted_labels[pred_original_pos[idx]] > 0.0f) {
auto idx_within_group = (idx - dgroups[dgidx_arr[idx]]) + 1; auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1;
return MAPStats(static_cast<float>(dhits_arr[idx]) / idx_within_group, return MAPStats(static_cast<float>(dhits_arr[idx]) / idx_within_group,
static_cast<float>(dhits_arr[idx] - 1) / idx_within_group, static_cast<float>(dhits_arr[idx] - 1) / idx_within_group,
static_cast<float>(dhits_arr[idx] + 1) / idx_within_group, static_cast<float>(dhits_arr[idx] + 1) / idx_within_group,
@ -734,22 +572,22 @@ class MAPLambdaWeightComputer
// Lastly, compute the accumulated precisions for all the items segmented by groups. // Lastly, compute the accumulated precisions for all the items segmented by groups.
// The precisions are accumulated within each group // The precisions are accumulated within each group
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
group_segments.begin(), group_segments.end(), dh::tcbegin(group_segments), dh::tcend(group_segments),
this->dmap_stats_.begin(), // Input map stats this->dmap_stats_.begin(), // Input map stats
this->dmap_stats_.begin()); // In-place scan and output here this->dmap_stats_.begin()); // In-place scan and output here
} }
inline const dh::caching_device_vector<MAPStats> &GetMapStats() const { inline const common::Span<const MAPStats> GetMapStatsSpan() const {
return dmap_stats_; return { dmap_stats_.data().get(), dmap_stats_.size() };
} }
// Type containing device pointers that can be cheaply copied on the kernel // Type containing device pointers that can be cheaply copied on the kernel
class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
public: public:
MAPLambdaWeightMultiplier(const SegmentSorter<float> &segment_label_sorter, MAPLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
const MAPLambdaWeightComputer &lwc) const MAPLambdaWeightComputer &lwc)
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
dmap_stats_ptr_(lwc.GetMapStats().data().get()) {} dmap_stats_(lwc.GetMapStatsSpan()) {}
// Adjust the items weight by this value // Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
@ -762,16 +600,16 @@ class MAPLambdaWeightComputer
// Note: the label positive and negative indices are relative to the entire dataset. // Note: the label positive and negative indices are relative to the entire dataset.
// Hence, scale them back to an index within the group // Hence, scale them back to an index within the group
auto pos_pred_pos = dindexable_sorted_preds_pos_ptr_[pos_lab_orig_posn] - group_begin; auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
auto neg_pred_pos = dindexable_sorted_preds_pos_ptr_[neg_lab_orig_posn] - group_begin; auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
return MAPLambdaWeightComputer::GetLambdaMAP( return MAPLambdaWeightComputer::GetLambdaMAP(
pos_pred_pos, neg_pred_pos, pos_pred_pos, neg_pred_pos,
dsorted_labels_[pidx], dsorted_labels_[nidx], dsorted_labels_[pidx], dsorted_labels_[nidx],
&dmap_stats_ptr_[group_begin], group_end - group_begin); &dmap_stats_[group_begin], group_end - group_begin);
} }
private: private:
const MAPStats *dmap_stats_ptr_{nullptr}; // Start address of the map stats for every sorted common::Span<const MAPStats> dmap_stats_; // Start address of the map stats for every sorted
// prediction value // prediction value
}; };
@ -785,7 +623,7 @@ class MAPLambdaWeightComputer
}; };
#if defined(__CUDACC__) #if defined(__CUDACC__)
class SortedLabelList : SegmentSorter<float> { class SortedLabelList : dh::SegmentSorter<float> {
private: private:
const LambdaRankParam &param_; // Objective configuration const LambdaRankParam &param_; // Objective configuration
@ -808,7 +646,7 @@ class SortedLabelList : SegmentSorter<float> {
GradientPair *out_gpair, GradientPair *out_gpair,
float weight_normalization_factor) { float weight_normalization_factor) {
// Group info on device // Group info on device
const uint32_t *dgroups = this->GetGroupsPtr(); const auto &dgroups = this->GetGroupsSpan();
uint32_t ngroups = this->GetNumGroups() + 1; uint32_t ngroups = this->GetNumGroups() + 1;
uint32_t total_items = this->GetNumItems(); uint32_t total_items = this->GetNumItems();
@ -816,12 +654,12 @@ class SortedLabelList : SegmentSorter<float> {
float fix_list_weight = param_.fix_list_weight; float fix_list_weight = param_.fix_list_weight;
const uint32_t *original_pos = this->GetOriginalPositionsPtr(); const auto &original_pos = this->GetOriginalPositionsSpan();
uint32_t num_weights = weights.Size(); uint32_t num_weights = weights.Size();
auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr; auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr;
const bst_float *sorted_labels = this->GetItemsPtr(); const auto &sorted_labels = this->GetItemsSpan();
// This is used to adjust the weight of different elements based on the different ranking // This is used to adjust the weight of different elements based on the different ranking
// objective function policies // objective function policies
@ -834,7 +672,7 @@ class SortedLabelList : SegmentSorter<float> {
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) { dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
// First, determine the group 'idx' belongs to // First, determine the group 'idx' belongs to
uint32_t item_idx = idx % total_items; uint32_t item_idx = idx % total_items;
uint32_t group_idx = dh::UpperBound(dgroups, ngroups, item_idx); uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx);
// Span of this group within the larger labels/predictions sorted tuple // Span of this group within the larger labels/predictions sorted tuple
uint32_t group_begin = dgroups[group_idx - 1]; uint32_t group_begin = dgroups[group_idx - 1];
uint32_t group_end = dgroups[group_idx]; uint32_t group_end = dgroups[group_idx];
@ -847,9 +685,9 @@ class SortedLabelList : SegmentSorter<float> {
// Find the number of labels less than and greater than the current label // Find the number of labels less than and greater than the current label
// at the sorted index position item_idx // at the sorted index position item_idx
uint32_t nleft = CountNumItemsToTheLeftOf( uint32_t nleft = CountNumItemsToTheLeftOf(
sorted_labels + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]); sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
uint32_t nright = CountNumItemsToTheRightOf( uint32_t nright = CountNumItemsToTheRightOf(
sorted_labels + item_idx, group_end - item_idx, sorted_labels[item_idx]); sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]);
// Create a minstd_rand object to act as our source of randomness // Create a minstd_rand object to act as our source of randomness
thrust::minstd_rand rng((iter + 1) * 1111); thrust::minstd_rand rng((iter + 1) * 1111);

View File

@ -5,36 +5,34 @@
namespace xgboost { namespace xgboost {
template <typename T = uint32_t, typename Comparator = thrust::greater<T>> template <typename T = uint32_t, typename Comparator = thrust::greater<T>>
std::unique_ptr<xgboost::obj::SegmentSorter<T>> std::unique_ptr<dh::SegmentSorter<T>>
RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices, RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices,
const std::vector<T> &hlabels, const std::vector<T> &hlabels,
const std::vector<T> &expected_sorted_hlabels, const std::vector<T> &expected_sorted_hlabels,
const std::vector<uint32_t> &expected_orig_pos const std::vector<uint32_t> &expected_orig_pos
) { ) {
std::unique_ptr<xgboost::obj::SegmentSorter<T>> seg_sorter_ptr( std::unique_ptr<dh::SegmentSorter<T>> seg_sorter_ptr(new dh::SegmentSorter<T>);
new xgboost::obj::SegmentSorter<T>); dh::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
xgboost::obj::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
// Create a bunch of unsorted labels on the device and sort it via the segment sorter // Create a bunch of unsorted labels on the device and sort it via the segment sorter
dh::device_vector<T> dlabels(hlabels); dh::device_vector<T> dlabels(hlabels);
seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator()); seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator());
EXPECT_EQ(seg_sorter.GetNumItems(), group_indices.back()); auto num_items = seg_sorter.GetItemsSpan().size();
EXPECT_EQ(num_items, group_indices.back());
EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1); EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1);
// Check the labels // Check the labels
dh::device_vector<T> sorted_dlabels(seg_sorter.GetNumItems()); dh::device_vector<T> sorted_dlabels(num_items);
sorted_dlabels.assign(thrust::device_ptr<const T>(seg_sorter.GetItemsPtr()), sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()),
thrust::device_ptr<const T>(seg_sorter.GetItemsPtr()) dh::tcend(seg_sorter.GetItemsSpan()));
+ seg_sorter.GetNumItems());
thrust::host_vector<T> sorted_hlabels(sorted_dlabels); thrust::host_vector<T> sorted_hlabels(sorted_dlabels);
EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels); EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels);
// Check the indices // Check the indices
dh::device_vector<uint32_t> dorig_pos(seg_sorter.GetNumItems()); dh::device_vector<uint32_t> dorig_pos(num_items);
dorig_pos.assign(thrust::device_ptr<const uint32_t>(seg_sorter.GetOriginalPositionsPtr()), dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()),
thrust::device_ptr<const uint32_t>(seg_sorter.GetOriginalPositionsPtr()) dh::tcend(seg_sorter.GetOriginalPositionsSpan()));
+ seg_sorter.GetNumItems());
dh::device_vector<uint32_t> horig_pos(dorig_pos); dh::device_vector<uint32_t> horig_pos(dorig_pos);
EXPECT_EQ(expected_orig_pos, horig_pos); EXPECT_EQ(expected_orig_pos, horig_pos);
@ -152,18 +150,22 @@ TEST(Objective, NDCGLambdaWeightComputerTest) {
// Where will the predictions move from its current position, if they were sorted // Where will the predictions move from its current position, if they were sorted
// descendingly? // descendingly?
auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositions(); auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan();
thrust::host_vector<uint32_t> hsorted_pred_pos(dsorted_pred_pos); std::vector<uint32_t> hsorted_pred_pos(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos);
std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3, std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3,
4, 5, 6, 4, 5, 6,
7, 8, 11, 9, 10}; 7, 8, 11, 9, 10};
EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos); EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos);
// Check group DCG values // Check group DCG values
thrust::host_vector<float> hgroup_dcgs(ndcg_lw_computer.GetGroupDcgs()); std::vector<float> hgroup_dcgs(segment_label_sorter->GetNumGroups());
thrust::host_vector<uint32_t> hgroups(segment_label_sorter->GetGroups()); dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan());
thrust::host_vector<float> hsorted_labels(segment_label_sorter->GetItems()); std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups()); EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups());
std::vector<float> hsorted_labels(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan());
for (auto i = 0; i < hgroup_dcgs.size(); ++i) { for (auto i = 0; i < hgroup_dcgs.size(); ++i) {
// Compute group DCG value on CPU and compare // Compute group DCG value on CPU and compare
auto gbegin = hgroups[i]; auto gbegin = hgroups[i];
@ -193,7 +195,9 @@ TEST(Objective, IndexableSortedItemsTest) {
9, 11, 7, 10, 8}); 9, 11, 7, 10, 8});
segment_label_sorter->CreateIndexableSortedPositions(); segment_label_sorter->CreateIndexableSortedPositions();
thrust::host_vector<uint32_t> sorted_indices(segment_label_sorter->GetIndexableSortedPositions()); std::vector<uint32_t> sorted_indices(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&sorted_indices,
segment_label_sorter->GetIndexableSortedPositionsSpan());
std::vector<uint32_t> expected_sorted_indices = { std::vector<uint32_t> expected_sorted_indices = {
1, 3, 2, 0, 1, 3, 2, 0,
4, 6, 5, 4, 6, 5,
@ -228,11 +232,13 @@ TEST(Objective, ComputeAndCompareMAPStatsTest) {
*segment_label_sorter); *segment_label_sorter);
// Get the device MAP stats on host // Get the device MAP stats on host
thrust::host_vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> dmap_stats( std::vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> dmap_stats(
map_lw_computer.GetMapStats()); segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan());
// Compute the MAP stats on host next to compare // Compute the MAP stats on host next to compare
thrust::host_vector<uint32_t> hgroups(segment_label_sorter->GetGroups()); std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
for (auto i = 0; i < hgroups.size() - 1; ++i) { for (auto i = 0; i < hgroups.size() - 1; ++i) {
auto gbegin = hgroups[i]; auto gbegin = hgroups[i];

View File

@ -40,9 +40,9 @@ void VerifySampling(size_t page_size,
EXPECT_EQ(sample.page->matrix.n_rows, kRows); EXPECT_EQ(sample.page->matrix.n_rows, kRows);
EXPECT_EQ(sample.gpair.size(), kRows); EXPECT_EQ(sample.gpair.size(), kRows);
} else { } else {
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012f); EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.016f);
EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012f); EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.016f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012f); EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.016f);
} }
GradientPair sum_sampled_gpair{}; GradientPair sum_sampled_gpair{};
@ -52,11 +52,11 @@ void VerifySampling(size_t page_size,
sum_sampled_gpair += gp; sum_sampled_gpair += gp;
} }
if (check_sum) { if (check_sum) {
EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows); EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.03f * kRows);
EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows); EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.03f * kRows);
} else { } else {
EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.02f); EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.03f);
EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.02f); EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.03f);
} }
} }