fix auc.cu
This commit is contained in:
@@ -1208,8 +1208,7 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(nullptr,
|
||||
bytes, d_in, d_out, num_items, scan_op)));
|
||||
safe_cuda((rocprim::inclusive_scan(nullptr, bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
|
||||
@@ -1229,8 +1228,7 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(
|
||||
storage.data().get(), bytes, d_in, d_out, num_items, scan_op)));
|
||||
safe_cuda((rocprim::inclusive_scan(storage.data().get(), bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
||||
}
|
||||
|
||||
template <typename InIt, typename OutIt, typename Predicate>
|
||||
@@ -1262,11 +1260,7 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
using ValueT = std::remove_const_t<IdxT>;
|
||||
|
||||
TemporaryArray<KeyT> out(keys.size());
|
||||
hipcub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
|
||||
out.data().get());
|
||||
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
|
||||
hipcub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
|
||||
sorted_idx_out.data().get());
|
||||
|
||||
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
|
||||
using OffsetT = std::conditional_t<!BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
|
||||
@@ -1286,8 +1280,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
@@ -1305,8 +1299,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
@@ -1323,8 +1317,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
@@ -1341,8 +1335,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
|
||||
bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user