fix DispatchScan
This commit is contained in:
parent
bdcb036592
commit
a45005863b
@ -1,8 +1,10 @@
|
|||||||
#include "hip/hip_runtime.h"
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2017-2023 XGBoost contributors
|
* Copyright 2017-2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "hip/hip_runtime.h"
|
||||||
|
|
||||||
#include <thrust/binary_search.h> // thrust::upper_bound
|
#include <thrust/binary_search.h> // thrust::upper_bound
|
||||||
#include <thrust/device_malloc_allocator.h>
|
#include <thrust/device_malloc_allocator.h>
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
@ -22,8 +24,11 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
#include <hipcub/util_allocator.hpp>
|
#include <hipcub/util_allocator.hpp>
|
||||||
|
#include <rocprim/rocprim.hpp>
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -1153,6 +1158,7 @@ template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
|
|||||||
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||||
OffsetT num_items) {
|
OffsetT num_items) {
|
||||||
size_t bytes = 0;
|
size_t bytes = 0;
|
||||||
|
#if 0
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
safe_cuda((
|
safe_cuda((
|
||||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||||
@ -1165,7 +1171,14 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
|||||||
hipcub::NullType(), num_items, nullptr,
|
hipcub::NullType(), num_items, nullptr,
|
||||||
false)));
|
false)));
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(nullptr,
|
||||||
|
bytes, d_in, d_out, num_items, scan_op)));
|
||||||
|
|
||||||
TemporaryArray<char> storage(bytes);
|
TemporaryArray<char> storage(bytes);
|
||||||
|
|
||||||
|
#if 0
|
||||||
#if THRUST_MAJOR_VERSION >= 2
|
#if THRUST_MAJOR_VERSION >= 2
|
||||||
safe_cuda((
|
safe_cuda((
|
||||||
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
hipcub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, hipcub::NullType,
|
||||||
@ -1179,6 +1192,10 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
|||||||
d_out, scan_op, hipcub::NullType(),
|
d_out, scan_op, hipcub::NullType(),
|
||||||
num_items, nullptr, false)));
|
num_items, nullptr, false)));
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(
|
||||||
|
storage.data().get(), bytes, d_in, d_out, num_items, scan_op)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InIt, typename OutIt, typename Predicate>
|
template <typename InIt, typename OutIt, typename Predicate>
|
||||||
|
|||||||
@ -13,6 +13,10 @@
|
|||||||
#include "gradient_index.h"
|
#include "gradient_index.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
|
#if defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#include <rocprim/rocprim.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
|
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
|
||||||
@ -235,6 +239,8 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
|||||||
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
|
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
|
||||||
// So we don't crash on n > 2^31
|
// So we don't crash on n > 2^31
|
||||||
size_t temp_storage_bytes = 0;
|
size_t temp_storage_bytes = 0;
|
||||||
|
|
||||||
|
#if defined(__CUDACC__)
|
||||||
using DispatchScan =
|
using DispatchScan =
|
||||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||||
@ -257,6 +263,19 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
|||||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||||
cub::NullType(), batch.Size(), nullptr, false);
|
cub::NullType(), batch.Size(), nullptr, false);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#elif defined (__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
|
rocprim::inclusive_scan<decltype(key_value_index_iter), decltype(out), TupleScanOp<Tuple>>
|
||||||
|
(nullptr, temp_storage_bytes, key_value_index_iter, out, batch.Size(), TupleScanOp<Tuple>());
|
||||||
|
|
||||||
|
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||||
|
|
||||||
|
rocprim::inclusive_scan<decltype(key_value_index_iter), decltype(out), TupleScanOp<Tuple>>
|
||||||
|
(temp_storage.data().get(), temp_storage_bytes, key_value_index_iter, out, batch.Size(),
|
||||||
|
TupleScanOp<Tuple>());
|
||||||
|
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user