Add device argsort. (#6749)

This is part of https://github.com/dmlc/xgboost/pull/6747 .
This commit is contained in:
Jiaming Yuan
2021-03-16 16:05:22 +08:00
committed by GitHub
parent 325bc93e16
commit 1a73a28511
3 changed files with 97 additions and 1 deletions

View File

@@ -171,5 +171,28 @@ TEST(Allocator, OOM) {
// Clear last error so we don't fail subsequent tests
cudaGetLastError();
}
TEST(DeviceHelpers, ArgSort) {
dh::device_vector<float> values(20);
dh::Iota(dh::ToSpan(values)); // accending
dh::device_vector<size_t> sorted_idx(20);
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
dh::Iota(dh::ToSpan(values));
dh::device_vector<size_t> groups(3);
groups[0] = 0;
groups[1] = 10;
groups[2] = 20;
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
dh::ToSpan(sorted_idx));
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
thrust::greater<size_t>{}));
}
} // namespace common
} // namespace xgboost