Change reduce operation from thrust to cub. Fix for cuda 9.1 error (#3218)
* Change reduce operation from thrust to cub. Fix for cuda 9.1 runtime error * Unit test sum reduce
This commit is contained in:
@@ -797,6 +797,29 @@ void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
|
||||
in.data(), out.data(), nVals));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to perform device-wide sum-reduction, returns to the
|
||||
* host
|
||||
* @param tmp_mem cub temporary memory info
|
||||
* @param in the input array to be reduced
|
||||
* @param nVals number of elements in the input array
|
||||
*/
|
||||
template <typename T>
|
||||
T sumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) {
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
|
||||
// Allocate small extra memory for the return value
|
||||
tmp_mem.LazyAllocate(tmpSize + sizeof(T));
|
||||
auto ptr = reinterpret_cast<T *>(tmp_mem.d_temp_storage) + 1;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(
|
||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
||||
reinterpret_cast<T *>(tmp_mem.d_temp_storage), nVals));
|
||||
T sum;
|
||||
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fill a given constant value across all elements in the buffer
|
||||
* @param out the buffer to be filled
|
||||
|
||||
Reference in New Issue
Block a user