Support multi-target, fit intercept for hinge. (#9850)

This commit is contained in:
Jiaming Yuan
2023-12-08 05:50:41 +08:00
committed by GitHub
parent 39c637ee19
commit 42de9206fc
8 changed files with 221 additions and 155 deletions

View File

@@ -4,71 +4,85 @@
* \brief Provides an implementation of the hinge loss function
* \author Henry Gouk
*/
#include "xgboost/objective.h"
#include "xgboost/json.h"
#include "xgboost/span.h"
#include "xgboost/host_device_vector.h"
#include <algorithm> // for max
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include "../common/math.h"
#include "../common/transform.h"
#include "../common/common.h"
#include "../common/common.h" // for Range
#if defined(XGBOOST_USE_CUDA)
#include "../common/linalg_op.cuh"
#endif
#include "../common/linalg_op.h"
#include "../common/optional_weight.h" // for OptionalWeights
#include "../common/transform.h" // for Transform
#include "init_estimation.h" // for FitIntercept
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/json.h" // for Json
#include "xgboost/linalg.h" // for UnravelIndex
#include "xgboost/span.h" // for Span
namespace xgboost::obj {
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
class HingeObj : public ObjFunction {
class HingeObj : public FitIntercept {
public:
HingeObj() = default;
void Configure(Args const&) override {}
void Configure(Args const &) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size()
<< ", label.size=" << info.labels.Size();
const size_t ndata = preds.Size();
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported.";
out_gpair->Reshape(ndata, 1);
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx] * 2.0 - 1.0;
bst_float g, h;
if (p * y < 1.0) {
g = -y * w;
h = w;
} else {
g = 0.0;
h = std::numeric_limits<bst_float>::min();
}
_out_gpair[_idx] = GradientPair(g, h);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
ctx_->Device()).Eval(
out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
[[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override {
// Multi-target regression.
return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
void GetGradient(HostDeviceVector<float> const &preds, MetaInfo const &info,
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
CheckInitInputs(info);
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points.";
}
bst_target_t n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);
auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);
auto labels = info.labels.View(ctx_->Device());
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(this->ctx_, labels,
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto w = weight[i];
auto p = predt(i, j);
auto y = labels(i, j) * 2.0 - 1.0;
float g, h;
if (p * y < 1.0) {
g = -y * w;
h = w;
} else {
g = 0.0;
h = std::numeric_limits<float>::min();
}
gpair(i, j) = GradientPair{g, h};
});
}
void PredTransform(HostDeviceVector<float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
[] XGBOOST_DEVICE(std::size_t _idx, common::Span<float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
},
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(),
@@ -76,12 +90,10 @@ class HingeObj : public ObjFunction {
.Eval(io_preds);
}
[[nodiscard]] const char* DefaultEvalMetric() const override {
return "error";
}
[[nodiscard]] const char *DefaultEvalMetric() const override { return "error"; }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["name"] = String("binary:hinge");
}
void LoadConfig(Json const &) override {}
@@ -89,7 +101,7 @@ class HingeObj : public ObjFunction {
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
} // namespace xgboost::obj

View File

@@ -75,28 +75,25 @@ class QuantileRegression : public ObjFunction {
: info.weights_.ConstHostSpan()};
preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto n_samples = info.num_row_;
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);
alpha_.SetDevice(ctx_->Device());
auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan();
linalg::ElementWiseKernel(
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
auto [sample_id, quantile_id, target_id] =
linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size());
assert(target_id == 0);
auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id];
if (d >= 0) {
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id];
gpair(sample_id, quantile_id) = GradientPair{g, h};
} else {
auto g = (-alpha[quantile_id] * weight[sample_id]);
gpair(sample_id, quantile_id) = GradientPair{g, h};
}
});
linalg::ElementWiseKernel(ctx_, gpair,
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
// j is the quantile index
// 0 is the target index
auto d = predt(i, j) - labels(i, 0);
auto h = weight[i];
if (d >= 0) {
auto g = (1.0f - alpha[j]) * weight[i];
gpair(i, j) = GradientPair{g, h};
} else {
auto g = (-alpha[j] * weight[i]);
gpair(i, j) = GradientPair{g, h};
}
});
}
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override {

View File

@@ -255,24 +255,24 @@ class PseudoHuberRegression : public FitIntercept {
auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
const float z = predt(i) - y;
const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
float grad = z / scale_sqrt;
linalg::ElementWiseKernel(
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
float z = predt(i, j) - labels(i, j);
float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
float grad = z / scale_sqrt;
auto scale = common::Sqr(slope) + common::Sqr(z);
float hess = common::Sqr(slope) / (scale * scale_sqrt);
auto scale = common::Sqr(slope) + common::Sqr(z);
float hess = common::Sqr(slope) / (scale * scale_sqrt);
auto w = weight[sample_id];
gpair(i) = {grad * w, hess * w};
});
auto w = weight[i];
gpair(i) = {grad * w, hess * w};
});
}
[[nodiscard]] const char* DefaultEvalMetric() const override { return "mphe"; }
@@ -635,20 +635,21 @@ class MeanAbsoluteError : public ObjFunction {
auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable {
auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
};
auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape());
auto grad = sign(predt(i) - y) * weight[sample_id];
auto hess = weight[sample_id];
gpair(sample_id, target_id) = GradientPair{grad, hess};
});
linalg::ElementWiseKernel(
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
};
auto y = labels(i, j);
auto hess = weight[i];
auto grad = sign(predt(i, j) - y) * hess;
gpair(i, j) = GradientPair{grad, hess};
});
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {