Calculate base_score based on input labels for mae. (#8107)
Fit an intercept as base score for abs loss.
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2022 by Contributors
|
||||
* \file objective.cc
|
||||
* \brief Registry of all objective functions.
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
@@ -31,6 +31,11 @@ ObjFunction* ObjFunction::Create(const std::string& name, GenericParameter const
|
||||
return pobj;
|
||||
}
|
||||
|
||||
void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor<float, 1>* base_score) const {
|
||||
CHECK(base_score);
|
||||
base_score->Reshape(1);
|
||||
(*base_score)(0) = DefaultBaseScore();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/numeric.h" // Reduce
|
||||
#include "../common/pseudo_huber.h"
|
||||
#include "../common/stats.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform.h"
|
||||
#include "./regression_loss.h"
|
||||
@@ -37,14 +39,18 @@
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
namespace {
|
||||
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
|
||||
void CheckInitInputs(MetaInfo const& info) {
|
||||
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
|
||||
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
|
||||
if (!info.weights_.Empty()) {
|
||||
CHECK_EQ(info.weights_.Size(), info.num_row_)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
|
||||
CheckInitInputs(info);
|
||||
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
@@ -698,6 +704,33 @@ class MeanAbsoluteError : public ObjFunction {
|
||||
});
|
||||
}
|
||||
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
|
||||
CheckInitInputs(info);
|
||||
base_margin->Reshape(1);
|
||||
auto out = base_margin->HostView();
|
||||
|
||||
double w{0.0};
|
||||
if (info.weights_.Empty()) {
|
||||
w = static_cast<double>(info.num_row_);
|
||||
} else {
|
||||
w = common::Reduce(ctx_, info.weights_);
|
||||
}
|
||||
|
||||
if (info.num_row_ == 0) {
|
||||
out(0) = 0;
|
||||
} else {
|
||||
// weighted avg
|
||||
out(0) = common::Median(ctx_, info.labels, info.weights_) * w;
|
||||
}
|
||||
|
||||
// Weighted average base score across all workers
|
||||
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size());
|
||||
rabit::Allreduce<rabit::op::Sum>(&w, 1);
|
||||
|
||||
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
||||
[w](float v) { return v / w; });
|
||||
}
|
||||
|
||||
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
||||
HostDeviceVector<float> const& prediction, RegTree* p_tree) const override {
|
||||
if (ctx_->IsCPU()) {
|
||||
|
||||
Reference in New Issue
Block a user