From a45005863b19e849db7ad9986325374e15d80fd4 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 20:15:33 +0100 Subject: [PATCH] fix DispatchScan --- src/common/device_helpers.hip.h | 19 ++++++++++++++++++- src/data/ellpack_page.cu | 19 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 0452d6626..3ac3f6b6a 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -1,8 +1,10 @@ -#include "hip/hip_runtime.h" /** * Copyright 2017-2023 XGBoost contributors */ #pragma once + +#include "hip/hip_runtime.h" + #include // thrust::upper_bound #include #include @@ -22,8 +24,11 @@ #include #include #include // for size_t + #include #include +#include + #include #include #include @@ -1153,6 +1158,7 @@ template = 2 safe_cuda(( hipcub::DispatchScan(nullptr, + bytes, d_in, d_out, num_items, scan_op))); + TemporaryArray storage(bytes); + +#if 0 #if THRUST_MAJOR_VERSION >= 2 safe_cuda(( hipcub::DispatchScan( + storage.data().get(), bytes, d_in, d_out, num_items, scan_op))); } template diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 99e17d886..ed84d532f 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -13,6 +13,10 @@ #include "gradient_index.h" #include "xgboost/data.h" +#if defined(__HIP_PLATFORM_AMD__) +#include +#endif + namespace xgboost { EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {} @@ -235,6 +239,8 @@ void CopyDataToEllpack(const AdapterBatchT &batch, // Go one level down into cub::DeviceScan API to set OffsetT as 64 bit // So we don't crash on n > 2^31 size_t temp_storage_bytes = 0; + +#if defined(__CUDACC__) using DispatchScan = cub::DispatchScan, cub::NullType, int64_t>; @@ -257,6 +263,19 @@ void CopyDataToEllpack(const AdapterBatchT &batch, key_value_index_iter, out, TupleScanOp(), cub::NullType(), batch.Size(), nullptr, false); #endif + +#elif defined (__HIP_PLATFORM_AMD__) + + rocprim::inclusive_scan> + (nullptr, temp_storage_bytes, key_value_index_iter, out, batch.Size(), TupleScanOp()); + + dh::TemporaryArray temp_storage(temp_storage_bytes); + + rocprim::inclusive_scan> + (temp_storage.data().get(), temp_storage_bytes, key_value_index_iter, out, batch.Size(), + TupleScanOp()); + +#endif } void WriteNullValues(EllpackPageImpl* dst, int device_idx,