fix aft_obj.hip
This commit is contained in:
parent
b71c1b50de
commit
a2bab03205
@ -2,9 +2,6 @@
|
||||
* Copyright 2017-2023 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "hip/hip_runtime.h"
|
||||
|
||||
#include <thrust/binary_search.h> // thrust::upper_bound
|
||||
#include <thrust/device_malloc_allocator.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
@ -24,11 +21,9 @@
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstddef> // for size_t
|
||||
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_allocator.hpp>
|
||||
#include <rocprim/rocprim.hpp>
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -1158,41 +1153,9 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
|
||||
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
OffsetT num_items) {
|
||||
size_t bytes = 0;
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((
|
||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
||||
hipcub::NullType(), num_items, nullptr)));
|
||||
#else
|
||||
safe_cuda((
|
||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
||||
hipcub::NullType(), num_items, nullptr,
|
||||
false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::inclusive_scan(nullptr, bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((
|
||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
||||
d_out, scan_op, hipcub::NullType(),
|
||||
num_items, nullptr)));
|
||||
#else
|
||||
safe_cuda((
|
||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
||||
d_out, scan_op, hipcub::NullType(),
|
||||
num_items, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::inclusive_scan(storage.data().get(), bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
||||
}
|
||||
|
||||
@ -1233,73 +1196,23 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
if (accending) {
|
||||
void *d_temp_storage = nullptr;
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8)));
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
|
||||
#if !defined(XGBOOST_USE_HIP)
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#include "aft_obj.cu"
|
||||
#endif
|
||||
|
||||
@ -144,6 +144,7 @@ TEST(GpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("gpu_predictor");
|
||||
}
|
||||
|
||||
#if 0
|
||||
// Very basic test of empty model
|
||||
TEST(GPUPredictor, ShapStump) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
@ -212,7 +213,7 @@ TEST(GPUPredictor, Shap) {
|
||||
TEST(GPUPredictor, IterationRange) {
|
||||
TestIterationRange("gpu_predictor");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TEST(GPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("gpu_predictor");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user