Move segment sorter to common (#5378)

- move segment sorter to common
- this is the first of a handful of pr's that splits the larger pr #5326
- it moves this facility to common (from ranking objective class), so that it can be
    used for metric computation
- it also wraps all the bald device pointers into span.
This commit is contained in:
sriramch
2020-02-28 23:42:07 -08:00
committed by GitHub
parent 2ba8c13b69
commit b81f8cbbc0
4 changed files with 275 additions and 260 deletions

View File

@@ -5,36 +5,34 @@
namespace xgboost {
template <typename T = uint32_t, typename Comparator = thrust::greater<T>>
std::unique_ptr<xgboost::obj::SegmentSorter<T>>
std::unique_ptr<dh::SegmentSorter<T>>
RankSegmentSorterTestImpl(const std::vector<uint32_t> &group_indices,
const std::vector<T> &hlabels,
const std::vector<T> &expected_sorted_hlabels,
const std::vector<uint32_t> &expected_orig_pos
) {
std::unique_ptr<xgboost::obj::SegmentSorter<T>> seg_sorter_ptr(
new xgboost::obj::SegmentSorter<T>);
xgboost::obj::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
std::unique_ptr<dh::SegmentSorter<T>> seg_sorter_ptr(new dh::SegmentSorter<T>);
dh::SegmentSorter<T> &seg_sorter(*seg_sorter_ptr);
// Create a bunch of unsorted labels on the device and sort it via the segment sorter
dh::device_vector<T> dlabels(hlabels);
seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator());
EXPECT_EQ(seg_sorter.GetNumItems(), group_indices.back());
auto num_items = seg_sorter.GetItemsSpan().size();
EXPECT_EQ(num_items, group_indices.back());
EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1);
// Check the labels
dh::device_vector<T> sorted_dlabels(seg_sorter.GetNumItems());
sorted_dlabels.assign(thrust::device_ptr<const T>(seg_sorter.GetItemsPtr()),
thrust::device_ptr<const T>(seg_sorter.GetItemsPtr())
+ seg_sorter.GetNumItems());
dh::device_vector<T> sorted_dlabels(num_items);
sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()),
dh::tcend(seg_sorter.GetItemsSpan()));
thrust::host_vector<T> sorted_hlabels(sorted_dlabels);
EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels);
// Check the indices
dh::device_vector<uint32_t> dorig_pos(seg_sorter.GetNumItems());
dorig_pos.assign(thrust::device_ptr<const uint32_t>(seg_sorter.GetOriginalPositionsPtr()),
thrust::device_ptr<const uint32_t>(seg_sorter.GetOriginalPositionsPtr())
+ seg_sorter.GetNumItems());
dh::device_vector<uint32_t> dorig_pos(num_items);
dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()),
dh::tcend(seg_sorter.GetOriginalPositionsSpan()));
dh::device_vector<uint32_t> horig_pos(dorig_pos);
EXPECT_EQ(expected_orig_pos, horig_pos);
@@ -152,18 +150,22 @@ TEST(Objective, NDCGLambdaWeightComputerTest) {
// Where will the predictions move from its current position, if they were sorted
// descendingly?
auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositions();
thrust::host_vector<uint32_t> hsorted_pred_pos(dsorted_pred_pos);
auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan();
std::vector<uint32_t> hsorted_pred_pos(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos);
std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3,
4, 5, 6,
7, 8, 11, 9, 10};
EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos);
// Check group DCG values
thrust::host_vector<float> hgroup_dcgs(ndcg_lw_computer.GetGroupDcgs());
thrust::host_vector<uint32_t> hgroups(segment_label_sorter->GetGroups());
thrust::host_vector<float> hsorted_labels(segment_label_sorter->GetItems());
std::vector<float> hgroup_dcgs(segment_label_sorter->GetNumGroups());
dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan());
std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups());
std::vector<float> hsorted_labels(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan());
for (auto i = 0; i < hgroup_dcgs.size(); ++i) {
// Compute group DCG value on CPU and compare
auto gbegin = hgroups[i];
@@ -193,7 +195,9 @@ TEST(Objective, IndexableSortedItemsTest) {
9, 11, 7, 10, 8});
segment_label_sorter->CreateIndexableSortedPositions();
thrust::host_vector<uint32_t> sorted_indices(segment_label_sorter->GetIndexableSortedPositions());
std::vector<uint32_t> sorted_indices(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&sorted_indices,
segment_label_sorter->GetIndexableSortedPositionsSpan());
std::vector<uint32_t> expected_sorted_indices = {
1, 3, 2, 0,
4, 6, 5,
@@ -228,11 +232,13 @@ TEST(Objective, ComputeAndCompareMAPStatsTest) {
*segment_label_sorter);
// Get the device MAP stats on host
thrust::host_vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> dmap_stats(
map_lw_computer.GetMapStats());
std::vector<xgboost::obj::MAPLambdaWeightComputer::MAPStats> dmap_stats(
segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan());
// Compute the MAP stats on host next to compare
thrust::host_vector<uint32_t> hgroups(segment_label_sorter->GetGroups());
std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
for (auto i = 0; i < hgroups.size() - 1; ++i) {
auto gbegin = hgroups[i];

View File

@@ -40,9 +40,9 @@ void VerifySampling(size_t page_size,
EXPECT_EQ(sample.page->matrix.n_rows, kRows);
EXPECT_EQ(sample.gpair.size(), kRows);
} else {
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.012f);
EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.012f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.012f);
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.016f);
EXPECT_NEAR(sample.page->matrix.n_rows, sample_rows, kRows * 0.016f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.016f);
}
GradientPair sum_sampled_gpair{};
@@ -52,11 +52,11 @@ void VerifySampling(size_t page_size,
sum_sampled_gpair += gp;
}
if (check_sum) {
EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.02f * kRows);
EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.02f * kRows);
EXPECT_NEAR(sum_gpair.GetGrad(), sum_sampled_gpair.GetGrad(), 0.03f * kRows);
EXPECT_NEAR(sum_gpair.GetHess(), sum_sampled_gpair.GetHess(), 0.03f * kRows);
} else {
EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.02f);
EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.02f);
EXPECT_NEAR(sum_gpair.GetGrad() / kRows, sum_sampled_gpair.GetGrad() / sample_rows, 0.03f);
EXPECT_NEAR(sum_gpair.GetHess() / kRows, sum_sampled_gpair.GetHess() / sample_rows, 0.03f);
}
}