Calculate base_score based on input labels for mae. (#8107)

Fit an intercept as base score for abs loss.
This commit is contained in:
Jiaming Yuan
2022-09-20 20:53:54 +08:00
committed by GitHub
parent 4f42aa5f12
commit fffb1fca52
42 changed files with 999 additions and 343 deletions

View File

@@ -7,6 +7,7 @@
#include <limits>
#include <vector>
#include "../common/common.h"
#include "rabit/rabit.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"

View File

@@ -1,10 +1,10 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2022 by Contributors
* \file objective.cc
* \brief Registry of all objective functions.
*/
#include <xgboost/objective.h>
#include <dmlc/registry.h>
#include <xgboost/objective.h>
#include <sstream>
@@ -31,6 +31,11 @@ ObjFunction* ObjFunction::Create(const std::string& name, GenericParameter const
return pobj;
}
void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor<float, 1>* base_score) const {
CHECK(base_score);
base_score->Reshape(1);
(*base_score)(0) = DefaultBaseScore();
}
} // namespace xgboost
namespace xgboost {

View File

@@ -15,7 +15,9 @@
#include "../common/common.h"
#include "../common/linalg_op.h"
#include "../common/numeric.h" // Reduce
#include "../common/pseudo_huber.h"
#include "../common/stats.h"
#include "../common/threading_utils.h"
#include "../common/transform.h"
#include "./regression_loss.h"
@@ -37,14 +39,18 @@
namespace xgboost {
namespace obj {
namespace {
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
void CheckInitInputs(MetaInfo const& info) {
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points.";
}
}
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
CheckInitInputs(info);
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
}
} // anonymous namespace
#if defined(XGBOOST_USE_CUDA)
@@ -698,6 +704,33 @@ class MeanAbsoluteError : public ObjFunction {
});
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
CheckInitInputs(info);
base_margin->Reshape(1);
auto out = base_margin->HostView();
double w{0.0};
if (info.weights_.Empty()) {
w = static_cast<double>(info.num_row_);
} else {
w = common::Reduce(ctx_, info.weights_);
}
if (info.num_row_ == 0) {
out(0) = 0;
} else {
// weighted avg
out(0) = common::Median(ctx_, info.labels, info.weights_) * w;
}
// Weighted average base score across all workers
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size());
rabit::Allreduce<rabit::op::Sum>(&w, 1);
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
[w](float v) { return v / w; });
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, RegTree* p_tree) const override {
if (ctx_->IsCPU()) {