[back port] Copy output data for argsort. (#6866) (#6870)

Fix GPU AUC.
This commit is contained in:
Jiaming Yuan 2021-04-17 01:33:22 +08:00 committed by GitHub
parent 9f5e2c52ce
commit d05c47dcf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 8 deletions

View File

@ -1321,15 +1321,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
TemporaryArray<KeyT> out(keys.size()); TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()), cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
out.data().get()); out.data().get());
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()), cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx.data()); sorted_idx_out.data().get());
if (accending) { if (accending) {
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch( safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
dh::TemporaryArray<char> storage(bytes); TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get(); d_temp_storage = storage.data().get();
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch( safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
@ -1339,12 +1340,15 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch( safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
dh::TemporaryArray<char> storage(bytes); TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get(); d_temp_storage = storage.data().get();
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch( safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
} }
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
} }
namespace detail { namespace detail {
@ -1379,14 +1383,19 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
size_t bytes = 0; size_t bytes = 0;
Iota(sorted_idx); Iota(sorted_idx);
TemporaryArray<std::remove_const_t<U>> values_out(values.size()); TemporaryArray<std::remove_const_t<U>> values_out(values.size());
TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());
detail::DeviceSegmentedRadixSortPair<!accending>( detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(), nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size(), n_groups, group_ptr.data(), sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1); group_ptr.data() + 1);
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes); TemporaryArray<xgboost::common::byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>( detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(), temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups, sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(),
group_ptr.data(), group_ptr.data() + 1); n_groups, group_ptr.data(), group_ptr.data() + 1);
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
} }
} // namespace dh } // namespace dh

View File

@ -269,7 +269,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
}); });
// unique values are sparse, so we need a CSR style indptr // unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size() + 1); dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr); auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
auto n_uniques = dh::SegmentedUniqueByKey( auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc), thrust::cuda::par(alloc),