Calculate base_score based on input labels for mae. (#8107)
Fit an intercept as base score for abs loss.
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "common.h" // AssertGPUSupport
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -90,6 +91,44 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
|
||||
idx = std::min(idx, static_cast<size_t>(n - 1));
|
||||
return val(idx);
|
||||
}
|
||||
|
||||
namespace cuda {
|
||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
common::OptionalWeights weights);
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline float Median(Context const*, linalg::TensorView<float const, 2>, common::OptionalWeights) {
|
||||
AssertGPUSupport();
|
||||
return 0;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda
|
||||
|
||||
inline float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||
HostDeviceVector<float> const& weights) {
|
||||
if (!ctx->IsCPU()) {
|
||||
weights.SetDevice(ctx->gpu_id);
|
||||
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
||||
auto t_v = t.View(ctx->gpu_id);
|
||||
return cuda::Median(ctx, t_v, opt_weights);
|
||||
}
|
||||
|
||||
auto opt_weights = OptionalWeights(weights.ConstHostSpan());
|
||||
auto t_v = t.HostView();
|
||||
auto iter = common::MakeIndexTransformIter(
|
||||
[&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); });
|
||||
float q{0};
|
||||
if (opt_weights.Empty()) {
|
||||
q = common::Quantile(0.5, iter, iter + t_v.Size());
|
||||
} else {
|
||||
CHECK_NE(t_v.Shape(1), 0);
|
||||
auto w_it = common::MakeIndexTransformIter([&](size_t i) {
|
||||
auto sample_idx = i / t_v.Shape(1);
|
||||
return opt_weights[sample_idx];
|
||||
});
|
||||
q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it);
|
||||
}
|
||||
return q;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_STATS_H_
|
||||
|
||||
Reference in New Issue
Block a user