finish stats.cu

This commit is contained in:
amdsc21
2023-03-10 05:38:16 +01:00
parent 911a5d8a60
commit 14cc438a64
3 changed files with 22 additions and 1 deletions

View File

@@ -7,7 +7,13 @@
#include <cstddef> // size_t
#include "cuda_context.cuh" // CUDAContext
#if defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#elif defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h" // dh::MakeTransformIterator, tcbegin, tcend
#endif
#include "optional_weight.h" // common::OptionalWeights
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
#include "xgboost/base.h" // XGBOOST_DEVICE
@@ -18,6 +24,11 @@
namespace xgboost {
namespace common {
namespace cuda_impl {
#if defined(XGBOOST_USE_HIP)
namespace cub = hipcub;
#endif
void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
common::OptionalWeights weights, linalg::Tensor<float, 1>* out) {
CHECK_GE(t.Shape(1), 1);