This commit is contained in:
Hendrik Groove 2024-10-21 11:43:29 +02:00
parent bb2feab0b2
commit 8c15f3b665

View File

@ -11,9 +11,9 @@
#include "evaluate_splits.cuh"
#include "expand_entry.cuh"
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
#define WARP_SIZE 32
#elif defined(XGBOOST_USE_HIP2)
#elif defined(XGBOOST_USE_HIP)
#include <hip/hip_cooperative_groups.h>
#ifdef __AMDGCN_WAVEFRONT_SIZE
@ -110,10 +110,10 @@ class EvaluateSplitAgent {
}
local_sum = SumReduceT(temp_storage->sum_reduce).Sum(local_sum); // NOLINT
// Broadcast result from thread 0
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
return {__shfl_sync(0xffffffff, local_sum.GetQuantisedGrad(), 0),
__shfl_sync(0xffffffff, local_sum.GetQuantisedHess(), 0)};
#elif defined(XGBOOST_USE_HIP2)
#elif defined(XGBOOST_USE_HIP)
return {__shfl(local_sum.GetQuantisedGrad(), 0),
__shfl(local_sum.GetQuantisedHess(), 0)};
#endif
@ -144,9 +144,9 @@ class EvaluateSplitAgent {
// This reduce result is only valid in thread 0
// broadcast to the rest of the warp
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
#elif defined(XGBOOST_USE_HIP2)
#elif defined(XGBOOST_USE_HIP)
auto best_thread = __shfl(best.key, 0);
#endif
@ -181,9 +181,9 @@ class EvaluateSplitAgent {
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({(int)threadIdx.x, gain}, cub::ArgMax());
// This reduce result is only valid in thread 0
// broadcast to the rest of the warp
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
#elif defined(XGBOOST_USE_HIP2)
#elif defined(XGBOOST_USE_HIP)
auto best_thread = __shfl(best.key, 0);
#endif
@ -215,9 +215,9 @@ class EvaluateSplitAgent {
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({(int)threadIdx.x, gain}, cub::ArgMax());
// This reduce result is only valid in thread 0
// broadcast to the rest of the warp
#if defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_CUDA)
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
#elif defined(XGBOOST_USE_HIP2)
#elif defined(XGBOOST_USE_HIP)
auto best_thread = __shfl(best.key, 0);
#endif