Support multi-target, fit intercept for hinge. (#9850)

This commit is contained in:
Jiaming Yuan 2023-12-08 05:50:41 +08:00 committed by GitHub
parent 39c637ee19
commit 42de9206fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 221 additions and 155 deletions

View File

@ -1,31 +1,48 @@
/*! /**
* Copyright 2021-2022 by XGBoost Contributors * Copyright 2021-2023, XGBoost Contributors
*/ */
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_
#define XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_
#include "device_helpers.cuh" #include <cstdint> // for int32_t
#include <cstdlib> // for size_t
#include <tuple> // for apply
#include "device_helpers.cuh" // for LaunchN
#include "linalg_op.h" #include "linalg_op.h"
#include "xgboost/context.h" #include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" #include "xgboost/linalg.h" // for TensorView
namespace xgboost { namespace xgboost {
namespace linalg { namespace linalg {
template <typename T, int32_t D, typename Fn> namespace cuda_impl {
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) { // Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); // lambda inside constexpr if
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value, template <typename T, std::int32_t D>
"For function with return, use transform instead."); struct ElementWiseImpl {
if (t.Contiguous()) { template <typename Fn>
auto ptr = t.Values().data(); void operator()(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); }); static_assert(D > 1);
} else { dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
fn(i, v);
}); });
} }
};
template <typename T>
struct ElementWiseImpl<T, 1> {
template <typename Fn>
void operator()(linalg::TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
} }
};
template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernel(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
}
} // namespace cuda_impl
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) { void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
@ -42,7 +59,8 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) { void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn); ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
: ElementWiseKernelHost(t, ctx->Threads(), fn);
} }
} // namespace linalg } // namespace linalg
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2021-2022 by XGBoost Contributors * Copyright 2021-2023, XGBoost Contributors
*/ */
#ifndef XGBOOST_COMMON_LINALG_OP_H_ #ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_
@ -27,17 +27,23 @@ void ElementWiseTransformHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&
} }
} }
template <typename T, int32_t D, typename Fn> template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) { void ElementWiseKernelHost(linalg::TensorView<T, D> t, std::int32_t n_threads, Fn &&fn) {
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value, if constexpr (D == 1) {
"For function with return, use transform instead."); common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) { fn(i); });
if (t.Contiguous()) { } else if (D == 2 && t.CContiguous() && t.Shape(0) > t.Shape(1) * 64) {
auto ptr = t.Values().data(); // Heuristic. Tall, c-contiguous matrix,
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); }); auto n_rows = t.Shape(0);
auto n_columns = t.Shape(1);
common::ParallelFor(n_rows, n_threads, [&](std::size_t i) {
for (std::size_t j = 0; j < n_columns; ++j) {
fn(i, j);
}
});
} else { } else {
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) {
auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); auto idx = linalg::UnravelIndex(i, t.Shape());
fn(i, v); std::apply(fn, idx);
}); });
} }
} }

View File

@ -4,71 +4,85 @@
* \brief Provides an implementation of the hinge loss function * \brief Provides an implementation of the hinge loss function
* \author Henry Gouk * \author Henry Gouk
*/ */
#include "xgboost/objective.h" #include <algorithm> // for max
#include "xgboost/json.h" #include <cstddef> // for size_t
#include "xgboost/span.h" #include <cstdint> // for int32_t
#include "xgboost/host_device_vector.h"
#include "../common/math.h" #include "../common/common.h" // for Range
#include "../common/transform.h" #if defined(XGBOOST_USE_CUDA)
#include "../common/common.h" #include "../common/linalg_op.cuh"
#endif
#include "../common/linalg_op.h"
#include "../common/optional_weight.h" // for OptionalWeights
#include "../common/transform.h" // for Transform
#include "init_estimation.h" // for FitIntercept
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/json.h" // for Json
#include "xgboost/linalg.h" // for UnravelIndex
#include "xgboost/span.h" // for Span
namespace xgboost::obj { namespace xgboost::obj {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu); DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
class HingeObj : public ObjFunction { class HingeObj : public FitIntercept {
public: public:
HingeObj() = default; HingeObj() = default;
void Configure(Args const &) override {} void Configure(Args const &) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; } ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info, [[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override {
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override { // Multi-target regression.
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
CHECK_EQ(preds.Size(), info.labels.Size()) }
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size()
<< ", label.size=" << info.labels.Size();
const size_t ndata = preds.Size(); void GetGradient(HostDeviceVector<float> const &preds, MetaInfo const &info,
const bool is_null_weight = info.weights_.Size() == 0; std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
if (!is_null_weight) { CheckInitInputs(info);
CHECK_EQ(info.weights_.Size(), ndata) CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points."; << "Number of weights should be equal to number of data points.";
} }
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported.";
out_gpair->Reshape(ndata, 1); bst_target_t n_targets = this->Targets(info);
common::Transform<>::Init( out_gpair->Reshape(info.num_row_, n_targets);
[=] XGBOOST_DEVICE(size_t _idx, auto gpair = out_gpair->View(ctx_->Device());
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds, preds.SetDevice(ctx_->Device());
common::Span<const bst_float> _labels, auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx]; auto labels = info.labels.View(ctx_->Device());
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx] * 2.0 - 1.0; info.weights_.SetDevice(ctx_->Device());
bst_float g, h; common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(this->ctx_, labels,
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto w = weight[i];
auto p = predt(i, j);
auto y = labels(i, j) * 2.0 - 1.0;
float g, h;
if (p * y < 1.0) { if (p * y < 1.0) {
g = -y * w; g = -y * w;
h = w; h = w;
} else { } else {
g = 0.0; g = 0.0;
h = std::numeric_limits<bst_float>::min(); h = std::numeric_limits<float>::min();
} }
_out_gpair[_idx] = GradientPair(g, h); gpair(i, j) = GradientPair{g, h};
}, });
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
ctx_->Device()).Eval(
out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
} }
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override { void PredTransform(HostDeviceVector<float> *io_preds) const override {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(std::size_t _idx, common::Span<float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(), common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(),
@ -76,9 +90,7 @@ class HingeObj : public ObjFunction {
.Eval(io_preds); .Eval(io_preds);
} }
[[nodiscard]] const char* DefaultEvalMetric() const override { [[nodiscard]] const char *DefaultEvalMetric() const override { return "error"; }
return "error";
}
void SaveConfig(Json *p_out) const override { void SaveConfig(Json *p_out) const override {
auto &out = *p_out; auto &out = *p_out;

View File

@ -75,26 +75,23 @@ class QuantileRegression : public ObjFunction {
: info.weights_.ConstHostSpan()}; : info.weights_.ConstHostSpan()};
preds.SetDevice(ctx_->Device()); preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds); auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);
auto n_samples = info.num_row_;
alpha_.SetDevice(ctx_->Device()); alpha_.SetDevice(ctx_->Device());
auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan(); auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan();
linalg::ElementWiseKernel( linalg::ElementWiseKernel(ctx_, gpair,
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto [sample_id, quantile_id, target_id] = // j is the quantile index
linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size()); // 0 is the target index
assert(target_id == 0); auto d = predt(i, j) - labels(i, 0);
auto h = weight[i];
auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id];
if (d >= 0) { if (d >= 0) {
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id]; auto g = (1.0f - alpha[j]) * weight[i];
gpair(sample_id, quantile_id) = GradientPair{g, h}; gpair(i, j) = GradientPair{g, h};
} else { } else {
auto g = (-alpha[quantile_id] * weight[sample_id]); auto g = (-alpha[j] * weight[i]);
gpair(sample_id, quantile_id) = GradientPair{g, h}; gpair(i, j) = GradientPair{g, h};
} }
}); });
} }

View File

@ -255,22 +255,22 @@ class PseudoHuberRegression : public FitIntercept {
auto gpair = out_gpair->View(ctx_->Device()); auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->Device()); preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds); auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
info.weights_.SetDevice(ctx_->Device()); info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()}; : info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { linalg::ElementWiseKernel(
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
const float z = predt(i) - y; float z = predt(i, j) - labels(i, j);
const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope)); float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
float grad = z / scale_sqrt; float grad = z / scale_sqrt;
auto scale = common::Sqr(slope) + common::Sqr(z); auto scale = common::Sqr(slope) + common::Sqr(z);
float hess = common::Sqr(slope) / (scale * scale_sqrt); float hess = common::Sqr(slope) / (scale * scale_sqrt);
auto w = weight[sample_id]; auto w = weight[i];
gpair(i) = {grad * w, hess * w}; gpair(i) = {grad * w, hess * w};
}); });
} }
@ -635,19 +635,20 @@ class MeanAbsoluteError : public ObjFunction {
auto gpair = out_gpair->View(ctx_->Device()); auto gpair = out_gpair->View(ctx_->Device());
preds.SetDevice(ctx_->Device()); preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds); auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
info.weights_.SetDevice(ctx_->Device()); info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()}; : info.weights_.ConstHostSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable { linalg::ElementWiseKernel(
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto sign = [](auto x) { auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0)); return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
}; };
auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape()); auto y = labels(i, j);
auto grad = sign(predt(i) - y) * weight[sample_id]; auto hess = weight[i];
auto hess = weight[sample_id]; auto grad = sign(predt(i, j) - y) * hess;
gpair(sample_id, target_id) = GradientPair{grad, hess}; gpair(i, j) = GradientPair{grad, hess};
}); });
} }

View File

@ -23,7 +23,7 @@ void TestElementWiseKernel() {
ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; }); ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
// CPU view // CPU view
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All()); t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
size_t k = 0; std::size_t k = 0;
for (size_t i = 0; i < l.Shape(0); ++i) { for (size_t i = 0; i < l.Shape(0); ++i) {
for (size_t j = 0; j < l.Shape(2); ++j) { for (size_t j = 0; j < l.Shape(2); ++j) {
ASSERT_EQ(k++, t(i, j)); ASSERT_EQ(k++, t(i, j));
@ -31,7 +31,15 @@ void TestElementWiseKernel() {
} }
t = l.View(device).Slice(linalg::All(), 1, linalg::All()); t = l.View(device).Slice(linalg::All(), 1, linalg::All());
ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); }); cuda_impl::ElementWiseKernel(
t, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { t(i, j) = i + j; });
t = l.Slice(linalg::All(), 1, linalg::All());
for (size_t i = 0; i < l.Shape(0); ++i) {
for (size_t j = 0; j < l.Shape(2); ++j) {
ASSERT_EQ(i + j, t(i, j));
}
}
} }
{ {

View File

@ -31,12 +31,10 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
auto const& h_result = info.labels.View(DeviceOrd::CPU()); auto const& h_result = info.labels.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2); ASSERT_EQ(h_result.Shape().size(), 2);
auto in_labels = labels.View(DeviceOrd::CPU()); auto in_labels = labels.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) { linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, std::size_t j) {
auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup);
auto i1 = std::get<1>(tup);
// Sliced at second dimension. // Sliced at second dimension.
auto v_1 = in_labels(i0, 0, i1); auto v_0 = h_result(i, j);
auto v_1 = in_labels(i, 0, j);
CHECK_EQ(v_0, v_1); CHECK_EQ(v_0, v_1);
}); });
} }
@ -65,12 +63,11 @@ inline void TestMetaInfoStridedData(DeviceOrd device) {
auto const& h_result = info.base_margin_.View(DeviceOrd::CPU()); auto const& h_result = info.base_margin_.View(DeviceOrd::CPU());
ASSERT_EQ(h_result.Shape().size(), 2); ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(DeviceOrd::CPU()); auto in_margin = base_margin.View(DeviceOrd::CPU());
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(),
auto tup = linalg::UnravelIndex(i, h_result.Shape()); [&](std::size_t i, std::size_t j) {
auto i0 = std::get<0>(tup);
auto i1 = std::get<1>(tup);
// Sliced at second dimension. // Sliced at second dimension.
auto v_1 = in_margin(i0, 0, i1); auto v_0 = h_result(i, j);
auto v_1 = in_margin(i, 0, j);
CHECK_EQ(v_0, v_1); CHECK_EQ(v_0, v_1);
}); });
} }

View File

@ -1,28 +1,55 @@
// Copyright by Contributors /**
* Copyright 2018-2023, XGBoost Contributors
*/
#include <xgboost/objective.h> #include <xgboost/objective.h>
#include <xgboost/context.h> #include <xgboost/context.h>
#include <limits> #include <limits>
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/common/linalg_op.h"
namespace xgboost { namespace xgboost {
TEST(Objective, DeclareUnifiedTest(HingeObj)) { TEST(Objective, DeclareUnifiedTest(HingeObj)) {
Context ctx = MakeCUDACtx(GPUIDX); Context ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("binary:hinge", &ctx)}; std::unique_ptr<ObjFunction> obj{ObjFunction::Create("binary:hinge", &ctx)};
float eps = std::numeric_limits<xgboost::bst_float>::min(); float eps = std::numeric_limits<xgboost::bst_float>::min();
CheckObjFunction(obj, std::vector<float> predt{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f};
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}, std::vector<float> label{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f};
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}, std::vector<float> grad{0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f};
{ 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, std::vector<float> hess{eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps};
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
CheckObjFunction(obj,
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f},
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f},
{}, // Empty weight.
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
ASSERT_NO_THROW(obj->DefaultEvalMetric()); CheckObjFunction(obj, predt, label, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, grad, hess);
CheckObjFunction(obj, predt, label, {/* Empty weight. */}, grad, hess);
ASSERT_EQ(obj->DefaultEvalMetric(), StringView{"error"});
MetaInfo info;
info.num_row_ = label.size();
info.labels.Reshape(info.num_row_, 3);
ASSERT_EQ(obj->Targets(info), 3);
auto h_labels = info.labels.HostView();
for (std::size_t j = 0; j < obj->Targets(info); ++j) {
for (std::size_t i = 0; i < info.num_row_; ++i) {
h_labels(i, j) = label[i];
}
}
linalg::Tensor<float, 2> t_predt{};
t_predt.Reshape(info.labels.Shape());
for (std::size_t j = 0; j < obj->Targets(info); ++j) {
for (std::size_t i = 0; i < info.num_row_; ++i) {
t_predt(i, j) = predt[i];
}
}
linalg::Matrix<GradientPair> out_gpair;
obj->GetGradient(*t_predt.Data(), info, 0, &out_gpair);
for (std::size_t j = 0; j < obj->Targets(info); ++j) {
auto gh = out_gpair.Slice(linalg::All(), j);
ASSERT_EQ(gh.Size(), info.num_row_);
for (std::size_t i = 0; i < gh.Size(); ++i) {
ASSERT_EQ(gh(i).GetGrad(), grad[i]);
ASSERT_EQ(gh(i).GetHess(), hess[i]);
}
}
} }
} // namespace xgboost } // namespace xgboost