fix DispatchScan
This commit is contained in:
@@ -13,6 +13,10 @@
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#include <rocprim/rocprim.hpp>
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
|
||||
@@ -235,6 +239,8 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
|
||||
// So we don't crash on n > 2^31
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
using DispatchScan =
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
@@ -257,6 +263,19 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr, false);
|
||||
#endif
|
||||
|
||||
#elif defined (__HIP_PLATFORM_AMD__)
|
||||
|
||||
rocprim::inclusive_scan<decltype(key_value_index_iter), decltype(out), TupleScanOp<Tuple>>
|
||||
(nullptr, temp_storage_bytes, key_value_index_iter, out, batch.Size(), TupleScanOp<Tuple>());
|
||||
|
||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||
|
||||
rocprim::inclusive_scan<decltype(key_value_index_iter), decltype(out), TupleScanOp<Tuple>>
|
||||
(temp_storage.data().get(), temp_storage_bytes, key_value_index_iter, out, batch.Size(),
|
||||
TupleScanOp<Tuple>());
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
|
||||
Reference in New Issue
Block a user