fix DispatchScan
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#include "hip/hip_runtime.h"
|
||||
/**
|
||||
* Copyright 2017-2023 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "hip/hip_runtime.h"
|
||||
|
||||
#include <thrust/binary_search.h> // thrust::upper_bound
|
||||
#include <thrust/device_malloc_allocator.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
@@ -22,8 +24,11 @@
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstddef> // for size_t
|
||||
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_allocator.hpp>
|
||||
#include <rocprim/rocprim.hpp>
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -1153,6 +1158,7 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
|
||||
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
OffsetT num_items) {
|
||||
size_t bytes = 0;
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((
|
||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||
@@ -1165,7 +1171,14 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
hipcub::NullType(), num_items, nullptr,
|
||||
false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(nullptr,
|
||||
bytes, d_in, d_out, num_items, scan_op)));
|
||||
|
||||
TemporaryArray<char> storage(bytes);
|
||||
|
||||
#if 0
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((
|
||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||
@@ -1179,6 +1192,10 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
d_out, scan_op, hipcub::NullType(),
|
||||
num_items, nullptr, false)));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(
|
||||
storage.data().get(), bytes, d_in, d_out, num_items, scan_op)));
|
||||
}
|
||||
|
||||
template <typename InIt, typename OutIt, typename Predicate>
|
||||
|
||||
Reference in New Issue
Block a user