add warp size

This commit is contained in:
amdsc21 2023-03-15 22:00:26 +01:00
parent 4484c7f073
commit a79a35c22c

View File

@ -18,6 +18,12 @@
#include "evaluate_splits.cuh" #include "evaluate_splits.cuh"
#include "expand_entry.cuh" #include "expand_entry.cuh"
#if defined(XGBOOST_USE_HIP)
#define WARP_SIZE WAVEFRONT_SIZE
#elif defined(XGBOOST_USE_CUDA)
#define WARP_SIZE 32
#endif
namespace xgboost { namespace xgboost {
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
namespace cub = hipcub; namespace cub = hipcub;
@ -97,11 +103,7 @@ class EvaluateSplitAgent {
param(shared_inputs.param), evaluator(evaluator), param(shared_inputs.param), evaluator(evaluator),
missing(parent_sum - ReduceFeature()) { missing(parent_sum - ReduceFeature()) {
static_assert( static_assert(
#if defined(XGBOOST_USE_HIP) kBlockSize == WARP_SIZE,
kBlockSize == WAVEFRONT_SIZE,
#elif defined(XGBOOST_USE_CUDA)
kBlockSize == 32,
#endif
"This kernel relies on the assumption block_size == warp_size"); "This kernel relies on the assumption block_size == warp_size");
// There should be no missing value gradients for a dense matrix // There should be no missing value gradients for a dense matrix
KERNEL_CHECK(!shared_inputs.is_dense || missing.GetQuantisedHess() == 0); KERNEL_CHECK(!shared_inputs.is_dense || missing.GetQuantisedHess() == 0);
@ -393,11 +395,7 @@ void GPUHistEvaluator::LaunchEvaluateSplits(
combined_num_features, DeviceSplitCandidate()); combined_num_features, DeviceSplitCandidate());
// One block for each feature // One block for each feature
#if defined(XGBOOST_USE_HIP) uint32_t constexpr kBlockThreads = WARP_SIZE;
uint32_t constexpr kBlockThreads = WAVEFRONT_SIZE;
#elif defined(XGBOOST_USE_CUDA)
uint32_t constexpr kBlockThreads = 32;
#endif
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads,
0}( 0}(
EvaluateSplitsKernel<kBlockThreads>, max_active_features, d_inputs, EvaluateSplitsKernel<kBlockThreads>, max_active_features, d_inputs,