fix aft_obj.hip
This commit is contained in:
parent
b71c1b50de
commit
a2bab03205
@ -2,9 +2,6 @@
|
|||||||
* Copyright 2017-2023 XGBoost contributors
|
* Copyright 2017-2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "hip/hip_runtime.h"
|
|
||||||
|
|
||||||
#include <thrust/binary_search.h> // thrust::upper_bound
|
#include <thrust/binary_search.h> // thrust::upper_bound
|
||||||
#include <thrust/device_malloc_allocator.h>
|
#include <thrust/device_malloc_allocator.h>
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
@ -24,11 +21,9 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
#include <hipcub/util_allocator.hpp>
|
#include <hipcub/util_allocator.hpp>
|
||||||
#include <rocprim/rocprim.hpp>
|
#include <rocprim/rocprim.hpp>
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -1158,41 +1153,9 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
|
|||||||
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||||
OffsetT num_items) {
|
OffsetT num_items) {
|
||||||
size_t bytes = 0;
|
size_t bytes = 0;
|
||||||
#if 0
|
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
|
||||||
safe_cuda((
|
|
||||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
|
||||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
|
||||||
hipcub::NullType(), num_items, nullptr)));
|
|
||||||
#else
|
|
||||||
safe_cuda((
|
|
||||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
|
||||||
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
|
|
||||||
hipcub::NullType(), num_items, nullptr,
|
|
||||||
false)));
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
safe_cuda((rocprim::inclusive_scan(nullptr, bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
safe_cuda((rocprim::inclusive_scan(nullptr, bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
||||||
|
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
|
|
||||||
#if 0
|
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
|
||||||
safe_cuda((
|
|
||||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
|
||||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
|
||||||
d_out, scan_op, hipcub::NullType(),
|
|
||||||
num_items, nullptr)));
|
|
||||||
#else
|
|
||||||
safe_cuda((
|
|
||||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
|
||||||
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
|
|
||||||
d_out, scan_op, hipcub::NullType(),
|
|
||||||
num_items, nullptr, false)));
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
safe_cuda((rocprim::inclusive_scan(storage.data().get(), bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
safe_cuda((rocprim::inclusive_scan(storage.data().get(), bytes, d_in, d_out, (size_t) num_items, scan_op)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1233,73 +1196,23 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
if (accending) {
|
if (accending) {
|
||||||
void *d_temp_storage = nullptr;
|
void *d_temp_storage = nullptr;
|
||||||
|
|
||||||
#if 0
|
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr)));
|
|
||||||
#else
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8)));
|
sizeof(KeyT) * 8)));
|
||||||
|
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
d_temp_storage = storage.data().get();
|
d_temp_storage = storage.data().get();
|
||||||
|
|
||||||
#if 0
|
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr)));
|
|
||||||
#else
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
||||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8)));
|
sizeof(KeyT) * 8)));
|
||||||
} else {
|
} else {
|
||||||
void *d_temp_storage = nullptr;
|
void *d_temp_storage = nullptr;
|
||||||
|
|
||||||
#if 0
|
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr)));
|
|
||||||
#else
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
||||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8)));
|
sizeof(KeyT) * 8)));
|
||||||
|
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
d_temp_storage = storage.data().get();
|
d_temp_storage = storage.data().get();
|
||||||
|
|
||||||
#if 0
|
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr)));
|
|
||||||
#else
|
|
||||||
safe_cuda((hipcub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
|
||||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
|
||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
||||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8)));
|
sizeof(KeyT) * 8)));
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
|
|
||||||
#if !defined(XGBOOST_USE_HIP)
|
#if defined(XGBOOST_USE_HIP)
|
||||||
#include "aft_obj.cu"
|
#include "aft_obj.cu"
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -144,6 +144,7 @@ TEST(GpuPredictor, LesserFeatures) {
|
|||||||
TestPredictionWithLesserFeatures("gpu_predictor");
|
TestPredictionWithLesserFeatures("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
// Very basic test of empty model
|
// Very basic test of empty model
|
||||||
TEST(GPUPredictor, ShapStump) {
|
TEST(GPUPredictor, ShapStump) {
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
@ -212,7 +213,7 @@ TEST(GPUPredictor, Shap) {
|
|||||||
TEST(GPUPredictor, IterationRange) {
|
TEST(GPUPredictor, IterationRange) {
|
||||||
TestIterationRange("gpu_predictor");
|
TestIterationRange("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TEST(GPUPredictor, CategoricalPrediction) {
|
TEST(GPUPredictor, CategoricalPrediction) {
|
||||||
TestCategoricalPrediction("gpu_predictor");
|
TestCategoricalPrediction("gpu_predictor");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user