finish stats.cu
This commit is contained in:
@@ -7,7 +7,13 @@
|
||||
#include <cstddef> // size_t
|
||||
|
||||
#include "cuda_context.cuh" // CUDAContext
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
#include "device_helpers.hip.h" // dh::MakeTransformIterator, tcbegin, tcend
|
||||
#endif
|
||||
|
||||
#include "optional_weight.h" // common::OptionalWeights
|
||||
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||
@@ -18,6 +24,11 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace cuda_impl {
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
common::OptionalWeights weights, linalg::Tensor<float, 1>* out) {
|
||||
CHECK_GE(t.Shape(1), 1);
|
||||
|
||||
Reference in New Issue
Block a user