Extract device algorithms. (#8789)

This commit is contained in:
Jiaming Yuan
2023-02-13 20:53:53 +08:00
committed by GitHub
parent 457f704e3d
commit 31d3ec07af
13 changed files with 361 additions and 218 deletions

View File

@@ -172,28 +172,4 @@ TEST(Allocator, OOM) {
// Clear last error so we don't fail subsequent tests
cudaGetLastError();
}
TEST(DeviceHelpers, ArgSort) {
dh::device_vector<float> values(20);
dh::Iota(dh::ToSpan(values)); // accending
dh::device_vector<size_t> sorted_idx(20);
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
dh::Iota(dh::ToSpan(values));
dh::device_vector<size_t> groups(3);
groups[0] = 0;
groups[1] = 10;
groups[2] = 20;
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
dh::ToSpan(sorted_idx));
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
thrust::greater<size_t>{}));
}
} // namespace xgboost