Implement slope for Pseduo-Huber. (#7727)

* Add objective and metric.
* Some refactoring for CPU/GPU dispatching using linalg module.
This commit is contained in:
Jiaming Yuan 2022-03-14 21:42:38 +08:00 committed by GitHub
parent 4dafb5fac8
commit 98d6faefd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 456 additions and 290 deletions

View File

@ -76,8 +76,9 @@
#include "../src/common/quantile.cc" #include "../src/common/quantile.cc"
#include "../src/common/host_device_vector.cc" #include "../src/common/host_device_vector.cc"
#include "../src/common/hist_util.cc" #include "../src/common/hist_util.cc"
#include "../src/common/json.cc"
#include "../src/common/io.cc" #include "../src/common/io.cc"
#include "../src/common/json.cc"
#include "../src/common/pseudo_huber.cc"
#include "../src/common/survival_util.cc" #include "../src/common/survival_util.cc"
#include "../src/common/threading_utils.cc" #include "../src/common/threading_utils.cc"
#include "../src/common/version.cc" #include "../src/common/version.cc"

View File

@ -204,6 +204,14 @@
} }
} }
}, },
"pseduo_huber_param": {
"type": "object",
"properties": {
"huber_slope": {
"type": "string"
}
}
},
"aft_loss_param": { "aft_loss_param": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@ -338,15 +338,6 @@ Parameters for Linear Booster (``booster=gblinear``)
- The number of top features to select in ``greedy`` and ``thrifty`` feature selector. The value of 0 means using all the features. - The number of top features to select in ``greedy`` and ``thrifty`` feature selector. The value of 0 means using all the features.
Parameters for Tweedie Regression (``objective=reg:tweedie``)
=============================================================
* ``tweedie_variance_power`` [default=1.5]
- Parameter that controls the variance of the Tweedie distribution ``var(y) ~ E(y)^tweedie_variance_power``
- range: (1,2)
- Set closer to 2 to shift towards a gamma distribution
- Set closer to 1 to shift towards a Poisson distribution.
************************ ************************
Learning Task Parameters Learning Task Parameters
************************ ************************
@ -356,14 +347,14 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:squarederror``: regression with squared loss. - ``reg:squarederror``: regression with squared loss.
- ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective. - ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective.
- ``reg:logistic``: logistic regression - ``reg:logistic``: logistic regression.
- ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss.
- ``binary:logistic``: logistic regression for binary classification, output probability - ``binary:logistic``: logistic regression for binary classification, output probability
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
- ``count:poisson``: Poisson regression for count data, output mean of Poisson distribution. - ``count:poisson``: Poisson regression for count data, output mean of Poisson distribution.
- ``max_delta_step`` is set to 0.7 by default in Poisson regression (used to safeguard optimization) + ``max_delta_step`` is set to 0.7 by default in Poisson regression (used to safeguard optimization)
- ``survival:cox``: Cox regression for right censored survival time data (negative values are considered right censored). - ``survival:cox``: Cox regression for right censored survival time data (negative values are considered right censored).
Note that predictions are returned on the hazard ratio scale (i.e., as HR = exp(marginal_prediction) in the proportional hazard function ``h(t) = h0(t) * HR``). Note that predictions are returned on the hazard ratio scale (i.e., as HR = exp(marginal_prediction) in the proportional hazard function ``h(t) = h0(t) * HR``).
@ -435,6 +426,20 @@ Specify the learning task and the corresponding learning objective. The objectiv
- Seed PRNG determnisticly via iterator number. - Seed PRNG determnisticly via iterator number.
Parameters for Tweedie Regression (``objective=reg:tweedie``)
=============================================================
* ``tweedie_variance_power`` [default=1.5]
- Parameter that controls the variance of the Tweedie distribution ``var(y) ~ E(y)^tweedie_variance_power``
- range: (1,2)
- Set closer to 2 to shift towards a gamma distribution
- Set closer to 1 to shift towards a Poisson distribution.
Parameter for using Pseudo-Huber (``reg:pseudohubererror``)
===========================================================
* ``huber_slope`` : A parameter used for Pseudo-Huber loss to define the :math:`\delta` term. [default = 1.0]
*********************** ***********************
Command Line Parameters Command Line Parameters
*********************** ***********************

View File

@ -49,6 +49,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
*/ */
int32_t Threads() const; int32_t Threads() const;
bool IsCPU() const { return gpu_id == kCpuId; }
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(GenericParameter) { DMLC_DECLARE_PARAMETER(GenericParameter) {
DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe( DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe(

View File

@ -545,8 +545,19 @@ using VectorView = TensorView<T, 1>;
*/ */
template <typename T> template <typename T>
auto MakeVec(T *ptr, size_t s, int32_t device = -1) { auto MakeVec(T *ptr, size_t s, int32_t device = -1) {
using U = std::remove_const_t<std::remove_pointer_t<decltype(ptr)>> const; return linalg::TensorView<T, 1>{{ptr, s}, {s}, device};
return linalg::TensorView<U, 1>{{ptr, s}, {s}, device}; }
template <typename T>
auto MakeVec(HostDeviceVector<T> *data) {
return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(),
data->Size(), data->DeviceIdx());
}
template <typename T>
auto MakeVec(HostDeviceVector<T> const *data) {
return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(),
data->Size(), data->DeviceIdx());
} }
/** /**

View File

@ -48,7 +48,10 @@ class Metric : public Configurable {
* override this function to maintain internal configuration * override this function to maintain internal configuration
* \param out pointer to output JSON object * \param out pointer to output JSON object
*/ */
void SaveConfig(Json*) const override {} void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(this->Name());
}
/*! /*!
* \brief evaluate a specific metric * \brief evaluate a specific metric

View File

@ -188,6 +188,16 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op); XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
return result; return result;
} }
struct OptionalWeights {
Span<float const> weights;
float dft{1.0f};
explicit OptionalWeights(Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
};
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_ #endif // XGBOOST_COMMON_COMMON_H_

View File

@ -1,15 +1,33 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_
#define XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_
#include "xgboost/generic_parameters.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "linalg_op.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
namespace xgboost { namespace xgboost {
namespace linalg { namespace linalg {
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) { void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {
auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); });
} else {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
fn(i, v);
});
}
}
template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
if (t.Contiguous()) { if (t.Contiguous()) {
auto ptr = t.Values().data(); auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
@ -20,6 +38,11 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
}); });
} }
} }
template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn);
}
} // namespace linalg } // namespace linalg
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_LINALG_OP_CUH_ #endif // XGBOOST_COMMON_LINALG_OP_CUH_

View File

@ -1,15 +1,19 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_COMMON_LINALG_OP_H_ #ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_
#include <type_traits>
#include "common.h"
#include "threading_utils.h" #include "threading_utils.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
namespace xgboost { namespace xgboost {
namespace linalg { namespace linalg {
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) { void ElementWiseTransformHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
if (t.Contiguous()) { if (t.Contiguous()) {
auto ptr = t.Values().data(); auto ptr = t.Values().data();
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); }); common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); });
@ -20,6 +24,41 @@ void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& f
}); });
} }
} }
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {
auto ptr = t.Values().data();
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); });
} else {
common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
fn(i, v);
});
}
}
#if !defined(XGBOOST_USE_CUDA)
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, void* s = nullptr) {
common::AssertGPUSupport();
}
template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, void* s = nullptr) {
common::AssertGPUSupport();
}
template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
if (!ctx->IsCPU()) {
common::AssertGPUSupport();
}
ElementWiseKernelHost(t, ctx->Threads(), fn);
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace linalg } // namespace linalg
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_LINALG_OP_H_ #endif // XGBOOST_COMMON_LINALG_OP_H_

View File

@ -23,7 +23,11 @@ namespace common {
* \return the transformed value. * \return the transformed value.
*/ */
XGBOOST_DEVICE inline float Sigmoid(float x) { XGBOOST_DEVICE inline float Sigmoid(float x) {
return 1.0f / (1.0f + expf(-x)); float constexpr kEps = 1e-16; // avoid 0 div
x = std::min(-x, 88.7f); // avoid exp overflow
auto denom = expf(x) + 1.0f + kEps;
auto y = 1.0f / denom;
return y;
} }
template <typename T> template <typename T>

View File

@ -0,0 +1,7 @@
/*!
* Copyright 2022, by XGBoost Contributors
*/
#include "pseudo_huber.h"
namespace xgboost {
DMLC_REGISTER_PARAMETER(PesudoHuberParam);
}

19
src/common/pseudo_huber.h Normal file
View File

@ -0,0 +1,19 @@
#ifndef XGBOOST_COMMON_PSEUDO_HUBER_H_
#define XGBOOST_COMMON_PSEUDO_HUBER_H_
/*!
* Copyright 2022, by XGBoost Contributors
*/
#include "xgboost/parameter.h"
namespace xgboost {
struct PesudoHuberParam : public XGBoostParameter<PesudoHuberParam> {
float huber_slope{1.0};
DMLC_DECLARE_PARAMETER(PesudoHuberParam) {
DMLC_DECLARE_FIELD(huber_slope)
.set_default(1.0f)
.describe("The delta term in Pseudo-Huber loss.");
}
};
} // namespace xgboost
#endif // XGBOOST_COMMON_PSEUDO_HUBER_H_

View File

@ -431,7 +431,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
auto t = p_out->View(GenericParameter::kCpuId); auto t = p_out->View(GenericParameter::kCpuId);
CHECK(t.CContiguous()); CHECK(t.CContiguous());
// FIXME(jiamingy): Remove the use of this default thread. // FIXME(jiamingy): Remove the use of this default thread.
linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape())); return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
}); });
} }
@ -877,7 +877,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(),
1, cache_file); 1, cache_file);
} else { } else {
data::FileIterator iter{fname, uint32_t(partid), uint32_t(npart), data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
file_format}; file_format};
dmat = new data::SparsePageDMatrix{ dmat = new data::SparsePageDMatrix{
&iter, &iter,

View File

@ -49,7 +49,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
} }
p_out->Reshape(array.shape); p_out->Reshape(array.shape);
auto t = p_out->View(ptr_device); auto t = p_out->View(ptr_device);
linalg::ElementWiseKernelDevice(t, [=] __device__(size_t i, T) { linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, array.shape)); return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, array.shape));
}); });
} }

View File

@ -277,6 +277,21 @@ using LearnerAPIThreadLocalStore =
using ThreadLocalPredictionCache = using ThreadLocalPredictionCache =
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>; dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
namespace {
StringView ModelMsg() {
return StringView{
R"doc(
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
older XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
for more details about differences between saving model and serializing.
)doc"};
}
} // anonymous namespace
class LearnerConfiguration : public Learner { class LearnerConfiguration : public Learner {
private: private:
std::mutex config_lock_; std::mutex config_lock_;
@ -375,7 +390,6 @@ class LearnerConfiguration : public Learner {
this->ConfigureGBM(old_tparam, args); this->ConfigureGBM(old_tparam, args);
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU()); generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
this->ConfigureMetrics(args); this->ConfigureMetrics(args);
this->need_configuration_ = false; this->need_configuration_ = false;
@ -418,9 +432,17 @@ class LearnerConfiguration : public Learner {
metric_names_.resize(n_metrics); metric_names_.resize(n_metrics);
metrics_.resize(n_metrics); metrics_.resize(n_metrics);
for (size_t i = 0; i < n_metrics; ++i) { for (size_t i = 0; i < n_metrics; ++i) {
metric_names_[i]= get<String>(j_metrics[i]); auto old_serialization = IsA<String>(j_metrics[i]);
metrics_[i] = std::unique_ptr<Metric>( if (old_serialization) {
Metric::Create(metric_names_[i], &generic_parameters_)); LOG(WARNING) << ModelMsg();
metric_names_[i] = get<String>(j_metrics[i]);
} else {
metric_names_[i] = get<String>(j_metrics[i]["name"]);
}
metrics_[i] = std::unique_ptr<Metric>(Metric::Create(metric_names_[i], &generic_parameters_));
if (!old_serialization) {
metrics_[i]->LoadConfig(j_metrics[i]);
}
} }
FromJson(learner_parameters.at("generic_param"), &generic_parameters_); FromJson(learner_parameters.at("generic_param"), &generic_parameters_);
@ -448,9 +470,9 @@ class LearnerConfiguration : public Learner {
auto& objective_fn = learner_parameters["objective"]; auto& objective_fn = learner_parameters["objective"];
obj_->SaveConfig(&objective_fn); obj_->SaveConfig(&objective_fn);
std::vector<Json> metrics(metrics_.size()); std::vector<Json> metrics(metrics_.size(), Json{Object{}});
for (size_t i = 0; i < metrics_.size(); ++i) { for (size_t i = 0; i < metrics_.size(); ++i) {
metrics[i] = String(metrics_[i]->Name()); metrics_[i]->SaveConfig(&metrics[i]);
} }
learner_parameters["metrics"] = Array(std::move(metrics)); learner_parameters["metrics"] = Array(std::move(metrics));
@ -709,21 +731,6 @@ class LearnerConfiguration : public Learner {
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
namespace {
StringView ModelMsg() {
return StringView{
R"doc(
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
older XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
for more details about differences between saving model and serializing.
)doc"};
}
} // anonymous namespace
class LearnerIO : public LearnerConfiguration { class LearnerIO : public LearnerConfiguration {
private: private:
std::set<std::string> saved_configs_ = {"num_round"}; std::set<std::string> saved_configs_ = {"num_round"};

View File

@ -33,7 +33,7 @@ namespace metric {
template <typename Fn> template <typename Fn>
std::tuple<double, double, double> std::tuple<double, double, double>
BinaryAUC(common::Span<float const> predts, linalg::VectorView<float const> labels, BinaryAUC(common::Span<float const> predts, linalg::VectorView<float const> labels,
OptionalWeights weights, common::OptionalWeights weights,
std::vector<size_t> const &sorted_idx, Fn &&area_fn) { std::vector<size_t> const &sorted_idx, Fn &&area_fn) {
CHECK_NE(labels.Size(), 0); CHECK_NE(labels.Size(), 0);
CHECK_EQ(labels.Size(), predts.size()); CHECK_EQ(labels.Size(), predts.size());
@ -93,7 +93,7 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
auto tp = results.Slice(linalg::All(), 1); auto tp = results.Slice(linalg::All(), 1);
auto auc = results.Slice(linalg::All(), 2); auto auc = results.Slice(linalg::All(), 2);
auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()};
auto predts_t = linalg::TensorView<float const, 2>( auto predts_t = linalg::TensorView<float const, 2>(
predts, {static_cast<size_t>(info.num_row_), n_classes}, predts, {static_cast<size_t>(info.num_row_), n_classes},
GenericParameter::kCpuId); GenericParameter::kCpuId);
@ -140,7 +140,7 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
std::tuple<double, double, double> BinaryROCAUC(common::Span<float const> predts, std::tuple<double, double, double> BinaryROCAUC(common::Span<float const> predts,
linalg::VectorView<float const> labels, linalg::VectorView<float const> labels,
OptionalWeights weights) { common::OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{}); auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea);
} }
@ -186,7 +186,7 @@ double GroupRankingROC(common::Span<float const> predts,
*/ */
std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts, std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
linalg::VectorView<float const> labels, linalg::VectorView<float const> labels,
OptionalWeights weights) { common::OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{}); auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
double total_pos{0}, total_neg{0}; double total_pos{0}, total_neg{0};
for (size_t i = 0; i < labels.Size(); ++i) { for (size_t i = 0; i < labels.Size(); ++i) {
@ -238,7 +238,7 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
if (is_roc) { if (is_roc) {
auc = GroupRankingROC(g_predts, g_labels, w); auc = GroupRankingROC(g_predts, g_labels, w);
} else { } else {
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w})); auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, common::OptionalWeights{w}));
} }
if (std::isnan(auc)) { if (std::isnan(auc)) {
invalid_groups++; invalid_groups++;
@ -373,7 +373,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(fp, tp, auc) = std::tie(fp, tp, auc) =
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0), BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
OptionalWeights{info.weights_.ConstHostSpan()}); common::OptionalWeights{info.weights_.ConstHostSpan()});
} else { } else {
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
tparam_->gpu_id, &this->d_cache_); tparam_->gpu_id, &this->d_cache_);
@ -426,7 +426,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
if (tparam_->gpu_id == GenericParameter::kCpuId) { if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(pr, re, auc) = std::tie(pr, re, auc) =
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
OptionalWeights{info.weights_.ConstHostSpan()}); common::OptionalWeights{info.weights_.ConstHostSpan()});
} else { } else {
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
tparam_->gpu_id, &this->d_cache_); tparam_->gpu_id, &this->d_cache_);

View File

@ -99,7 +99,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
/** /**
* Linear scan * Linear scan
*/ */
auto get_weight = OptionalWeights{weights}; auto get_weight = common::OptionalWeights{weights};
auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i]; size_t idx = d_sorted_idx[i];
@ -353,7 +353,7 @@ double GPUMultiClassAUCOVR(common::Span<float const> predts,
* Linear scan * Linear scan
*/ */
dh::caching_device_vector<double> d_auc(n_classes, 0); dh::caching_device_vector<double> d_auc(n_classes, 0);
auto get_weight = OptionalWeights{weights}; auto get_weight = common::OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp); auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i]; size_t idx = d_sorted_idx[i];
@ -633,7 +633,7 @@ GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
auto labels = info.labels.View(device); auto labels = info.labels.View(device);
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();
auto get_weight = OptionalWeights{d_weights}; auto get_weight = common::OptionalWeights{d_weights};
auto it = dh::MakeTransformIterator<Pair>( auto it = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto w = get_weight[d_sorted_idx[i]]; auto w = get_weight[d_sorted_idx[i]];
@ -687,7 +687,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
[n_samples] XGBOOST_DEVICE(size_t i) { [n_samples] XGBOOST_DEVICE(size_t i) {
return i / n_samples; // class id return i / n_samples; // class id
}); });
auto get_weight = OptionalWeights{d_weights}; auto get_weight = common::OptionalWeights{d_weights};
auto val_it = dh::MakeTransformIterator<thrust::pair<double, double>>( auto val_it = dh::MakeTransformIterator<thrust::pair<double, double>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto idx = d_sorted_idx[i] % n_samples; auto idx = d_sorted_idx[i] % n_samples;
@ -736,7 +736,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
*/ */
size_t n_samples = labels.Shape(0); size_t n_samples = labels.Shape(0);
dh::caching_device_vector<double> d_auc(n_groups, 0); dh::caching_device_vector<double> d_auc(n_groups, 0);
auto get_weight = OptionalWeights{weights}; auto get_weight = common::OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp); auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) { auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i]; size_t idx = d_sorted_idx[i];

View File

@ -112,18 +112,6 @@ struct PRAUCLabelInvalid {
inline void InvalidLabels() { inline void InvalidLabels() {
LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank."; LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank.";
} }
struct OptionalWeights {
common::Span<float const> weights;
float dft { 1.0f };
explicit OptionalWeights(common::Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const {
return weights.empty() ? dft : weights[i];
}
};
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_METRIC_AUC_H_ #endif // XGBOOST_METRIC_AUC_H_

View File

@ -1,20 +1,22 @@
/*! /*!
* Copyright 2015-2019 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file elementwise_metric.cc * \file elementwise_metric.cc
* \brief evaluation metrics for elementwise binary or regression. * \brief evaluation metrics for elementwise binary or regression.
* \author Kailong Chen, Tianqi Chen * \author Kailong Chen, Tianqi Chen
* *
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
*/ */
#include <dmlc/registry.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/metric.h> #include <xgboost/metric.h>
#include <dmlc/registry.h>
#include <cmath> #include <cmath>
#include "metric_common.h"
#include "../common/math.h"
#include "../common/common.h" #include "../common/common.h"
#include "../common/math.h"
#include "../common/pseudo_huber.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "metric_common.h"
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par #include <thrust/execution_policy.h> // thrust::cuda::par
@ -30,109 +32,63 @@ namespace metric {
// tag the this file, used by force static link later. // tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(elementwise_metric); DMLC_REGISTRY_FILE_TAG(elementwise_metric);
template <typename EvalRow> namespace {
class ElementWiseMetricsReduction { /**
public: * \brief Reduce function for element wise metrics.
explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {} *
* The loss function should handle all the computation for each sample, including
PackedReduceResult * applying the weights. A tuple of {error_i, weight_i} is expected as return.
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights, */
linalg::TensorView<float const, 2> labels, template <typename Fn>
const HostDeviceVector<bst_float> &preds, PackedReduceResult Reduce(GenericParameter const* ctx, MetaInfo const& info, Fn&& loss) {
int32_t n_threads) const { PackedReduceResult result;
size_t ndata = labels.Size(); auto labels = info.labels.View(ctx->gpu_id);
auto n_targets = std::max(labels.Shape(1), static_cast<size_t>(1)); if (ctx->IsCPU()) {
auto h_labels = labels.Values(); auto n_threads = ctx->Threads();
const auto& h_weights = weights.HostVector();
const auto& h_preds = preds.HostVector();
std::vector<double> score_tloc(n_threads, 0.0); std::vector<double> score_tloc(n_threads, 0.0);
std::vector<double> weight_tloc(n_threads, 0.0); std::vector<double> weight_tloc(n_threads, 0.0);
// We sum over losses over all samples and targets instead of performing this for each // We sum over losses over all samples and targets instead of performing this for each
// target since the first one approach more accurate while the second approach is used // target since the first one approach more accurate while the second approach is used
// for approximation in distributed setting. For rmse: // for approximation in distributed setting. For rmse:
// - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target // - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target
// - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed // - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed
common::ParallelFor(ndata, n_threads, [&](size_t i) { common::ParallelFor(info.labels.Size(), ctx->Threads(), [&](size_t i) {
float wt = h_weights.size() > 0 ? h_weights[i / n_targets] : 1.0f;
auto t_idx = omp_get_thread_num(); auto t_idx = omp_get_thread_num();
score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; size_t sample_id;
size_t target_id;
std::tie(sample_id, target_id) = linalg::UnravelIndex(i, labels.Shape());
float v, wt;
std::tie(v, wt) = loss(i, sample_id, target_id);
score_tloc[t_idx] += v;
weight_tloc[t_idx] += wt; weight_tloc[t_idx] += wt;
}); });
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
result = PackedReduceResult{residue_sum, weights_sum};
PackedReduceResult res { residue_sum, weights_sum }; } else {
return res;
}
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
PackedReduceResult DeviceReduceMetrics(
const HostDeviceVector<bst_float>& weights,
linalg::TensorView<float const, 2> labels,
const HostDeviceVector<bst_float>& preds) {
size_t n_data = preds.Size();
auto n_targets = std::max(labels.Shape(1), static_cast<size_t>(1));
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data;
auto s_label = labels.Values();
auto s_preds = preds.DeviceSpan();
auto s_weights = weights.DeviceSpan();
bool const is_null_weight = weights.Size() == 0;
auto d_policy = policy_;
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
PackedReduceResult result = thrust::transform_reduce( thrust::counting_iterator<size_t> begin(0);
thrust::cuda::par(alloc), thrust::counting_iterator<size_t> end = begin + labels.Size();
begin, end, result = thrust::transform_reduce(
[=] XGBOOST_DEVICE(size_t idx) { thrust::cuda::par(alloc), begin, end,
float weight = is_null_weight ? 1.0f : s_weights[idx / n_targets]; [=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape());
float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); auto sample_id = std::get<0>(idx);
residue *= weight; auto target_id = std::get<1>(idx);
return PackedReduceResult{ residue, weight }; auto res = loss(i, sample_id, target_id);
float v{std::get<0>(res)}, wt{std::get<1>(res)};
return PackedReduceResult{v, wt};
}, },
PackedReduceResult(), PackedReduceResult{}, thrust::plus<PackedReduceResult>());
thrust::plus<PackedReduceResult>()); #else
common::AssertGPUSupport();
return result; #endif // defined(XGBOOST_USE_CUDA)
} }
return result;
#endif // XGBOOST_USE_CUDA }
} // anonymous namespace
PackedReduceResult Reduce(const GenericParameter& ctx, const HostDeviceVector<bst_float>& weights,
linalg::Tensor<float, 2> const& labels,
const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result;
if (ctx.gpu_id < 0) {
auto n_threads = ctx.Threads();
result = CpuReduceMetrics(weights, labels.HostView(), preds, n_threads);
}
#if defined(XGBOOST_USE_CUDA)
else { // NOLINT
preds.SetDevice(ctx.gpu_id);
weights.SetDevice(ctx.gpu_id);
dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
result = DeviceReduceMetrics(weights, labels.View(ctx.gpu_id), preds);
}
#endif // defined(XGBOOST_USE_CUDA)
return result;
}
private:
EvalRow policy_;
#if defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA)
};
struct EvalRowRMSE { struct EvalRowRMSE {
char const *Name() const { char const *Name() const {
@ -187,38 +143,64 @@ struct EvalRowMAPE {
} }
}; };
namespace {
XGBOOST_DEVICE inline float LogLoss(float y, float py) {
auto xlogy = [](float x, float y) {
float eps = 1e-16;
return (x - 0.0f == 0.0f) ? 0.0f : (x * std::log(std::max(y, eps)));
};
const bst_float pneg = 1.0f - py;
return xlogy(-y, py) + xlogy(-(1.0f - y), pneg);
}
} // anonymous namespace
struct EvalRowLogLoss { struct EvalRowLogLoss {
const char *Name() const { const char *Name() const {
return "logloss"; return "logloss";
} }
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { return LogLoss(y, py); }
const bst_float eps = 1e-16f;
const bst_float pneg = 1.0f - py;
if (py < eps) {
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
} else if (pneg < eps) {
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
} else {
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
}
}
static double GetFinal(double esum, double wsum) { static double GetFinal(double esum, double wsum) {
return wsum == 0 ? esum : esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
struct EvalRowMPHE { class PseudoErrorLoss : public Metric {
char const *Name() const { PesudoHuberParam param_;
return "mphe";
public:
const char* Name() const override { return "mphe"; }
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
void LoadConfig(Json const& in) override { FromJson(in["pseduo_huber_param"], &param_); }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(this->Name());
out["pseduo_huber_param"] = ToJson(param_);
} }
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float diff = label - pred; double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info,
return std::sqrt( 1 + diff * diff) - 1; bool distributed) override {
} CHECK_EQ(info.labels.Shape(0), info.num_row_);
static double GetFinal(double esum, double wsum) { auto labels = info.labels.View(tparam_->gpu_id);
return wsum == 0 ? esum : esum / wsum; preds.SetDevice(tparam_->gpu_id);
auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
info.weights_.SetDevice(tparam_->gpu_id);
common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan());
float slope = this->param_.huber_slope;
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
PackedReduceResult result =
Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
float wt = weights[sample_id];
auto a = labels(sample_id, target_id) - predts[i];
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
return std::make_tuple(v, wt);
});
double dat[2]{result.Residue(), result.Weights()};
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
} }
}; };
@ -355,20 +337,36 @@ struct EvalTweedieNLogLik {
* \brief base class of element-wise evaluation * \brief base class of element-wise evaluation
* \tparam Derived the name of subclass * \tparam Derived the name of subclass
*/ */
template<typename Policy> template <typename Policy>
struct EvalEWiseBase : public Metric { struct EvalEWiseBase : public Metric {
EvalEWiseBase() = default; EvalEWiseBase() = default;
explicit EvalEWiseBase(char const* policy_param) : explicit EvalEWiseBase(char const* policy_param) : policy_{policy_param} {}
policy_{policy_param}, reducer_{policy_} {}
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info, double Eval(HostDeviceVector<bst_float> const& preds, const MetaInfo& info,
bool distributed) override { bool distributed) override {
CHECK_EQ(preds.Size(), info.labels.Size()) CHECK_EQ(preds.Size(), info.labels.Size())
<< "label and prediction size not match, " << "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification"; << "hint: use merror or mlogloss for multi-class classification";
auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels, preds); if (info.labels.Size() != 0) {
CHECK_NE(info.labels.Shape(1), 0);
}
auto labels = info.labels.View(tparam_->gpu_id);
info.weights_.SetDevice(tparam_->gpu_id);
common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan());
preds.SetDevice(tparam_->gpu_id);
auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan();
double dat[2] { result.Residue(), result.Weights() }; auto d_policy = policy_;
auto result =
Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) {
float wt = weights[sample_id];
float residue = d_policy.EvalRow(labels(sample_id, target_id), predts[i]);
residue *= wt;
return std::make_tuple(residue, wt);
});
double dat[2]{result.Residue(), result.Weights()};
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
@ -376,13 +374,10 @@ struct EvalEWiseBase : public Metric {
return Policy::GetFinal(dat[0], dat[1]); return Policy::GetFinal(dat[0], dat[1]);
} }
const char* Name() const override { const char* Name() const override { return policy_.Name(); }
return policy_.Name();
}
private: private:
Policy policy_; Policy policy_;
ElementWiseMetricsReduction<Policy> reducer_{policy_};
}; };
XGBOOST_REGISTER_METRIC(RMSE, "rmse") XGBOOST_REGISTER_METRIC(RMSE, "rmse")
@ -401,14 +396,14 @@ XGBOOST_REGISTER_METRIC(MAPE, "mape")
.describe("Mean absolute percentage error.") .describe("Mean absolute percentage error.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowMAPE>(); }); .set_body([](const char* param) { return new EvalEWiseBase<EvalRowMAPE>(); });
XGBOOST_REGISTER_METRIC(MPHE, "mphe")
.describe("Mean Pseudo Huber error.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowMPHE>(); });
XGBOOST_REGISTER_METRIC(LogLoss, "logloss") XGBOOST_REGISTER_METRIC(LogLoss, "logloss")
.describe("Negative loglikelihood for logistic regression.") .describe("Negative loglikelihood for logistic regression.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowLogLoss>(); }); .set_body([](const char* param) { return new EvalEWiseBase<EvalRowLogLoss>(); });
XGBOOST_REGISTER_METRIC(PseudoErrorLoss, "mphe")
.describe("Mean Pseudo-huber error.")
.set_body([](const char* param) { return new PseudoErrorLoss{}; });
XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
.describe("Negative loglikelihood for poisson regression.") .describe("Negative loglikelihood for poisson regression.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalPoissonNegLogLik>(); }); .set_body([](const char* param) { return new EvalEWiseBase<EvalPoissonNegLogLik>(); });
@ -430,6 +425,5 @@ XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
.set_body([](const char* param) { .set_body([](const char* param) {
return new EvalEWiseBase<EvalTweedieNLogLik>(param); return new EvalEWiseBase<EvalTweedieNLogLik>(param);
}); });
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost

View File

@ -105,38 +105,6 @@ struct LogisticRegression {
static ObjInfo Info() { return {ObjInfo::kRegression, false}; } static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
}; };
struct PseudoHuberError {
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) {
return x;
}
XGBOOST_DEVICE static bool CheckLabel(bst_float) {
return true;
}
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
const float z = predt - label;
const float scale_sqrt = std::sqrt(1 + std::pow(z, 2));
return z/scale_sqrt;
}
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const float scale = 1 + std::pow(predt - label, 2);
const float scale_sqrt = std::sqrt(scale);
return 1/(scale*scale_sqrt);
}
static bst_float ProbToMargin(bst_float base_score) {
return base_score;
}
static const char* LabelErrorMsg() {
return "";
}
static const char* DefaultEvalMetric() {
return "mphe";
}
static const char* Name() {
return "reg:pseudohubererror";
}
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
};
// logistic loss for binary classification task // logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression { struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "logloss"; } static const char* DefaultEvalMetric() { return "logloss"; }

View File

@ -8,23 +8,38 @@
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/objective.h> #include <xgboost/objective.h>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "../common/common.h"
#include "../common/linalg_op.h"
#include "../common/pseudo_huber.h"
#include "../common/threading_utils.h"
#include "../common/transform.h"
#include "./regression_loss.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "../common/transform.h" #if defined(XGBOOST_USE_CUDA)
#include "../common/common.h" #include "../common/linalg_op.cuh"
#include "../common/threading_utils.h" #endif // defined(XGBOOST_USE_CUDA)
#include "./regression_loss.h"
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
namespace {
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points.";
}
}
} // anonymous namespace
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
@ -64,20 +79,13 @@ class RegLossObj : public ObjFunction {
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int, const MetaInfo &info, int,
HostDeviceVector<GradientPair>* out_gpair) override { HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_EQ(preds.Size(), info.labels.Size()) CheckRegInputs(info, preds);
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
<< "Loss: " << Loss::Name();
size_t const ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto device = ctx_->gpu_id; auto device = ctx_->gpu_id;
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight; auto scale_pos_weight = param_.scale_pos_weight;
additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[1] = scale_pos_weight;
additional_input_.HostVector().begin()[2] = is_null_weight; additional_input_.HostVector().begin()[2] = is_null_weight;
@ -179,10 +187,6 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name())
.describe("Logistic regression for probability regression task.") .describe("Logistic regression for probability regression task.")
.set_body([]() { return new RegLossObj<LogisticRegression>(); }); .set_body([]() { return new RegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(PseudoHuberError, PseudoHuberError::Name())
.describe("Regression Pseudo Huber error.")
.set_body([]() { return new RegLossObj<PseudoHuberError>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name()) XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name())
.describe("Logistic regression for binary classification task.") .describe("Logistic regression for binary classification task.")
.set_body([]() { return new RegLossObj<LogisticClassification>(); }); .set_body([]() { return new RegLossObj<LogisticClassification>(); });
@ -200,6 +204,70 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
return new RegLossObj<LinearSquareLoss>(); }); return new RegLossObj<LinearSquareLoss>(); });
// End deprecated // End deprecated
class PseudoHuberRegression : public ObjFunction {
PesudoHuberParam param_;
public:
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
struct ObjInfo Task() const override { return {ObjInfo::kRegression, false}; }
uint32_t Targets(MetaInfo const& info) const override {
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
}
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info, int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds);
auto slope = param_.huber_slope;
CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0.";
auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Resize(info.labels.Size());
auto gpair = linalg::MakeVec(out_gpair);
preds.SetDevice(ctx_->gpu_id);
auto predt = linalg::MakeVec(&preds);
info.weights_.SetDevice(ctx_->gpu_id);
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
const float z = predt(i) - y;
const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
float grad = z / scale_sqrt;
auto scale = common::Sqr(slope) + common::Sqr(z);
float hess = common::Sqr(slope) / (scale * scale_sqrt);
auto w = weight[sample_id];
gpair(i) = {grad * w, hess * w};
});
}
const char* DefaultEvalMetric() const override { return "mphe"; }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("reg:pseudohubererror");
out["pseduo_huber_param"] = ToJson(param_);
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
if (config.find("pseduo_huber_param") == config.cend()) {
// The parameter is added in 1.6.
return;
}
FromJson(in["pseduo_huber_param"], &param_);
}
};
XGBOOST_REGISTER_OBJECTIVE(PseudoHuberRegression, "reg:pseudohubererror")
.describe("Regression Pseudo Huber error.")
.set_body([]() { return new PseudoHuberRegression(); });
// declare parameter // declare parameter
struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam> { struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam> {
float max_delta_step; float max_delta_step;

View File

@ -314,11 +314,11 @@ TEST(Linalg, Popc) {
TEST(Linalg, Stack) { TEST(Linalg, Stack) {
Tensor<float, 3> l{{2, 3, 4}, kCpuId}; Tensor<float, 3> l{{2, 3, 4}, kCpuId};
ElementWiseKernelHost(l.View(kCpuId), omp_get_max_threads(), ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(),
[=](size_t i, float v) { return i; }); [=](size_t i, float v) { return i; });
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId}; Tensor<float, 3> r_0{{2, 3, 4}, kCpuId};
ElementWiseKernelHost(r_0.View(kCpuId), omp_get_max_threads(), ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(),
[=](size_t i, float v) { return i; }); [=](size_t i, float v) { return i; });
Stack(&l, r_0); Stack(&l, r_0);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -19,7 +19,7 @@ void TestElementWiseKernel() {
// GPU view // GPU view
auto t = l.View(0).Slice(linalg::All(), 1, linalg::All()); auto t = l.View(0).Slice(linalg::All(), 1, linalg::All());
ASSERT_FALSE(t.CContiguous()); ASSERT_FALSE(t.CContiguous());
ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; }); ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; });
// CPU view // CPU view
t = l.View(GenericParameter::kCpuId).Slice(linalg::All(), 1, linalg::All()); t = l.View(GenericParameter::kCpuId).Slice(linalg::All(), 1, linalg::All());
size_t k = 0; size_t k = 0;
@ -30,10 +30,7 @@ void TestElementWiseKernel() {
} }
t = l.View(0).Slice(linalg::All(), 1, linalg::All()); t = l.View(0).Slice(linalg::All(), 1, linalg::All());
ElementWiseKernelDevice(t, [] __device__(size_t i, float v) { ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); });
SPAN_CHECK(v == i);
return v;
});
} }
{ {
@ -41,7 +38,7 @@ void TestElementWiseKernel() {
* Contiguous * Contiguous
*/ */
auto t = l.View(0); auto t = l.View(0);
ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; }); ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; });
ASSERT_TRUE(t.CContiguous()); ASSERT_TRUE(t.CContiguous());
// CPU view // CPU view
t = l.View(GenericParameter::kCpuId); t = l.View(GenericParameter::kCpuId);

View File

@ -29,14 +29,13 @@ inline void TestMetaInfoStridedData(int32_t device) {
auto const& h_result = info.labels.View(-1); auto const& h_result = info.labels.View(-1);
ASSERT_EQ(h_result.Shape().size(), 2); ASSERT_EQ(h_result.Shape().size(), 2);
auto in_labels = labels.View(-1); auto in_labels = labels.View(-1);
linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) {
auto tup = linalg::UnravelIndex(i, h_result.Shape()); auto tup = linalg::UnravelIndex(i, h_result.Shape());
auto i0 = std::get<0>(tup); auto i0 = std::get<0>(tup);
auto i1 = std::get<1>(tup); auto i1 = std::get<1>(tup);
// Sliced at second dimension. // Sliced at second dimension.
auto v_1 = in_labels(i0, 0, i1); auto v_1 = in_labels(i0, 0, i1);
CHECK_EQ(v_0, v_1); CHECK_EQ(v_0, v_1);
return v_0;
}); });
} }
{ {
@ -71,7 +70,6 @@ inline void TestMetaInfoStridedData(int32_t device) {
// Sliced at second dimension. // Sliced at second dimension.
auto v_1 = in_margin(i0, 0, i1); auto v_1 = in_margin(i0, 0, i1);
CHECK_EQ(v_0, v_1); CHECK_EQ(v_0, v_1);
return v_0;
}); });
} }
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2016-2020 XGBoost contributors * Copyright 2016-2022 by XGBoost contributors
*/ */
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
@ -136,8 +136,8 @@ void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> out_hess) { std::vector<xgboost::bst_float> out_hess) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
info.num_row_ = labels.size(); info.num_row_ = labels.size();
info.labels = info.labels = xgboost::linalg::Tensor<float, 2>{
xgboost::linalg::Tensor<float, 2>{labels.cbegin(), labels.cend(), {labels.size()}, -1}; labels.cbegin(), labels.cend(), {labels.size(), static_cast<size_t>(1)}, -1};
info.weights_.HostVector() = weights; info.weights_.HostVector() = weights;
info.group_ptr_ = groups; info.group_ptr_ = groups;

View File

@ -1,12 +1,13 @@
/*! /*!
* Copyright 2018-2019 XGBoost contributors * Copyright 2018-2022 by XGBoost contributors
*/ */
#include <xgboost/metric.h>
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/metric.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include "../../../src/common/linalg_op.h"
#include "../helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost {
@ -16,14 +17,17 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)}; std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)};
HostDeviceVector<float> predts; HostDeviceVector<float> predts;
size_t n_samples = 2048;
MetaInfo info; MetaInfo info;
info.labels.Reshape(n_samples, 1);
info.num_row_ = n_samples;
auto &h_labels = info.labels.Data()->HostVector(); auto &h_labels = info.labels.Data()->HostVector();
auto &h_predts = predts.HostVector(); auto &h_predts = predts.HostVector();
SimpleLCG lcg; SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f}; SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
size_t n_samples = 2048;
h_labels.resize(n_samples); h_labels.resize(n_samples);
h_predts.resize(n_samples); h_predts.resize(n_samples);
@ -145,27 +149,33 @@ TEST(Metric, DeclareUnifiedTest(MAPE)) {
TEST(Metric, DeclareUnifiedTest(MPHE)) { TEST(Metric, DeclareUnifiedTest(MPHE)) {
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("mphe", &lparam); std::unique_ptr<xgboost::Metric> metric{xgboost::Metric::Create("mphe", &lparam)};
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "mphe"); ASSERT_STREQ(metric->Name(), "mphe");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric.get(),
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}), { 0, 0, 1, 1}),
0.1751f, 1e-4); 0.1751f, 1e-4);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric.get(),
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}, { 0, 0, 1, 1},
{ -1, 1, 9, -9}), { -1, 1, 9, -9}),
3.4037f, 1e-4); 3.4037f, 1e-4);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric.get(),
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}, { 0, 0, 1, 1},
{ 1, 2, 9, 8}), { 1, 2, 9, 8}),
0.1922f, 1e-4); 0.1922f, 1e-4);
delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX);
metric->Configure({{"huber_slope", "0.1"}});
EXPECT_NEAR(GetMetricEval(metric.get(),
{0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1},
{ 1, 2, 9, 8}),
0.0461686f, 1e-4);
} }
TEST(Metric, DeclareUnifiedTest(LogLoss)) { TEST(Metric, DeclareUnifiedTest(LogLoss)) {
@ -277,7 +287,7 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
1.5783f, 0.001f); 1.5783f, 0.001f);
delete metric; delete metric;
xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"poisson-nloglik"}, GPUIDX);
} }
TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { TEST(Metric, DeclareUnifiedTest(MultiRMSE)) {
@ -288,8 +298,8 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) {
HostDeviceVector<float> predt(n_samples * n_targets, 0); HostDeviceVector<float> predt(n_samples * n_targets, 0);
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> metric{Metric::Create("rmse", &lparam)}; std::unique_ptr<Metric> metric{Metric::Create("rmse", &ctx)};
metric->Configure({}); metric->Configure({});
auto loss = GetMultiMetricEval(metric.get(), predt, y); auto loss = GetMultiMetricEval(metric.get(), predt, y);

View File

@ -57,25 +57,31 @@ TEST(Objective, DeclareUnifiedTest(SquaredLog)) {
TEST(Objective, DeclareUnifiedTest(PseudoHuber)) { TEST(Objective, DeclareUnifiedTest(PseudoHuber)) {
GenericParameter tparam = CreateEmptyGenericParam(GPUIDX); GenericParameter tparam = CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; Args args;
std::unique_ptr<ObjFunction> obj { ObjFunction::Create("reg:pseudohubererror", &tparam) }; std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:pseudohubererror", &tparam)};
obj->Configure(args); obj->Configure(args);
CheckConfigReload(obj, "reg:pseudohubererror"); CheckConfigReload(obj, "reg:pseudohubererror");
CheckObjFunction(obj, CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred
{0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad
{-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess
{ 0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred
CheckObjFunction(obj, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels
{0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred {}, // empty weights
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad
{}, // empty weights {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess
{-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad
{ 0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"});
obj->Configure({{"huber_slope", "0.1"}});
CheckConfigReload(obj, "reg:pseudohubererror");
CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights
{-0.099388f, -0.099228f, -0.098639f, -0.089443f, 0.098639f}, // out_grad
{0.0013467f, 0.001908f, 0.004443f, 0.089443f, 0.004443f}); // out_hess
} }
TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) {
@ -131,7 +137,6 @@ TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) {
std::unique_ptr<ObjFunction> obj { std::unique_ptr<ObjFunction> obj {
ObjFunction::Create("binary:logitraw", &lparam) ObjFunction::Create("binary:logitraw", &lparam)
}; };
obj->Configure(args); obj->Configure(args);
CheckObjFunction(obj, CheckObjFunction(obj,
@ -373,5 +378,4 @@ TEST(Objective, CoxRegressionGPair) {
{ 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f});
} }
#endif #endif
} // namespace xgboost } // namespace xgboost

View File

@ -430,8 +430,8 @@ TEST(Learner, MultiTarget) {
size_t constexpr kRows{128}, kCols{10}, kTargets{3}; size_t constexpr kRows{128}, kCols{10}, kTargets{3};
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
m->Info().labels.Reshape(kRows, kTargets); m->Info().labels.Reshape(kRows, kTargets);
linalg::ElementWiseKernelHost(m->Info().labels.HostView(), omp_get_max_threads(), linalg::ElementWiseTransformHost(m->Info().labels.HostView(), omp_get_max_threads(),
[](auto i, auto) { return i; }); [](auto i, auto) { return i; });
{ {
std::unique_ptr<Learner> learner{Learner::Create({m})}; std::unique_ptr<Learner> learner{Learner::Create({m})};