finish stats.cu
This commit is contained in:
parent
911a5d8a60
commit
14cc438a64
@ -7,7 +7,13 @@
|
|||||||
#include <cstddef> // size_t
|
#include <cstddef> // size_t
|
||||||
|
|
||||||
#include "cuda_context.cuh" // CUDAContext
|
#include "cuda_context.cuh" // CUDAContext
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
#include "device_helpers.hip.h" // dh::MakeTransformIterator, tcbegin, tcend
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "optional_weight.h" // common::OptionalWeights
|
#include "optional_weight.h" // common::OptionalWeights
|
||||||
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
||||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||||
@ -18,6 +24,11 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
namespace cub = hipcub;
|
||||||
|
#endif
|
||||||
|
|
||||||
void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||||
common::OptionalWeights weights, linalg::Tensor<float, 1>* out) {
|
common::OptionalWeights weights, linalg::Tensor<float, 1>* out) {
|
||||||
CHECK_GE(t.Shape(1), 1);
|
CHECK_GE(t.Shape(1), 1);
|
||||||
|
|||||||
@ -19,7 +19,13 @@
|
|||||||
|
|
||||||
#include "algorithm.cuh" // SegmentedArgMergeSort
|
#include "algorithm.cuh" // SegmentedArgMergeSort
|
||||||
#include "cuda_context.cuh" // CUDAContext
|
#include "cuda_context.cuh" // CUDAContext
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
#include "device_helpers.hip.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "xgboost/context.h" // Context
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/span.h" // Span
|
#include "xgboost/span.h" // Span
|
||||||
|
|
||||||
@ -220,7 +226,7 @@ void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_b
|
|||||||
#if defined(XGBOOST_USE_HIP)
|
#if defined(XGBOOST_USE_HIP)
|
||||||
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
|
thrust::inclusive_scan_by_key(thrust::hip::par(caching), scan_key, scan_key + n_weights,
|
||||||
scan_val, weights_cdf.begin());
|
scan_val, weights_cdf.begin());
|
||||||
#else
|
#elif defined(XGBOOST_USE_CUDA)
|
||||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
|
||||||
scan_val, weights_cdf.begin());
|
scan_val, weights_cdf.begin());
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
#include "stats.cu"
|
||||||
|
#endif
|
||||||
Loading…
x
Reference in New Issue
Block a user