Multi-target support for L1 error. (#8652)
- Add matrix support to the median function. - Iterate through each target for quantile computation.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_STATS_H_
|
||||
#define XGBOOST_COMMON_STATS_H_
|
||||
@@ -95,13 +95,15 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights);
|
||||
void Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights,
|
||||
linalg::Tensor<float, 1>* out);
|
||||
|
||||
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline float Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights) {
|
||||
inline void Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights,
|
||||
linalg::Tensor<float, 1>*) {
|
||||
common::AssertGPUSupport();
|
||||
return 0;
|
||||
}
|
||||
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
|
||||
common::AssertGPUSupport();
|
||||
@@ -109,8 +111,11 @@ inline void Mean(Context const*, linalg::VectorView<float const>, linalg::Vector
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda_impl
|
||||
|
||||
float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||
HostDeviceVector<float> const& weights);
|
||||
/**
|
||||
* \brief Calculate medians for each column of the input matrix.
|
||||
*/
|
||||
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out);
|
||||
|
||||
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
|
||||
} // namespace common
|
||||
|
||||
Reference in New Issue
Block a user