- ndcg ltr implementation on gpu (#5004)

* - ndcg ltr implementation on gpu
  - this is a follow-up to the pairwise ltr implementation
This commit is contained in:
sriramch 2019-11-12 14:21:04 -08:00 committed by Rory Mitchell
parent f4e7b707c9
commit 2abe69d774
5 changed files with 780 additions and 202 deletions

View File

@ -98,7 +98,7 @@ struct EvalAMS : public Metric {
for (bst_omp_uint i = 0; i < ndata; ++i) {
rec[i] = std::make_pair(h_preds[i], i);
}
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
auto ntop = static_cast<unsigned>(ratio_ * ndata);
if (ntop == 0) ntop = ndata;
const double br = 10.0;
@ -168,7 +168,7 @@ struct EvalAuc : public Metric {
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
// calculate AUC
double sum_pospair = 0.0;
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
@ -321,7 +321,7 @@ struct EvalPrecision : public EvalRankList{
protected:
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
// calculate Precision
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhit = 0;
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
nhit += (rec[j].second != 0);
@ -369,7 +369,7 @@ struct EvalMAP : public EvalRankList {
protected:
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
@ -481,7 +481,7 @@ struct EvalAucPR : public Metric {
total_neg += wt * (1.0f - h_labels[j]);
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
// we need pos > 0 && neg > 0
if (0.0 == total_pos || 0.0 == total_neg) {
auc_error += 1;

View File

@ -29,7 +29,7 @@
namespace xgboost {
namespace obj {
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(rank_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
@ -46,6 +46,185 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
}
};
#if defined(__CUDACC__)
// This type sorts an array which is divided into multiple groups. The sorting is influenced
// by the function object 'Comparator'
template <typename T>
class SegmentSorter {
private:
// Items sorted within the group
dh::caching_device_vector<T> ditems_;
// Original position of the items before they are sorted descendingly within its groups
dh::caching_device_vector<uint32_t> doriginal_pos_;
// Segments within the original list that delineates the different groups
dh::caching_device_vector<uint32_t> group_segments_;
// Need this on the device as it is used in the kernels
dh::caching_device_vector<uint32_t> dgroups_; // Group information on device
// Initialize everything but the segments
void Init(uint32_t num_elems) {
ditems_.resize(num_elems);
doriginal_pos_.resize(num_elems);
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
}
// Initialize all with group info
void Init(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
this->Init(num_elems);
this->CreateGroupSegments(groups);
}
public:
// This needs to be public due to device lambda
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
group_segments_.resize(num_elems);
dgroups_ = groups;
// Launch a kernel that populates the segment information for the different groups
uint32_t *gsegs = group_segments_.data().get();
const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
dh::LaunchN(device_id, num_elems, nullptr, [=] __device__(uint32_t idx){
// Find the group first
uint32_t group_idx = dh::UpperBound(dgroups, ngroups, idx);
gsegs[idx] = group_idx - 1;
});
}
// Accessors that returns device pointer
inline const T *Items() const { return ditems_.data().get(); }
inline uint32_t NumItems() const { return ditems_.size(); }
inline const uint32_t *OriginalPositions() const { return doriginal_pos_.data().get(); }
inline const dh::caching_device_vector<uint32_t> &GroupSegments() const {
return group_segments_;
}
inline uint32_t NumGroups() const { return dgroups_.size() - 1; }
inline const uint32_t *GroupIndices() const { return dgroups_.data().get(); }
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the host.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
const Comparator &comp = Comparator()) {
this->Init(groups);
this->SortItems(ditems, item_size, group_segments_, comp);
}
// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the device.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size,
const dh::caching_device_vector<uint32_t> &group_segments,
const Comparator &comp = Comparator()) {
this->Init(item_size);
// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
// when comparators are used. Hence, the following algorithm is used. This is done so that
// we can grab the appropriate related values from the original list later, after the
// items are sorted.
//
// Here is the internal representation:
// dgroups_: [ 0, 3, 5, 8, 10 ]
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
//
// Sort the items first and make a note of the original positions in doriginal_pos_
// based on the sort
// ditems_: 4 4 3 3 2 1 1 1 1 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
// in kernel, sorting using predicates etc.
ditems_.assign(thrust::device_ptr<const T>(ditems),
thrust::device_ptr<const T>(ditems) + item_size);
// Allocator to be used by sort for managing space overhead while sorting
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
ditems_.begin(), ditems_.end(),
doriginal_pos_.begin(), comp);
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
// holisitic item sort order on the segments
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
dh::caching_device_vector<uint32_t> group_segments_c(group_segments);
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
group_segments.begin(), group_segments_c.begin());
// Now, sort the group segments so that you may bring the items within the group together,
// in the process also noting the relative changes to the doriginal_pos_ while that happens
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
group_segments_c.begin(), group_segments_c.end(),
doriginal_pos_.begin(), thrust::less<uint32_t>());
// Finally, gather the original items based on doriginal_pos_ to sort the input and
// to store them in ditems_
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
thrust::device_ptr<const T>(ditems), ditems_.begin());
}
};
// Helper functions
// Items of size 'n' are sorted in a descending order
// If left is true, find the number of elements > v; 0 if nothing is greater
// If left is false, find the number of elements < v; 0 if nothing is lesser
template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v) {
const T *items_begin = items;
uint32_t num_remaining = n;
const T *middle_item = nullptr;
uint32_t middle;
while (num_remaining > 0) {
middle_item = items_begin;
middle = num_remaining / 2;
middle_item += middle;
if ((left && *middle_item > v) || (!left && !(v > *middle_item))) {
items_begin = ++middle_item;
num_remaining -= middle + 1;
} else {
num_remaining = middle;
}
}
return left ? items_begin - items : items + n - items_begin;
}
template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheLeftOf(const T * __restrict__ items, uint32_t n, T v) {
return CountNumItemsImpl(true, items, n, v);
}
template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheRightOf(const T * __restrict__ items, uint32_t n, T v) {
return CountNumItemsImpl(false, items, n, v);
}
#endif
/*! \brief helper information in a list */
struct ListEntry {
/*! \brief the predict score we in the data */
@ -96,16 +275,133 @@ struct PairwiseLambdaWeightComputer {
return "rank:pairwise";
}
// Stopgap method - will be removed when we support other type of ranking - ndcg, map etc.
// Stopgap method - will be removed when we support other type of ranking - map
// on GPU later
inline static bool SupportOnGPU() { return true; }
#if defined(__CUDACC__)
PairwiseLambdaWeightComputer(const bst_float *dpreds,
uint32_t pred_size,
const SegmentSorter<float> &segment_label_sorter) {}
struct PairwiseLambdaWeightMultiplier {
// Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
return 1.0f;
}
};
inline PairwiseLambdaWeightMultiplier GetWeightMultiplier() const {
return {};
}
#endif
};
// beta version: NDCG lambda rank
struct NDCGLambdaWeightComputer {
// Stopgap method - will be removed when we support other type of ranking - ndcg, map etc.
public:
#if defined(__CUDACC__)
// This function object computes the group's DCG for a given group
struct ComputeGroupDCG {
public:
XGBOOST_DEVICE ComputeGroupDCG(const float *dsorted_labels, const uint32_t *dgroups)
: dsorted_labels_(dsorted_labels),
dgroups_(dgroups) {}
// Compute DCG for group 'gidx'
__device__ __forceinline__ float operator()(uint32_t gidx) const {
uint32_t group_begin = dgroups_[gidx];
uint32_t group_end = dgroups_[gidx + 1];
uint32_t group_size = group_end - group_begin;
return ComputeGroupDCGWeight(&dsorted_labels_[group_begin], group_size);
}
private:
const float *dsorted_labels_{nullptr}; // Labels sorted within a group
const uint32_t *dgroups_{nullptr}; // The group indices - where each group begins and ends
};
// Type containing device pointers that can be cheaply copied on the kernel
class NDCGLambdaWeightMultiplier {
public:
NDCGLambdaWeightMultiplier(const float *dsorted_labels,
const uint32_t *dorig_pos,
const uint32_t *dgroups,
const float *dgroup_dcg_ptr,
uint32_t *dindexable_sorted_preds_pos_ptr)
: dsorted_labels_(dsorted_labels),
dorig_pos_(dorig_pos),
dgroups_(dgroups),
dgroup_dcg_ptr_(dgroup_dcg_ptr),
dindexable_sorted_preds_pos_ptr_(dindexable_sorted_preds_pos_ptr) {}
// Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
if (dgroup_dcg_ptr_[gidx] == 0.0) return 0.0f;
uint32_t group_begin = dgroups_[gidx];
auto ppred_idx = dorig_pos_[pidx];
auto npred_idx = dorig_pos_[nidx];
KERNEL_CHECK(ppred_idx != npred_idx);
// Note: the label positive and negative indices are relative to the entire dataset.
// Hence, scale them back to an index within the group
ppred_idx = dindexable_sorted_preds_pos_ptr_[ppred_idx] - group_begin;
npred_idx = dindexable_sorted_preds_pos_ptr_[npred_idx] - group_begin;
return NDCGLambdaWeightComputer::ComputeDeltaWeight(
ppred_idx, npred_idx,
static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]),
dgroup_dcg_ptr_[gidx]);
}
private:
const float *dsorted_labels_{nullptr}; // Labels sorted within a group
const uint32_t *dorig_pos_{nullptr}; // Original indices of the labels before they are sorted
const uint32_t *dgroups_{nullptr}; // The group indices
const float *dgroup_dcg_ptr_{nullptr}; // Start address of the group DCG values
// Where can a prediction for a label be found in the original array, when they are sorted
uint32_t *dindexable_sorted_preds_pos_ptr_{nullptr};
};
NDCGLambdaWeightComputer(const bst_float *dpreds,
uint32_t pred_size,
const SegmentSorter<float> &segment_label_sorter)
: dgroup_dcg_(segment_label_sorter.NumGroups()),
dindexable_sorted_preds_pos_(pred_size),
weight_multiplier_(segment_label_sorter.Items(),
segment_label_sorter.OriginalPositions(),
segment_label_sorter.GroupIndices(),
dgroup_dcg_.data().get(),
dindexable_sorted_preds_pos_.data().get()) {
// Sort the predictions first and get the sorted position
SegmentSorter<float> segment_prediction_sorter;
segment_prediction_sorter.SortItems(dpreds, pred_size, segment_label_sorter.GroupSegments());
this->CreateIndexableSortedPredictionPositions(segment_prediction_sorter.OriginalPositions());
// Compute each group's DCG concurrently
// Set the values to be the group indices first so that the predicate knows which
// group it is dealing with
thrust::sequence(dgroup_dcg_.begin(), dgroup_dcg_.end());
// TODO(sriramch): parallelize across all elements, if possible
// Transform each group - the predictate computes the group's DCG
thrust::transform(dgroup_dcg_.begin(), dgroup_dcg_.end(),
dgroup_dcg_.begin(),
ComputeGroupDCG(segment_label_sorter.Items(),
segment_label_sorter.GroupIndices()));
}
inline NDCGLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; }
inline const dh::caching_device_vector<uint32_t> &GetSortedPredPos() const {
return dindexable_sorted_preds_pos_;
}
#endif
// Stopgap method - will be removed when we support other type of ranking - map
// on GPU later
inline static bool SupportOnGPU() { return false; }
inline static bool SupportOnGPU() { return true; }
static void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) {
@ -116,29 +412,20 @@ struct NDCGLambdaWeightComputer {
for (size_t i = 0; i < sorted_list.size(); ++i) {
labels[i] = sorted_list[i].label;
}
std::sort(labels.begin(), labels.end(), std::greater<bst_float>());
IDCG = CalcDCG(labels);
std::stable_sort(labels.begin(), labels.end(), std::greater<bst_float>());
IDCG = ComputeGroupDCGWeight(&labels[0], labels.size());
}
if (IDCG == 0.0) {
for (auto & pair : pairs) {
pair.weight = 0.0f;
}
} else {
IDCG = 1.0f / IDCG;
for (auto & pair : pairs) {
unsigned pos_idx = pair.pos_index;
unsigned neg_idx = pair.neg_index;
float pos_loginv = 1.0f / std::log2(pos_idx + 2.0f);
float neg_loginv = 1.0f / std::log2(neg_idx + 2.0f);
auto pos_label = static_cast<int>(sorted_list[pos_idx].label);
auto neg_label = static_cast<int>(sorted_list[neg_idx].label);
bst_float original =
((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
float changed =
((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
bst_float delta = (original - changed) * IDCG;
if (delta < 0.0f) delta = - delta;
pair.weight *= delta;
pair.weight *= ComputeDeltaWeight(pos_idx, neg_idx,
sorted_list[pos_idx].label, sorted_list[neg_idx].label,
IDCG);
}
}
}
@ -148,16 +435,77 @@ struct NDCGLambdaWeightComputer {
}
private:
inline static bst_float CalcDCG(const std::vector<bst_float> &labels) {
XGBOOST_DEVICE inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels,
uint32_t size) {
double sumdcg = 0.0;
for (size_t i = 0; i < labels.size(); ++i) {
const auto rel = static_cast<unsigned>(labels[i]);
for (uint32_t i = 0; i < size; ++i) {
const auto rel = static_cast<unsigned>(sorted_labels[i]);
if (rel != 0) {
sumdcg += ((1 << rel) - 1) / std::log2(static_cast<bst_float>(i + 2));
}
}
return static_cast<bst_float>(sumdcg);
}
// Compute the weight adjustment for an item within a group:
// ppred_idx => Where does the positive label live, had the list been sorted by prediction
// npred_idx => Where does the negative label live, had the list been sorted by prediction
// pos_label => positive label value from sorted label list
// neg_label => negative label value from sorted label list
XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t ppred_idx, uint32_t npred_idx,
int pos_label, int neg_label,
float idcg) {
float pos_loginv = 1.0f / std::log2(ppred_idx + 2.0f);
float neg_loginv = 1.0f / std::log2(npred_idx + 2.0f);
bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
bst_float delta = (original - changed) * (1.0f / idcg);
if (delta < 0.0f) delta = - delta;
return delta;
}
#if defined(__CUDACC__)
// While computing the weight that needs to be adjusted by this ranking objective, we need
// to figure out where positive and negative labels chosen earlier exists, if the group
// were to be sorted by its predictions. To accommodate this, we employ the following algorithm.
// For a given group, let's assume the following:
// labels: 1 5 9 2 4 8 0 7 6 3
// predictions: 1 9 0 8 2 7 3 6 5 4
// position: 0 1 2 3 4 5 6 7 8 9
//
// After label sort:
// labels: 9 8 7 6 5 4 3 2 1 0
// position: 2 5 7 8 1 4 9 3 0 6
//
// After prediction sort:
// predictions: 9 8 7 6 5 4 3 2 1 0
// position: 1 3 5 7 8 9 6 4 0 2
//
// If a sorted label at position 'x' is chosen, then we need to find out where the prediction
// for this label 'x' exists, if the group were to be sorted by predictions.
// We first take the sorted prediction positions:
// position: 1 3 5 7 8 9 6 4 0 2
// at indices: 0 1 2 3 4 5 6 7 8 9
//
// We create a sorted prediction positional array, such that value at position 'x' gives
// us the position in the sorted prediction array where its related prediction lies.
// dindexable_sorted_preds_pos_ptr_: 8 0 9 1 7 2 6 3 4 5
// at indices: 0 1 2 3 4 5 6 7 8 9
// Basically, swap the previous 2 arrays, sort the indices and reorder positions
// for an O(1) lookup using the position where the sorted label exists
void CreateIndexableSortedPredictionPositions(const uint32_t *dsorted_preds_pos) {
dh::caching_device_vector<uint32_t> indices(dindexable_sorted_preds_pos_.size());
thrust::sequence(indices.begin(), indices.end());
thrust::scatter(indices.begin(), indices.end(), // Rearrange indices...
thrust::device_ptr<const uint32_t>(dsorted_preds_pos), // ...based on this map
dindexable_sorted_preds_pos_.begin()); // Write results into this
}
dh::caching_device_vector<float> dgroup_dcg_;
// Where can a prediction for a label be found in the original array, when they are sorted
dh::caching_device_vector<uint32_t> dindexable_sorted_preds_pos_;
NDCGLambdaWeightMultiplier weight_multiplier_; // This computes the adjustment to the weight
#endif
};
struct MAPLambdaWeightComputer {
@ -238,7 +586,7 @@ struct MAPLambdaWeightComputer {
}
public:
// Stopgap method - will be removed when we support other type of ranking - ndcg, map etc.
// Stopgap method - will be removed when we support other type of ranking - map
// on GPU later
inline static bool SupportOnGPU() { return false; }
@ -257,177 +605,79 @@ struct MAPLambdaWeightComputer {
pair.neg_index, &map_stats);
}
}
#if defined(__CUDACC__)
MAPLambdaWeightComputer(const bst_float *dpreds,
uint32_t pred_size,
const SegmentSorter<float> &segment_label_sorter) {}
struct MAPLambdaWeightMultiplier {
// Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
return 1.0f;
}
};
inline MAPLambdaWeightMultiplier GetWeightMultiplier() const {
return {};
}
#endif
};
#if defined(__CUDACC__)
// Helper functions
// Labels of size 'n' are sorted in a descending order
// If left is true, find the number of elements > v; 0 if nothing is greater
// If left is false, find the number of elements < v; 0 if nothing is lesser
__device__ __forceinline__ int
CountNumLabelsImpl(bool left, const float * __restrict__ labels, int n, float v) {
const float *labels_begin = labels;
int num_remaining = n;
const float *middle_item = nullptr;
int middle;
while (num_remaining > 0) {
middle_item = labels_begin;
middle = num_remaining / 2;
middle_item += middle;
if ((left && *middle_item > v) || (!left && !(v > *middle_item))) {
labels_begin = ++middle_item;
num_remaining -= middle + 1;
} else {
num_remaining = middle;
}
}
return left ? labels_begin - labels : labels + n - labels_begin;
}
__device__ __forceinline__ int
CountNumLabelsToTheLeftOf(const float * __restrict__ labels, int n, float v) {
return CountNumLabelsImpl(true, labels, n, v);
}
__device__ __forceinline__ int
CountNumLabelsToTheRightOf(const float * __restrict__ labels, int n, float v) {
return CountNumLabelsImpl(false, labels, n, v);
}
class SortedLabelList {
class SortedLabelList : SegmentSorter<float> {
private:
// Labels sorted within the group
dh::caching_device_vector<float> dsorted_labels_;
// Original position of the labels before they are sorted descendingly within its groups
dh::caching_device_vector<uint32_t> doriginal_pos_;
// Segments within the original list that delineates the different groups
dh::caching_device_vector<uint32_t> group_segments_;
// Need this on the device as it is used in the kernels
dh::caching_device_vector<uint32_t> dgroups_; // Group information on device
int device_id_{-1}; // GPU device ID
const LambdaRankParam &param_; // Objective configuration
dh::XGBCachingDeviceAllocator<char> alloc_; // Allocator to be used by sort for managing
// space overhead while sorting
public:
SortedLabelList(int dev_id,
const LambdaRankParam &param)
: device_id_(dev_id),
param_(param) {}
explicit SortedLabelList(const LambdaRankParam &param)
: param_(param) {}
void InitWithTrainingInfo(const std::vector<uint32_t> &groups) {
int num_elems = groups.back();
dsorted_labels_.resize(num_elems);
doriginal_pos_.resize(num_elems);
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
group_segments_.resize(num_elems);
dgroups_ = groups;
// Launch a kernel that populates the segment information for the different groups
uint32_t *gsegs = group_segments_.data().get();
const unsigned *dgroups = dgroups_.data().get();
size_t ngroups = dgroups_.size();
dh::LaunchN(device_id_, num_elems, nullptr, [=] __device__(unsigned idx){
// Find the group first
int group_idx = dh::UpperBound(dgroups, ngroups, idx);
gsegs[idx] = group_idx - 1;
});
}
// Sort the groups by labels. We would like to avoid using predicates to perform the sort,
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
// when comparators are used. Hence, the following algorithm is used. This is done so that
// we can grab the appropriate prediction values from the original list later, after the
// labels are sorted.
//
// Here is the internal representation:
// dgroups_: [ 0, 3, 5, 8, 10 ]
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
// dsorted_labels_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original labels)
//
// Sort the labels first and make a note of the original positions in doriginal_pos_
// based on the sort
// dsorted_labels_: 4 4 3 3 2 1 1 1 1 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
// in kernel, sorting using predicates etc.
void Sort(const HostDeviceVector<bst_float> &dlabels) {
dsorted_labels_.assign(dh::tcbegin(dlabels), dh::tcend(dlabels));
thrust::stable_sort_by_key(thrust::cuda::par(alloc_),
dsorted_labels_.begin(), dsorted_labels_.end(),
doriginal_pos_.begin(), thrust::greater<float>());
// Next, gather the segments based on the doriginal_pos_. This is to reflect the
// holisitic label sort order on the segments
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
thrust::device_vector<uint32_t> group_segments_c(group_segments_);
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
group_segments_.begin(), group_segments_c.begin());
// Now, sort the group segments so that you may bring the labels within the group together,
// in the process also noting the relative changes to the doriginal_pos_ while that happens
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
thrust::stable_sort_by_key(thrust::cuda::par(alloc_),
group_segments_c.begin(), group_segments_c.end(),
doriginal_pos_.begin(), thrust::less<int>());
// Finally, gather the original labels based on doriginal_pos_ to sort the input and
// to store them in dsorted_labels_
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
// dsorted_labels_: 1 1 0 2 1 3 3 1 4 4 (from unsorted dlabels - dlabels)
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
dh::tcbegin(dlabels), dsorted_labels_.begin());
}
~SortedLabelList() {
dh::safe_cuda(cudaSetDevice(device_id_));
// Sort the labels that are grouped by 'groups'
void Sort(const HostDeviceVector<bst_float> &dlabels, const std::vector<uint32_t> &groups) {
this->SortItems(dlabels.ConstDevicePointer(), dlabels.Size(), groups);
}
// This kernel can only run *after* the kernel in sort is completed, as they
// use the default stream
template <typename LambdaWeightComputerT>
void ComputeGradients(const bst_float *dpreds,
GradientPair *out_gpair,
const HostDeviceVector<bst_float> &weights,
int iter,
GradientPair *out_gpair,
float weight_normalization_factor) {
// Group info on device
const unsigned *dgroups = dgroups_.data().get();
size_t ngroups = dgroups_.size();
const uint32_t *dgroups = this->GroupIndices();
uint32_t ngroups = this->NumGroups() + 1;
auto total_items = group_segments_.size();
size_t niter = param_.num_pairsample * total_items;
uint32_t total_items = this->NumItems();
uint32_t niter = param_.num_pairsample * total_items;
float fix_list_weight = param_.fix_list_weight;
const uint32_t *original_pos = doriginal_pos_.data().get();
const uint32_t *original_pos = this->OriginalPositions();
size_t num_weights = weights.Size();
uint32_t num_weights = weights.Size();
auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr;
const bst_float *sorted_labels = dsorted_labels_.data().get();
const bst_float *sorted_labels = this->Items();
// This is used to adjust the weight of different elements based on the different ranking
// objective function policies
LambdaWeightComputerT weight_computer(dpreds, total_items, *this);
auto wmultiplier = weight_computer.GetWeightMultiplier();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each instance in the group, compute the gradient pair concurrently
dh::LaunchN(device_id_, niter, nullptr, [=] __device__(size_t idx) {
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
// First, determine the group 'idx' belongs to
unsigned item_idx = idx % total_items;
int group_idx = dh::UpperBound(dgroups, ngroups, item_idx);
uint32_t item_idx = idx % total_items;
uint32_t group_idx = dh::UpperBound(dgroups, ngroups, item_idx);
// Span of this group within the larger labels/predictions sorted tuple
int group_begin = dgroups[group_idx - 1];
int group_end = dgroups[group_idx];
int total_group_items = group_end - group_begin;
uint32_t group_begin = dgroups[group_idx - 1];
uint32_t group_end = dgroups[group_idx];
uint32_t total_group_items = group_end - group_begin;
// Are the labels diverse enough? If they are all the same, then there is nothing to pick
// from another group - bail sooner
@ -435,14 +685,14 @@ class SortedLabelList {
// Find the number of labels less than and greater than the current label
// at the sorted index position item_idx
int nleft = CountNumLabelsToTheLeftOf(
sorted_labels + group_begin, total_group_items, sorted_labels[item_idx]);
int nright = CountNumLabelsToTheRightOf(
sorted_labels + group_begin, total_group_items, sorted_labels[item_idx]);
uint32_t nleft = CountNumItemsToTheLeftOf(
sorted_labels + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]);
uint32_t nright = CountNumItemsToTheRightOf(
sorted_labels + item_idx, group_end - item_idx, sorted_labels[item_idx]);
// Create a minstd_rand object to act as our source of randomness
thrust::minstd_rand rng;
rng.discard(idx);
thrust::minstd_rand rng((iter + 1) * 1111);
rng.discard(((idx / total_items) * total_group_items) + item_idx - group_begin);
// Create a uniform_int_distribution to produce a sample from outside of the
// present label group
thrust::uniform_int_distribution<int> dist(0, nleft + nright - 1);
@ -457,7 +707,7 @@ class SortedLabelList {
neg_idx = item_idx;
} else {
pos_idx = item_idx;
int items_in_group = total_group_items - nleft - nright;
uint32_t items_in_group = total_group_items - nleft - nright;
neg_idx = sample + items_in_group + group_begin;
}
@ -468,13 +718,14 @@ class SortedLabelList {
bst_float h = thrust::max(p * (1.0f - p), eps);
// Rescale each gradient and hessian so that the group has a weighted constant
float scale = 1.0f / (niter / total_items);
float scale = __frcp_ru(niter / total_items);
if (fix_list_weight != 0.0f) {
scale *= fix_list_weight / total_group_items;
}
float weight = num_weights ? dweights[group_idx - 1] : 1.0f;
weight *= weight_normalization_factor;
weight *= wmultiplier.GetWeight(group_idx - 1, pos_idx, neg_idx);
weight *= scale;
// Accumulate gradient and hessian in both positive and negative indices
const GradientPair in_pos_gpair(g * weight, 2.0f * weight * h);
@ -483,6 +734,9 @@ class SortedLabelList {
const GradientPair in_neg_gpair(-g * weight, 2.0f * weight * h);
dh::AtomicAddGpair(&out_gpair[original_pos[neg_idx]], in_neg_gpair);
});
// Wait until the computations done by the kernel is complete
dh::safe_cuda(cudaStreamSynchronize(nullptr));
}
};
#endif
@ -512,7 +766,7 @@ class LambdaRankObj : public ObjFunction {
// Check if we have a GPU assignment; else, revert back to CPU
auto device = tparam_->gpu_id;
if (device >= 0 && LambdaWeightComputerT::SupportOnGPU()) {
ComputeGradientsOnGPU(preds, info, out_gpair, gptr);
ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr);
} else {
// Revert back to CPU
#endif
@ -558,20 +812,21 @@ class LambdaRankObj : public ObjFunction {
LOG(DEBUG) << "Computing pairwise gradients on CPU.";
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);
const auto& preds_h = preds.HostVector();
const auto& labels = info.labels_.HostVector();
std::vector<GradientPair>& gpair = out_gpair->HostVector();
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
out_gpair->Resize(preds.Size());
#pragma omp parallel
{
// parallel construct, declare random number generator here, so that each
// thread use its own random number generator, seed by thread id and current iteration
common::RandomEngine rnd(iter * 1111 + omp_get_thread_num());
std::minstd_rand rnd((iter + 1) * 1111);
std::vector<LambdaPair> pairs;
std::vector<ListEntry> lst;
std::vector< std::pair<bst_float, unsigned> > rec;
const auto& preds_h = preds.HostVector();
const auto& labels = info.labels_.HostVector();
std::vector<GradientPair>& gpair = out_gpair->HostVector();
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
@ -580,12 +835,12 @@ class LambdaRankObj : public ObjFunction {
lst.emplace_back(preds_h[j], labels[j], j);
gpair[j] = GradientPair(0.0f, 0.0f);
}
std::sort(lst.begin(), lst.end(), ListEntry::CmpPred);
std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred);
rec.resize(lst.size());
for (unsigned i = 0; i < lst.size(); ++i) {
rec[i] = std::make_pair(lst[i].label, i);
}
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
// enumerate buckets with same label, for each item in the lst, grab another sample randomly
for (unsigned i = 0; i < rec.size(); ) {
unsigned j = i + 1;
@ -635,6 +890,7 @@ class LambdaRankObj : public ObjFunction {
#if defined(__CUDACC__)
void ComputeGradientsOnGPU(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair,
const std::vector<unsigned> &gptr) {
LOG(DEBUG) << "Computing pairwise gradients on GPU.";
@ -655,33 +911,24 @@ class LambdaRankObj : public ObjFunction {
auto d_preds = preds.ConstDevicePointer();
auto d_gpair = out_gpair->DevicePointer();
if (!slist_) {
slist_.reset(new SortedLabelList(device, param_));
}
// Create segments based on group info
slist_->InitWithTrainingInfo(gptr);
SortedLabelList slist(param_);
// Sort the labels within the groups on the device
slist_->Sort(info.labels_);
slist.Sort(info.labels_, gptr);
// Initialize the gradients next
out_gpair->Fill(GradientPair(0.0f, 0.0f));
// Finally, compute the gradients
slist_->ComputeGradients(d_preds, d_gpair, info.weights_, weight_normalization_factor);
// Wait until the computations done by the kernel is complete
dh::safe_cuda(cudaStreamSynchronize(nullptr));
slist.ComputeGradients<LambdaWeightComputerT>
(d_preds, info.weights_, iter, d_gpair, weight_normalization_factor);
}
#endif
LambdaRankParam param_;
#if defined(__CUDACC__)
std::unique_ptr<SortedLabelList> slist_;
#endif
};
#if !defined(GTEST_TEST)
// register the objective functions
DMLC_REGISTER_PARAMETER(LambdaRankParam);
@ -696,6 +943,7 @@ XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, NDCGLambdaWeightComputer::Name())
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name())
.describe("LambdaRank with MAP as objective.")
.set_body([]() { return new LambdaRankObj<MAPLambdaWeightComputer>(); });
#endif
} // namespace obj
} // namespace xgboost

View File

@ -76,4 +76,33 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<xgboost::ObjFunction> obj {
xgboost::ObjFunction::Create("rank:ndcg", &lparam)
};
obj->Configure(args);
CheckConfigReload(obj, "rank:ndcg");
// Test with setting sample weight to second query group
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{0.7f, -0.7f, 0.0f, 0.0f},
{0.74f, 0.74f, 0.0f, 0.0f});
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.35f, -0.35f, 0.35f, -0.35f},
{0.368f, 0.368f, 0.368f, 0.368f});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
} // namespace xgboost

View File

@ -1 +1,159 @@
#include "test_ranking_obj.cc"
#include "../../../src/objective/rank_obj.cu"
namespace xgboost {
template <typename T = uint32_t, typename Comparator = thrust::greater<T>>
std::unique_ptr<xgboost::obj::SegmentSorter<T>>
RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices,
const std::vector<T> &hlabels,
const std::vector<T> &expected_sorted_hlabels,
const std::vector<uint32_t> &expected_orig_pos
) {
std::unique_ptr<xgboost::obj::SegmentSorter<T>> seg_sorter_ptr(
new xgboost::obj::SegmentSorter<T>);
xgboost::obj::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
// Create a bunch of unsorted labels on the device and sort it via the segment sorter
dh::device_vector<T> dlabels(hlabels);
seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator());
EXPECT_EQ(seg_sorter.NumItems(), group_indices.back());
EXPECT_EQ(seg_sorter.NumGroups(), group_indices.size() - 1);
// Check the labels
dh::device_vector<T> sorted_dlabels(seg_sorter.NumItems());
sorted_dlabels.assign(thrust::device_ptr<const T>(seg_sorter.Items()),
thrust::device_ptr<const T>(seg_sorter.Items())
+ seg_sorter.NumItems());
thrust::host_vector<T> sorted_hlabels(sorted_dlabels);
EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels);
// Check the indices
dh::device_vector<uint32_t> dorig_pos(seg_sorter.NumItems());
dorig_pos.assign(thrust::device_ptr<const uint32_t>(seg_sorter.OriginalPositions()),
thrust::device_ptr<const uint32_t>(seg_sorter.OriginalPositions())
+ seg_sorter.NumItems());
dh::device_vector<uint32_t> horig_pos(dorig_pos);
EXPECT_EQ(expected_orig_pos, horig_pos);
return seg_sorter_ptr;
}
TEST(Objective, RankSegmentSorterTest) {
RankSegmentSorterTestImpl({0, 2, 4, 7, 10, 14, 18, 22, 26}, // Groups
{1, 1, // Labels
1, 2,
3, 2, 1,
1, 2, 1,
1, 3, 4, 2,
1, 2, 1, 1,
1, 2, 2, 3,
3, 3, 1, 2},
{1, 1, // Expected sorted labels
2, 1,
3, 2, 1,
2, 1, 1,
4, 3, 2, 1,
2, 1, 1, 1,
3, 2, 2, 1,
3, 3, 2, 1},
{0, 1, // Expected original positions
3, 2,
4, 5, 6,
8, 7, 9,
12, 11, 13, 10,
15, 14, 16, 17,
21, 19, 20, 18,
22, 23, 25, 24});
}
TEST(Objective, RankSegmentSorterSingleGroupTest) {
RankSegmentSorterTestImpl({0, 7}, // Groups
{6, 1, 4, 3, 0, 5, 2}, // Labels
{6, 5, 4, 3, 2, 1, 0}, // Expected sorted labels
{0, 5, 2, 3, 6, 1, 4}); // Expected original positions
}
TEST(Objective, RankSegmentSorterAscendingTest) {
RankSegmentSorterTestImpl<uint32_t, thrust::less<uint32_t>>(
{0, 4, 7}, // Groups
{3, 1, 4, 2, // Labels
6, 5, 7},
{1, 2, 3, 4, // Expected sorted labels
5, 6, 7},
{1, 3, 0, 2, // Expected original positions
5, 4, 6});
}
using CountFunctor = uint32_t (*)(const int *, uint32_t, int);
void RankItemCountImpl(const std::vector<int> &sorted_items, CountFunctor f,
int find_val, uint32_t exp_val) {
EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end());
EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val);
}
TEST(Objective, RankItemCountOnLeft) {
// Items sorted descendingly
std::vector<int> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
10, static_cast<uint32_t>(0));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
6, static_cast<uint32_t>(2));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
4, static_cast<uint32_t>(3));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
1, static_cast<uint32_t>(7));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
0, static_cast<uint32_t>(12));
}
TEST(Objective, RankItemCountOnRight) {
// Items sorted descendingly
std::vector<int> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
10, static_cast<uint32_t>(11));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
6, static_cast<uint32_t>(10));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
4, static_cast<uint32_t>(6));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
1, static_cast<uint32_t>(1));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
0, static_cast<uint32_t>(0));
}
TEST(Objective, NDCGLambdaWeightComputerTest) {
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
{0, 4, 7, 12}, // Groups
{3.1f, 1.2f, 2.3f, 4.4f, // Labels
7.8f, 5.01f, 6.96f,
10.3f, 8.7f, 11.4f, 9.45f, 11.4f},
{4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels
7.8f, 6.96f, 5.01f,
11.4f, 11.4f, 10.3f, 9.45f, 8.7f},
{3, 0, 2, 1, // Expected original positions
4, 6, 5,
9, 11, 7, 10, 8});
// Created segmented predictions for the labels from above
std::vector<bst_float> hpreds{-9.78f, 24.367f, 0.908f, -11.47f,
-1.03f, -2.79f, -3.1f,
104.22f, 103.1f, -101.7f, 100.5f, 45.1f};
dh::device_vector<bst_float> dpreds(hpreds);
xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(),
dpreds.size(),
*segment_label_sorter);
// Where will the predictions move from its current position, if they were sorted
// descendingly?
auto dsorted_pred_pos = ndcg_lw_computer.GetSortedPredPos();
thrust::host_vector<uint32_t> hsorted_pred_pos(dsorted_pred_pos);
std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3,
4, 5, 6,
7, 8, 11, 9, 10};
EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos);
}
} // namespace xgboost

View File

@ -0,0 +1,143 @@
import numpy as np
from scipy.sparse import csr_matrix
import xgboost
import os
import math
import unittest
import itertools
import shutil
import urllib.request
import zipfile
class TestRanking(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""
Download and setup the test fixtures
"""
from sklearn.datasets import load_svmlight_files
# download the test data
cls.dpath = 'demo/rank/'
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = cls.dpath + '/MQ2008.zip'
if os.path.exists(cls.dpath) and os.path.exists(target):
print ("Skipping dataset download...")
else:
urllib.request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, 'r') as f:
f.extractall(path=cls.dpath)
(x_train, y_train, qid_train, x_test, y_test, qid_test,
x_valid, y_valid, qid_valid) = load_svmlight_files(
(cls.dpath + "MQ2008/Fold1/train.txt",
cls.dpath + "MQ2008/Fold1/test.txt",
cls.dpath + "MQ2008/Fold1/vali.txt"),
query_id=True, zero_based=False)
# instantiate the matrices
cls.dtrain = xgboost.DMatrix(x_train, y_train)
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
cls.dtest = xgboost.DMatrix(x_test, y_test)
# set the group counts from the query IDs
cls.dtrain.set_group([len(list(items))
for _key, items in itertools.groupby(qid_train)])
cls.dtest.set_group([len(list(items))
for _key, items in itertools.groupby(qid_test)])
cls.dvalid.set_group([len(list(items))
for _key, items in itertools.groupby(qid_valid)])
# save the query IDs for testing
cls.qid_train = qid_train
cls.qid_test = qid_test
cls.qid_valid = qid_valid
# model training parameters
cls.params = {'booster': 'gbtree',
'tree_method': 'gpu_hist',
'gpu_id': 0,
'predictor': 'gpu_predictor'
}
cls.cpu_params = {'booster': 'gbtree',
'tree_method': 'hist',
'gpu_id': -1,
'predictor': 'cpu_predictor'
}
@classmethod
def tearDownClass(cls):
"""
Cleanup test artifacts from download and unpacking
:return:
"""
os.remove(cls.dpath + "MQ2008.zip")
shutil.rmtree(cls.dpath + "MQ2008")
@classmethod
def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02):
"""
Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates
the metric and determines if the delta between the metric is within the tolerance level
:return:
"""
# specify validations set to watch performance
watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')]
num_trees=2500
check_metric_improvement_rounds=10
evals_result = {}
cls.params['objective'] = rank_objective
cls.params['eval_metric'] = metric_name
bst = xgboost.train(cls.params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
gpu_map_metric = evals_result['train'][metric_name][-1]
evals_result = {}
cls.cpu_params['objective'] = rank_objective
cls.cpu_params['eval_metric'] = metric_name
bstc = xgboost.train(cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
cpu_map_metric = evals_result['train'][metric_name][-1]
print("{0} gpu {1} metric {2}".format(rank_objective, metric_name, gpu_map_metric))
print("{0} cpu {1} metric {2}".format(rank_objective, metric_name, cpu_map_metric))
print("gpu best score {0} cpu best score {1}".format(bst.best_score, bstc.best_score))
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, tolerance)
assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance)
def test_training_rank_pairwise_map_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'map')
def test_training_rank_pairwise_auc_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'auc')
def test_training_rank_pairwise_ndcg_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'ndcg')
def test_training_rank_ndcg_map(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'map')
def test_training_rank_ndcg_auc(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'auc')
def test_training_rank_ndcg_ndcg(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'ndcg')