Add device argsort. (#6749)
This is part of https://github.com/dmlc/xgboost/pull/6747 .
This commit is contained in:
@@ -171,5 +171,28 @@ TEST(Allocator, OOM) {
|
||||
// Clear last error so we don't fail subsequent tests
|
||||
cudaGetLastError();
|
||||
}
|
||||
|
||||
TEST(DeviceHelpers, ArgSort) {
|
||||
dh::device_vector<float> values(20);
|
||||
dh::Iota(dh::ToSpan(values)); // accending
|
||||
dh::device_vector<size_t> sorted_idx(20);
|
||||
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
|
||||
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
|
||||
sorted_idx.end(), thrust::greater<size_t>{}));
|
||||
|
||||
dh::Iota(dh::ToSpan(values));
|
||||
dh::device_vector<size_t> groups(3);
|
||||
groups[0] = 0;
|
||||
groups[1] = 10;
|
||||
groups[2] = 20;
|
||||
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
|
||||
dh::ToSpan(sorted_idx));
|
||||
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
|
||||
sorted_idx.end(), thrust::greater<size_t>{}));
|
||||
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
|
||||
thrust::greater<size_t>{}));
|
||||
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
|
||||
thrust::greater<size_t>{}));
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user