finish stats.cu

This commit is contained in:
amdsc21 2023-03-10 05:38:16 +01:00
parent 911a5d8a60
commit 14cc438a64
3 changed files with 22 additions and 1 deletions

View File

@ -7,7 +7,13 @@
#include <cstddef> // size_t #include <cstddef> // size_t
#include "cuda_context.cuh" // CUDAContext #include "cuda_context.cuh" // CUDAContext
#if defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend #include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#elif defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h" // dh::MakeTransformIterator, tcbegin, tcend
#endif
#include "optional_weight.h" // common::OptionalWeights #include "optional_weight.h" // common::OptionalWeights
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile #include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
#include "xgboost/base.h" // XGBOOST_DEVICE #include "xgboost/base.h" // XGBOOST_DEVICE
@ -18,6 +24,11 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
namespace cuda_impl { namespace cuda_impl {
#if defined(XGBOOST_USE_HIP)
namespace cub = hipcub;
#endif
void Median(Context const* ctx, linalg::TensorView<float const, 2> t, void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
common::OptionalWeights weights, linalg::Tensor<float, 1>* out) { common::OptionalWeights weights, linalg::Tensor<float, 1>* out) {
CHECK_GE(t.Shape(1), 1); CHECK_GE(t.Shape(1), 1);

View File

@ -19,7 +19,13 @@
#include "algorithm.cuh" // SegmentedArgMergeSort #include "algorithm.cuh" // SegmentedArgMergeSort
#include "cuda_context.cuh" // CUDAContext #include "cuda_context.cuh" // CUDAContext
#if defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh" #include "device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h"
#endif
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
#include "xgboost/span.h" // Span #include "xgboost/span.h" // Span
@ -220,7 +226,7 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights, thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin()); scan_val, weights_cdf.begin());
#else #elif defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights, thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin()); scan_val, weights_cdf.begin());
#endif #endif

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "stats.cu"
#endif