Multi-target support for L1 error. (#8652)

- Add matrix support to the median function.
- Iterate through each target for quantile computation.
This commit is contained in:
Jiaming Yuan
2023-01-11 05:51:14 +08:00
committed by GitHub
parent badeff1d74
commit cfa994d57f
19 changed files with 430 additions and 215 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_STATS_H_
#define XGBOOST_COMMON_STATS_H_
@@ -95,13 +95,15 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
}
namespace cuda_impl {
float Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights);
void Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights,
linalg::Tensor<float, 1>* out);
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);
#if !defined(XGBOOST_USE_CUDA)
inline float Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights) {
inline void Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights,
linalg::Tensor<float, 1>*) {
common::AssertGPUSupport();
return 0;
}
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
common::AssertGPUSupport();
@@ -109,8 +111,11 @@ inline void Mean(Context const*, linalg::VectorView<float const>, linalg::Vector
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl
float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
HostDeviceVector<float> const& weights);
/**
* \brief Calculate medians for each column of the input matrix.
*/
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out);
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
} // namespace common