Support multi-target, fit intercept for hinge. (#9850)
This commit is contained in:
parent
39c637ee19
commit
42de9206fc
@ -1,31 +1,48 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
#define XGBOOST_COMMON_LINALG_OP_CUH_
|
||||
|
||||
#include "device_helpers.cuh"
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for size_t
|
||||
#include <tuple> // for apply
|
||||
|
||||
#include "device_helpers.cuh" // for LaunchN
|
||||
#include "linalg_op.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/linalg.h" // for TensorView
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
|
||||
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
||||
"For function with return, use transform instead.");
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); });
|
||||
} else {
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
|
||||
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||
fn(i, v);
|
||||
namespace cuda_impl {
|
||||
// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
|
||||
// lambda inside constexpr if
|
||||
template <typename T, std::int32_t D>
|
||||
struct ElementWiseImpl {
|
||||
template <typename Fn>
|
||||
void operator()(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
|
||||
static_assert(D > 1);
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
|
||||
std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ElementWiseImpl<T, 1> {
|
||||
template <typename Fn>
|
||||
void operator()(linalg::TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
|
||||
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, std::int32_t D, typename Fn>
|
||||
void ElementWiseKernel(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
|
||||
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||
@ -42,7 +59,8 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
|
||||
ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
|
||||
: ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_LINALG_OP_H_
|
||||
#define XGBOOST_COMMON_LINALG_OP_H_
|
||||
@ -27,17 +27,23 @@ void ElementWiseTransformHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int32_t D, typename Fn>
|
||||
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
|
||||
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
|
||||
"For function with return, use transform instead.");
|
||||
if (t.Contiguous()) {
|
||||
auto ptr = t.Values().data();
|
||||
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); });
|
||||
template <typename T, std::int32_t D, typename Fn>
|
||||
void ElementWiseKernelHost(linalg::TensorView<T, D> t, std::int32_t n_threads, Fn &&fn) {
|
||||
if constexpr (D == 1) {
|
||||
common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) { fn(i); });
|
||||
} else if (D == 2 && t.CContiguous() && t.Shape(0) > t.Shape(1) * 64) {
|
||||
// Heuristic. Tall, c-contiguous matrix,
|
||||
auto n_rows = t.Shape(0);
|
||||
auto n_columns = t.Shape(1);
|
||||
common::ParallelFor(n_rows, n_threads, [&](std::size_t i) {
|
||||
for (std::size_t j = 0; j < n_columns; ++j) {
|
||||
fn(i, j);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
|
||||
auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
||||
fn(i, v);
|
||||
common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) {
|
||||
auto idx = linalg::UnravelIndex(i, t.Shape());
|
||||
std::apply(fn, idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,71 +4,85 @@
|
||||
* \brief Provides an implementation of the hinge loss function
|
||||
* \author Henry Gouk
|
||||
*/
|
||||
#include "xgboost/objective.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include <algorithm> // for max
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
|
||||
#include "../common/math.h"
|
||||
#include "../common/transform.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/common.h" // for Range
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include "../common/linalg_op.cuh"
|
||||
#endif
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/optional_weight.h" // for OptionalWeights
|
||||
#include "../common/transform.h" // for Transform
|
||||
#include "init_estimation.h" // for FitIntercept
|
||||
#include "xgboost/data.h" // for MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "xgboost/linalg.h" // for UnravelIndex
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::obj {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
class HingeObj : public ObjFunction {
|
||||
class HingeObj : public FitIntercept {
|
||||
public:
|
||||
HingeObj() = default;
|
||||
|
||||
void Configure(Args const&) override {}
|
||||
void Configure(Args const &) override {}
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
|
||||
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size()
|
||||
<< ", label.size=" << info.labels.Size();
|
||||
[[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override {
|
||||
// Multi-target regression.
|
||||
return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
|
||||
}
|
||||
|
||||
const size_t ndata = preds.Size();
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
void GetGradient(HostDeviceVector<float> const &preds, MetaInfo const &info,
|
||||
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
|
||||
CheckInitInputs(info);
|
||||
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
|
||||
if (!info.weights_.Empty()) {
|
||||
CHECK_EQ(info.weights_.Size(), info.num_row_)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported.";
|
||||
out_gpair->Reshape(ndata, 1);
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
bst_float p = _preds[_idx];
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float y = _labels[_idx] * 2.0 - 1.0;
|
||||
bst_float g, h;
|
||||
|
||||
bst_target_t n_targets = this->Targets(info);
|
||||
out_gpair->Reshape(info.num_row_, n_targets);
|
||||
auto gpair = out_gpair->View(ctx_->Device());
|
||||
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);
|
||||
|
||||
auto labels = info.labels.View(ctx_->Device());
|
||||
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
linalg::ElementWiseKernel(this->ctx_, labels,
|
||||
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
|
||||
auto w = weight[i];
|
||||
|
||||
auto p = predt(i, j);
|
||||
auto y = labels(i, j) * 2.0 - 1.0;
|
||||
|
||||
float g, h;
|
||||
if (p * y < 1.0) {
|
||||
g = -y * w;
|
||||
h = w;
|
||||
} else {
|
||||
g = 0.0;
|
||||
h = std::numeric_limits<bst_float>::min();
|
||||
h = std::numeric_limits<float>::min();
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair(g, h);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
|
||||
ctx_->Device()).Eval(
|
||||
out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
|
||||
gpair(i, j) = GradientPair{g, h};
|
||||
});
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) const override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
[] XGBOOST_DEVICE(std::size_t _idx, common::Span<float> _preds) {
|
||||
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(),
|
||||
@ -76,12 +90,10 @@ class HingeObj : public ObjFunction {
|
||||
.Eval(io_preds);
|
||||
}
|
||||
|
||||
[[nodiscard]] const char* DefaultEvalMetric() const override {
|
||||
return "error";
|
||||
}
|
||||
[[nodiscard]] const char *DefaultEvalMetric() const override { return "error"; }
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
void SaveConfig(Json *p_out) const override {
|
||||
auto &out = *p_out;
|
||||
out["name"] = String("binary:hinge");
|
||||
}
|
||||
void LoadConfig(Json const &) override {}
|
||||
@ -89,7 +101,7 @@ class HingeObj : public ObjFunction {
|
||||
|
||||
// register the objective functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
|
||||
.describe("Hinge loss. Expects labels to be in [0,1f]")
|
||||
.set_body([]() { return new HingeObj(); });
|
||||
.describe("Hinge loss. Expects labels to be in [0,1f]")
|
||||
.set_body([]() { return new HingeObj(); });
|
||||
|
||||
} // namespace xgboost::obj
|
||||
|
||||
@ -75,26 +75,23 @@ class QuantileRegression : public ObjFunction {
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predt = linalg::MakeVec(&preds);
|
||||
auto n_samples = info.num_row_;
|
||||
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);
|
||||
|
||||
alpha_.SetDevice(ctx_->Device());
|
||||
auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan();
|
||||
|
||||
linalg::ElementWiseKernel(
|
||||
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
|
||||
auto [sample_id, quantile_id, target_id] =
|
||||
linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size());
|
||||
assert(target_id == 0);
|
||||
|
||||
auto d = predt(i) - labels(sample_id, target_id);
|
||||
auto h = weight[sample_id];
|
||||
linalg::ElementWiseKernel(ctx_, gpair,
|
||||
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
|
||||
// j is the quantile index
|
||||
// 0 is the target index
|
||||
auto d = predt(i, j) - labels(i, 0);
|
||||
auto h = weight[i];
|
||||
if (d >= 0) {
|
||||
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id];
|
||||
gpair(sample_id, quantile_id) = GradientPair{g, h};
|
||||
auto g = (1.0f - alpha[j]) * weight[i];
|
||||
gpair(i, j) = GradientPair{g, h};
|
||||
} else {
|
||||
auto g = (-alpha[quantile_id] * weight[sample_id]);
|
||||
gpair(sample_id, quantile_id) = GradientPair{g, h};
|
||||
auto g = (-alpha[j] * weight[i]);
|
||||
gpair(i, j) = GradientPair{g, h};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -255,22 +255,22 @@ class PseudoHuberRegression : public FitIntercept {
|
||||
auto gpair = out_gpair->View(ctx_->Device());
|
||||
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predt = linalg::MakeVec(&preds);
|
||||
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
|
||||
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
|
||||
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
|
||||
const float z = predt(i) - y;
|
||||
const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
|
||||
linalg::ElementWiseKernel(
|
||||
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
|
||||
float z = predt(i, j) - labels(i, j);
|
||||
float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
|
||||
float grad = z / scale_sqrt;
|
||||
|
||||
auto scale = common::Sqr(slope) + common::Sqr(z);
|
||||
float hess = common::Sqr(slope) / (scale * scale_sqrt);
|
||||
|
||||
auto w = weight[sample_id];
|
||||
auto w = weight[i];
|
||||
gpair(i) = {grad * w, hess * w};
|
||||
});
|
||||
}
|
||||
@ -635,19 +635,20 @@ class MeanAbsoluteError : public ObjFunction {
|
||||
auto gpair = out_gpair->View(ctx_->Device());
|
||||
|
||||
preds.SetDevice(ctx_->Device());
|
||||
auto predt = linalg::MakeVec(&preds);
|
||||
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
|
||||
: info.weights_.ConstHostSpan()};
|
||||
|
||||
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable {
|
||||
linalg::ElementWiseKernel(
|
||||
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
|
||||
auto sign = [](auto x) {
|
||||
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
|
||||
};
|
||||
auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape());
|
||||
auto grad = sign(predt(i) - y) * weight[sample_id];
|
||||
auto hess = weight[sample_id];
|
||||
gpair(sample_id, target_id) = GradientPair{grad, hess};
|
||||
auto y = labels(i, j);
|
||||
auto hess = weight[i];
|
||||
auto grad = sign(predt(i, j) - y) * hess;
|
||||
gpair(i, j) = GradientPair{grad, hess};
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ void TestElementWiseKernel() {
|
||||
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
|
||||
// CPU view
|
||||
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
|
||||
size_t k = 0;
|
||||
std::size_t k = 0;
|
||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||
for (size_t j = 0; j < l.Shape(2); ++j) {
|
||||
ASSERT_EQ(k++, t(i, j));
|
||||
@ -31,7 +31,15 @@ void TestElementWiseKernel() {
|
||||
}
|
||||
|
||||
t = l.View(device).Slice(linalg::All(), 1, linalg::All());
|
||||
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
|
||||
cuda_impl::ElementWiseKernel(
|
||||
t, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { t(i, j) = i + j; });
|
||||
|
||||
t = l.Slice(linalg::All(), 1, linalg::All());
|
||||
for (size_t i = 0; i < l.Shape(0); ++i) {
|
||||
for (size_t j = 0; j < l.Shape(2); ++j) {
|
||||
ASSERT_EQ(i + j, t(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@ -31,12 +31,10 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
|
||||
auto const& h_result = info.labels.View(DeviceOrd::CPU());
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_labels = labels.View(DeviceOrd::CPU());
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
|
||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||
auto i0 = std::get<0>(tup);
|
||||
auto i1 = std::get<1>(tup);
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, std::size_t j) {
|
||||
// Sliced at second dimension.
|
||||
auto v_1 = in_labels(i0, 0, i1);
|
||||
auto v_0 = h_result(i, j);
|
||||
auto v_1 = in_labels(i, 0, j);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
});
|
||||
}
|
||||
@ -65,12 +63,11 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
|
||||
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_margin = base_margin.View(DeviceOrd::CPU());
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) {
|
||||
auto tup = linalg::UnravelIndex(i, h_result.Shape());
|
||||
auto i0 = std::get<0>(tup);
|
||||
auto i1 = std::get<1>(tup);
|
||||
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(),
|
||||
[&](std::size_t i, std::size_t j) {
|
||||
// Sliced at second dimension.
|
||||
auto v_1 = in_margin(i0, 0, i1);
|
||||
auto v_0 = h_result(i, j);
|
||||
auto v_1 = in_margin(i, 0, j);
|
||||
CHECK_EQ(v_0, v_1);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1,28 +1,55 @@
|
||||
// Copyright by Contributors
|
||||
/**
|
||||
* Copyright 2018-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <limits>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/linalg_op.h"
|
||||
namespace xgboost {
|
||||
TEST(Objective, DeclareUnifiedTest(HingeObj)) {
|
||||
Context ctx = MakeCUDACtx(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("binary:hinge", &ctx)};
|
||||
|
||||
float eps = std::numeric_limits<xgboost::bst_float>::min();
|
||||
CheckObjFunction(obj,
|
||||
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f},
|
||||
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
{ 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
|
||||
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
|
||||
CheckObjFunction(obj,
|
||||
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f},
|
||||
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
{}, // Empty weight.
|
||||
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
|
||||
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
|
||||
std::vector<float> predt{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f};
|
||||
std::vector<float> label{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
std::vector<float> grad{0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f};
|
||||
std::vector<float> hess{eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps};
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
CheckObjFunction(obj, predt, label, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, grad, hess);
|
||||
CheckObjFunction(obj, predt, label, {/* Empty weight. */}, grad, hess);
|
||||
|
||||
ASSERT_EQ(obj->DefaultEvalMetric(), StringView{"error"});
|
||||
|
||||
MetaInfo info;
|
||||
info.num_row_ = label.size();
|
||||
info.labels.Reshape(info.num_row_, 3);
|
||||
ASSERT_EQ(obj->Targets(info), 3);
|
||||
auto h_labels = info.labels.HostView();
|
||||
for (std::size_t j = 0; j < obj->Targets(info); ++j) {
|
||||
for (std::size_t i = 0; i < info.num_row_; ++i) {
|
||||
h_labels(i, j) = label[i];
|
||||
}
|
||||
}
|
||||
linalg::Tensor<float, 2> t_predt{};
|
||||
t_predt.Reshape(info.labels.Shape());
|
||||
for (std::size_t j = 0; j < obj->Targets(info); ++j) {
|
||||
for (std::size_t i = 0; i < info.num_row_; ++i) {
|
||||
t_predt(i, j) = predt[i];
|
||||
}
|
||||
}
|
||||
linalg::Matrix<GradientPair> out_gpair;
|
||||
obj->GetGradient(*t_predt.Data(), info, 0, &out_gpair);
|
||||
|
||||
for (std::size_t j = 0; j < obj->Targets(info); ++j) {
|
||||
auto gh = out_gpair.Slice(linalg::All(), j);
|
||||
ASSERT_EQ(gh.Size(), info.num_row_);
|
||||
for (std::size_t i = 0; i < gh.Size(); ++i) {
|
||||
ASSERT_EQ(gh(i).GetGrad(), grad[i]);
|
||||
ASSERT_EQ(gh(i).GetHess(), hess[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user