diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index e31091edf..46203387b 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -76,8 +76,9 @@ #include "../src/common/quantile.cc" #include "../src/common/host_device_vector.cc" #include "../src/common/hist_util.cc" -#include "../src/common/json.cc" #include "../src/common/io.cc" +#include "../src/common/json.cc" +#include "../src/common/pseudo_huber.cc" #include "../src/common/survival_util.cc" #include "../src/common/threading_utils.cc" #include "../src/common/version.cc" diff --git a/doc/model.schema b/doc/model.schema index 86acea967..8233d5509 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -204,6 +204,14 @@ } } }, + "pseduo_huber_param": { + "type": "object", + "properties": { + "huber_slope": { + "type": "string" + } + } + }, "aft_loss_param": { "type": "object", "properties": { diff --git a/doc/parameter.rst b/doc/parameter.rst index 4c6046af7..781150490 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -338,15 +338,6 @@ Parameters for Linear Booster (``booster=gblinear``) - The number of top features to select in ``greedy`` and ``thrifty`` feature selector. The value of 0 means using all the features. -Parameters for Tweedie Regression (``objective=reg:tweedie``) -============================================================= -* ``tweedie_variance_power`` [default=1.5] - - - Parameter that controls the variance of the Tweedie distribution ``var(y) ~ E(y)^tweedie_variance_power`` - - range: (1,2) - - Set closer to 2 to shift towards a gamma distribution - - Set closer to 1 to shift towards a Poisson distribution. - ************************ Learning Task Parameters ************************ @@ -356,14 +347,14 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``reg:squarederror``: regression with squared loss. - ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective. - - ``reg:logistic``: logistic regression + - ``reg:logistic``: logistic regression. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. - ``binary:logistic``: logistic regression for binary classification, output probability - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. - ``count:poisson``: Poisson regression for count data, output mean of Poisson distribution. - - ``max_delta_step`` is set to 0.7 by default in Poisson regression (used to safeguard optimization) + + ``max_delta_step`` is set to 0.7 by default in Poisson regression (used to safeguard optimization) - ``survival:cox``: Cox regression for right censored survival time data (negative values are considered right censored). Note that predictions are returned on the hazard ratio scale (i.e., as HR = exp(marginal_prediction) in the proportional hazard function ``h(t) = h0(t) * HR``). @@ -435,6 +426,20 @@ Specify the learning task and the corresponding learning objective. The objectiv - Seed PRNG determnisticly via iterator number. +Parameters for Tweedie Regression (``objective=reg:tweedie``) +============================================================= +* ``tweedie_variance_power`` [default=1.5] + + - Parameter that controls the variance of the Tweedie distribution ``var(y) ~ E(y)^tweedie_variance_power`` + - range: (1,2) + - Set closer to 2 to shift towards a gamma distribution + - Set closer to 1 to shift towards a Poisson distribution. + +Parameter for using Pseudo-Huber (``reg:pseudohubererror``) +=========================================================== + +* ``huber_slope`` : A parameter used for Pseudo-Huber loss to define the :math:`\delta` term. [default = 1.0] + *********************** Command Line Parameters *********************** diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index 97e181289..8f8cd0912 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -49,6 +49,8 @@ struct GenericParameter : public XGBoostParameter { */ int32_t Threads() const; + bool IsCPU() const { return gpu_id == kCpuId; } + // declare parameters DMLC_DECLARE_PARAMETER(GenericParameter) { DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe( diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 897a73301..32d0f9fb9 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -545,8 +545,19 @@ using VectorView = TensorView; */ template auto MakeVec(T *ptr, size_t s, int32_t device = -1) { - using U = std::remove_const_t> const; - return linalg::TensorView{{ptr, s}, {s}, device}; + return linalg::TensorView{{ptr, s}, {s}, device}; +} + +template +auto MakeVec(HostDeviceVector *data) { + return MakeVec(data->DeviceIdx() == -1 ? data->HostPointer() : data->DevicePointer(), + data->Size(), data->DeviceIdx()); +} + +template +auto MakeVec(HostDeviceVector const *data) { + return MakeVec(data->DeviceIdx() == -1 ? data->ConstHostPointer() : data->ConstDevicePointer(), + data->Size(), data->DeviceIdx()); } /** diff --git a/include/xgboost/metric.h b/include/xgboost/metric.h index 42d517819..0ce0d11ce 100644 --- a/include/xgboost/metric.h +++ b/include/xgboost/metric.h @@ -48,7 +48,10 @@ class Metric : public Configurable { * override this function to maintain internal configuration * \param out pointer to output JSON object */ - void SaveConfig(Json*) const override {} + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(this->Name()); + } /*! * \brief evaluate a specific metric diff --git a/src/common/common.h b/src/common/common.h index 8230e532f..fb7e7fee5 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -188,6 +188,16 @@ std::vector ArgSort(Container const &array, Comp comp = std::less{}) { XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op); return result; } + +struct OptionalWeights { + Span weights; + float dft{1.0f}; + + explicit OptionalWeights(Span w) : weights{w} {} + explicit OptionalWeights(float w) : dft{w} {} + + XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index dfab58729..f0f89df8a 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -1,15 +1,33 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_ + +#include "xgboost/generic_parameters.h" #include "device_helpers.cuh" +#include "linalg_op.h" #include "xgboost/linalg.h" namespace xgboost { namespace linalg { template void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { + static_assert(std::is_void>::value, + "For function with return, use transform instead."); + if (t.Contiguous()) { + auto ptr = t.Values().data(); + dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); }); + } else { + dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { + T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); + fn(i, v); + }); + } +} + +template +void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { if (t.Contiguous()) { auto ptr = t.Values().data(); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); @@ -20,6 +38,11 @@ void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s }); } } + +template +void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView t, Fn&& fn) { + ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn); +} } // namespace linalg } // namespace xgboost #endif // XGBOOST_COMMON_LINALG_OP_CUH_ diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index a74b119e7..05f050772 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -1,15 +1,19 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_ +#include + +#include "common.h" #include "threading_utils.h" +#include "xgboost/generic_parameters.h" #include "xgboost/linalg.h" namespace xgboost { namespace linalg { template -void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { +void ElementWiseTransformHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { if (t.Contiguous()) { auto ptr = t.Values().data(); common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); }); @@ -20,6 +24,41 @@ void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& f }); } } + +template +void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { + static_assert(std::is_void>::value, + "For function with return, use transform instead."); + if (t.Contiguous()) { + auto ptr = t.Values().data(); + common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); }); + } else { + common::ParallelFor(t.Size(), n_threads, [&](size_t i) { + auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); + fn(i, v); + }); + } +} + +#if !defined(XGBOOST_USE_CUDA) +template +void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, void* s = nullptr) { + common::AssertGPUSupport(); +} + +template +void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, void* s = nullptr) { + common::AssertGPUSupport(); +} + +template +void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView t, Fn&& fn) { + if (!ctx->IsCPU()) { + common::AssertGPUSupport(); + } + ElementWiseKernelHost(t, ctx->Threads(), fn); +} +#endif // !defined(XGBOOST_USE_CUDA) } // namespace linalg } // namespace xgboost #endif // XGBOOST_COMMON_LINALG_OP_H_ diff --git a/src/common/math.h b/src/common/math.h index 5a98ad329..71a494544 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -23,7 +23,11 @@ namespace common { * \return the transformed value. */ XGBOOST_DEVICE inline float Sigmoid(float x) { - return 1.0f / (1.0f + expf(-x)); + float constexpr kEps = 1e-16; // avoid 0 div + x = std::min(-x, 88.7f); // avoid exp overflow + auto denom = expf(x) + 1.0f + kEps; + auto y = 1.0f / denom; + return y; } template diff --git a/src/common/pseudo_huber.cc b/src/common/pseudo_huber.cc new file mode 100644 index 000000000..5f58a18b3 --- /dev/null +++ b/src/common/pseudo_huber.cc @@ -0,0 +1,7 @@ +/*! + * Copyright 2022, by XGBoost Contributors + */ +#include "pseudo_huber.h" +namespace xgboost { +DMLC_REGISTER_PARAMETER(PesudoHuberParam); +} diff --git a/src/common/pseudo_huber.h b/src/common/pseudo_huber.h new file mode 100644 index 000000000..9cf604534 --- /dev/null +++ b/src/common/pseudo_huber.h @@ -0,0 +1,19 @@ +#ifndef XGBOOST_COMMON_PSEUDO_HUBER_H_ +#define XGBOOST_COMMON_PSEUDO_HUBER_H_ +/*! + * Copyright 2022, by XGBoost Contributors + */ +#include "xgboost/parameter.h" + +namespace xgboost { +struct PesudoHuberParam : public XGBoostParameter { + float huber_slope{1.0}; + + DMLC_DECLARE_PARAMETER(PesudoHuberParam) { + DMLC_DECLARE_FIELD(huber_slope) + .set_default(1.0f) + .describe("The delta term in Pseudo-Huber loss."); + } +}; +} // namespace xgboost +#endif // XGBOOST_COMMON_PSEUDO_HUBER_H_ diff --git a/src/data/data.cc b/src/data/data.cc index 3d1e3cc28..f25daa6c8 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -431,7 +431,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { auto t = p_out->View(GenericParameter::kCpuId); CHECK(t.CContiguous()); // FIXME(jiamingy): Remove the use of this default thread. - linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { + linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); }); } @@ -877,7 +877,7 @@ DMatrix* DMatrix::Load(const std::string& uri, dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1, cache_file); } else { - data::FileIterator iter{fname, uint32_t(partid), uint32_t(npart), + data::FileIterator iter{fname, static_cast(partid), static_cast(npart), file_format}; dmat = new data::SparsePageDMatrix{ &iter, diff --git a/src/data/data.cu b/src/data/data.cu index 475d70313..55c1c80d0 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -49,7 +49,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { } p_out->Reshape(array.shape); auto t = p_out->View(ptr_device); - linalg::ElementWiseKernelDevice(t, [=] __device__(size_t i, T) { + linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) { return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, array.shape)); }); } diff --git a/src/learner.cc b/src/learner.cc index f8fad74cf..73447cf2e 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -277,6 +277,21 @@ using LearnerAPIThreadLocalStore = using ThreadLocalPredictionCache = dmlc::ThreadLocalStore>; +namespace { +StringView ModelMsg() { + return StringView{ + R"doc( + If you are loading a serialized model (like pickle in Python, RDS in R) generated by + older XGBoost, please export the model by calling `Booster.save_model` from that version + first, then load it back in current version. See: + + https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html + + for more details about differences between saving model and serializing. +)doc"}; +} +} // anonymous namespace + class LearnerConfiguration : public Learner { private: std::mutex config_lock_; @@ -375,7 +390,6 @@ class LearnerConfiguration : public Learner { this->ConfigureGBM(old_tparam, args); generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU()); - this->ConfigureMetrics(args); this->need_configuration_ = false; @@ -418,9 +432,17 @@ class LearnerConfiguration : public Learner { metric_names_.resize(n_metrics); metrics_.resize(n_metrics); for (size_t i = 0; i < n_metrics; ++i) { - metric_names_[i]= get(j_metrics[i]); - metrics_[i] = std::unique_ptr( - Metric::Create(metric_names_[i], &generic_parameters_)); + auto old_serialization = IsA(j_metrics[i]); + if (old_serialization) { + LOG(WARNING) << ModelMsg(); + metric_names_[i] = get(j_metrics[i]); + } else { + metric_names_[i] = get(j_metrics[i]["name"]); + } + metrics_[i] = std::unique_ptr(Metric::Create(metric_names_[i], &generic_parameters_)); + if (!old_serialization) { + metrics_[i]->LoadConfig(j_metrics[i]); + } } FromJson(learner_parameters.at("generic_param"), &generic_parameters_); @@ -448,9 +470,9 @@ class LearnerConfiguration : public Learner { auto& objective_fn = learner_parameters["objective"]; obj_->SaveConfig(&objective_fn); - std::vector metrics(metrics_.size()); + std::vector metrics(metrics_.size(), Json{Object{}}); for (size_t i = 0; i < metrics_.size(); ++i) { - metrics[i] = String(metrics_[i]->Name()); + metrics_[i]->SaveConfig(&metrics[i]); } learner_parameters["metrics"] = Array(std::move(metrics)); @@ -709,21 +731,6 @@ class LearnerConfiguration : public Learner { std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT -namespace { -StringView ModelMsg() { - return StringView{ - R"doc( - If you are loading a serialized model (like pickle in Python, RDS in R) generated by - older XGBoost, please export the model by calling `Booster.save_model` from that version - first, then load it back in current version. See: - - https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html - - for more details about differences between saving model and serializing. -)doc"}; -} -} // anonymous namespace - class LearnerIO : public LearnerConfiguration { private: std::set saved_configs_ = {"num_round"}; diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 1957bcc9a..0829db8f1 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -33,7 +33,7 @@ namespace metric { template std::tuple BinaryAUC(common::Span predts, linalg::VectorView labels, - OptionalWeights weights, + common::OptionalWeights weights, std::vector const &sorted_idx, Fn &&area_fn) { CHECK_NE(labels.Size(), 0); CHECK_EQ(labels.Size(), predts.size()); @@ -93,7 +93,7 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, auto tp = results.Slice(linalg::All(), 1); auto auc = results.Slice(linalg::All(), 2); - auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; + auto weights = common::OptionalWeights{info.weights_.ConstHostSpan()}; auto predts_t = linalg::TensorView( predts, {static_cast(info.num_row_), n_classes}, GenericParameter::kCpuId); @@ -140,7 +140,7 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, std::tuple BinaryROCAUC(common::Span predts, linalg::VectorView labels, - OptionalWeights weights) { + common::OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); } @@ -186,7 +186,7 @@ double GroupRankingROC(common::Span predts, */ std::tuple BinaryPRAUC(common::Span predts, linalg::VectorView labels, - OptionalWeights weights) { + common::OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); double total_pos{0}, total_neg{0}; for (size_t i = 0; i < labels.Size(); ++i) { @@ -238,7 +238,7 @@ std::pair RankingAUC(std::vector const &predts, if (is_roc) { auc = GroupRankingROC(g_predts, g_labels, w); } else { - auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w})); + auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, common::OptionalWeights{w})); } if (std::isnan(auc)) { invalid_groups++; @@ -373,7 +373,7 @@ class EvalROCAUC : public EvalAUC { if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(fp, tp, auc) = BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0), - OptionalWeights{info.weights_.ConstHostSpan()}); + common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); @@ -426,7 +426,7 @@ class EvalPRAUC : public EvalAUC { if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(pr, re, auc) = BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), - OptionalWeights{info.weights_.ConstHostSpan()}); + common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_); diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 317ce7db2..be89c015c 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -99,7 +99,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, /** * Linear scan */ - auto get_weight = OptionalWeights{weights}; + auto get_weight = common::OptionalWeights{weights}; auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; @@ -353,7 +353,7 @@ double GPUMultiClassAUCOVR(common::Span predts, * Linear scan */ dh::caching_device_vector d_auc(n_classes, 0); - auto get_weight = OptionalWeights{weights}; + auto get_weight = common::OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; @@ -633,7 +633,7 @@ GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, auto labels = info.labels.View(device); auto d_weights = info.weights_.ConstDeviceSpan(); - auto get_weight = OptionalWeights{d_weights}; + auto get_weight = common::OptionalWeights{d_weights}; auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto w = get_weight[d_sorted_idx[i]]; @@ -687,7 +687,7 @@ double GPUMultiClassPRAUC(common::Span predts, [n_samples] XGBOOST_DEVICE(size_t i) { return i / n_samples; // class id }); - auto get_weight = OptionalWeights{d_weights}; + auto get_weight = common::OptionalWeights{d_weights}; auto val_it = dh::MakeTransformIterator>( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto idx = d_sorted_idx[i] % n_samples; @@ -736,7 +736,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, */ size_t n_samples = labels.Shape(0); dh::caching_device_vector d_auc(n_groups, 0); - auto get_weight = OptionalWeights{weights}; + auto get_weight = common::OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; diff --git a/src/metric/auc.h b/src/metric/auc.h index cde8febf2..c42df6890 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -112,18 +112,6 @@ struct PRAUCLabelInvalid { inline void InvalidLabels() { LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank."; } - -struct OptionalWeights { - common::Span weights; - float dft { 1.0f }; - - explicit OptionalWeights(common::Span w) : weights{w} {} - explicit OptionalWeights(float w) : dft{w} {} - - XGBOOST_DEVICE float operator[](size_t i) const { - return weights.empty() ? dft : weights[i]; - } -}; } // namespace metric } // namespace xgboost #endif // XGBOOST_METRIC_AUC_H_ diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index abf888e0b..d36196bc3 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -1,20 +1,22 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file elementwise_metric.cc * \brief evaluation metrics for elementwise binary or regression. * \author Kailong Chen, Tianqi Chen * * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. */ +#include #include #include -#include + #include -#include "metric_common.h" -#include "../common/math.h" #include "../common/common.h" +#include "../common/math.h" +#include "../common/pseudo_huber.h" #include "../common/threading_utils.h" +#include "metric_common.h" #if defined(XGBOOST_USE_CUDA) #include // thrust::cuda::par @@ -30,109 +32,63 @@ namespace metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(elementwise_metric); -template -class ElementWiseMetricsReduction { - public: - explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {} - - PackedReduceResult - CpuReduceMetrics(const HostDeviceVector &weights, - linalg::TensorView labels, - const HostDeviceVector &preds, - int32_t n_threads) const { - size_t ndata = labels.Size(); - auto n_targets = std::max(labels.Shape(1), static_cast(1)); - auto h_labels = labels.Values(); - - const auto& h_weights = weights.HostVector(); - const auto& h_preds = preds.HostVector(); - +namespace { +/** + * \brief Reduce function for element wise metrics. + * + * The loss function should handle all the computation for each sample, including + * applying the weights. A tuple of {error_i, weight_i} is expected as return. + */ +template +PackedReduceResult Reduce(GenericParameter const* ctx, MetaInfo const& info, Fn&& loss) { + PackedReduceResult result; + auto labels = info.labels.View(ctx->gpu_id); + if (ctx->IsCPU()) { + auto n_threads = ctx->Threads(); std::vector score_tloc(n_threads, 0.0); std::vector weight_tloc(n_threads, 0.0); - // We sum over losses over all samples and targets instead of performing this for each // target since the first one approach more accurate while the second approach is used // for approximation in distributed setting. For rmse: // - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target // - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed - common::ParallelFor(ndata, n_threads, [&](size_t i) { - float wt = h_weights.size() > 0 ? h_weights[i / n_targets] : 1.0f; + common::ParallelFor(info.labels.Size(), ctx->Threads(), [&](size_t i) { auto t_idx = omp_get_thread_num(); - score_tloc[t_idx] += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; + size_t sample_id; + size_t target_id; + std::tie(sample_id, target_id) = linalg::UnravelIndex(i, labels.Shape()); + + float v, wt; + std::tie(v, wt) = loss(i, sample_id, target_id); + score_tloc[t_idx] += v; weight_tloc[t_idx] += wt; }); double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); - - PackedReduceResult res { residue_sum, weights_sum }; - return res; - } - + result = PackedReduceResult{residue_sum, weights_sum}; + } else { #if defined(XGBOOST_USE_CUDA) - - PackedReduceResult DeviceReduceMetrics( - const HostDeviceVector& weights, - linalg::TensorView labels, - const HostDeviceVector& preds) { - size_t n_data = preds.Size(); - auto n_targets = std::max(labels.Shape(1), static_cast(1)); - - thrust::counting_iterator begin(0); - thrust::counting_iterator end = begin + n_data; - - auto s_label = labels.Values(); - auto s_preds = preds.DeviceSpan(); - auto s_weights = weights.DeviceSpan(); - - bool const is_null_weight = weights.Size() == 0; - - auto d_policy = policy_; - dh::XGBCachingDeviceAllocator alloc; - PackedReduceResult result = thrust::transform_reduce( - thrust::cuda::par(alloc), - begin, end, - [=] XGBOOST_DEVICE(size_t idx) { - float weight = is_null_weight ? 1.0f : s_weights[idx / n_targets]; - - float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); - residue *= weight; - return PackedReduceResult{ residue, weight }; + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + labels.Size(); + result = thrust::transform_reduce( + thrust::cuda::par(alloc), begin, end, + [=] XGBOOST_DEVICE(size_t i) { + auto idx = linalg::UnravelIndex(i, labels.Shape()); + auto sample_id = std::get<0>(idx); + auto target_id = std::get<1>(idx); + auto res = loss(i, sample_id, target_id); + float v{std::get<0>(res)}, wt{std::get<1>(res)}; + return PackedReduceResult{v, wt}; }, - PackedReduceResult(), - thrust::plus()); - - return result; + PackedReduceResult{}, thrust::plus()); +#else + common::AssertGPUSupport(); +#endif // defined(XGBOOST_USE_CUDA) } - -#endif // XGBOOST_USE_CUDA - - PackedReduceResult Reduce(const GenericParameter& ctx, const HostDeviceVector& weights, - linalg::Tensor const& labels, - const HostDeviceVector& preds) { - PackedReduceResult result; - - if (ctx.gpu_id < 0) { - auto n_threads = ctx.Threads(); - result = CpuReduceMetrics(weights, labels.HostView(), preds, n_threads); - } -#if defined(XGBOOST_USE_CUDA) - else { // NOLINT - preds.SetDevice(ctx.gpu_id); - weights.SetDevice(ctx.gpu_id); - - dh::safe_cuda(cudaSetDevice(ctx.gpu_id)); - result = DeviceReduceMetrics(weights, labels.View(ctx.gpu_id), preds); - } -#endif // defined(XGBOOST_USE_CUDA) - return result; - } - - private: - EvalRow policy_; -#if defined(XGBOOST_USE_CUDA) -#endif // defined(XGBOOST_USE_CUDA) -}; + return result; +} +} // anonymous namespace struct EvalRowRMSE { char const *Name() const { @@ -187,38 +143,64 @@ struct EvalRowMAPE { } }; +namespace { +XGBOOST_DEVICE inline float LogLoss(float y, float py) { + auto xlogy = [](float x, float y) { + float eps = 1e-16; + return (x - 0.0f == 0.0f) ? 0.0f : (x * std::log(std::max(y, eps))); + }; + const bst_float pneg = 1.0f - py; + return xlogy(-y, py) + xlogy(-(1.0f - y), pneg); +} +} // anonymous namespace + struct EvalRowLogLoss { const char *Name() const { return "logloss"; } - XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { - const bst_float eps = 1e-16f; - const bst_float pneg = 1.0f - py; - if (py < eps) { - return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps); - } else if (pneg < eps) { - return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps); - } else { - return -y * std::log(py) - (1.0f - y) * std::log(pneg); - } - } - + XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { return LogLoss(y, py); } static double GetFinal(double esum, double wsum) { return wsum == 0 ? esum : esum / wsum; } }; -struct EvalRowMPHE { - char const *Name() const { - return "mphe"; +class PseudoErrorLoss : public Metric { + PesudoHuberParam param_; + + public: + const char* Name() const override { return "mphe"; } + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + void LoadConfig(Json const& in) override { FromJson(in["pseduo_huber_param"], ¶m_); } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(this->Name()); + out["pseduo_huber_param"] = ToJson(param_); } - XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { - bst_float diff = label - pred; - return std::sqrt( 1 + diff * diff) - 1; - } - static double GetFinal(double esum, double wsum) { - return wsum == 0 ? esum : esum / wsum; + + double Eval(const HostDeviceVector& preds, const MetaInfo& info, + bool distributed) override { + CHECK_EQ(info.labels.Shape(0), info.num_row_); + auto labels = info.labels.View(tparam_->gpu_id); + preds.SetDevice(tparam_->gpu_id); + auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); + info.weights_.SetDevice(tparam_->gpu_id); + common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()); + float slope = this->param_.huber_slope; + CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0."; + PackedReduceResult result = + Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) { + float wt = weights[sample_id]; + auto a = labels(sample_id, target_id) - predts[i]; + auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt; + return std::make_tuple(v, wt); + }); + double dat[2]{result.Residue(), result.Weights()}; + if (distributed) { + rabit::Allreduce(dat, 2); + } + return EvalRowMAPE::GetFinal(dat[0], dat[1]); } }; @@ -355,20 +337,36 @@ struct EvalTweedieNLogLik { * \brief base class of element-wise evaluation * \tparam Derived the name of subclass */ -template +template struct EvalEWiseBase : public Metric { EvalEWiseBase() = default; - explicit EvalEWiseBase(char const* policy_param) : - policy_{policy_param}, reducer_{policy_} {} + explicit EvalEWiseBase(char const* policy_param) : policy_{policy_param} {} - double Eval(const HostDeviceVector &preds, const MetaInfo &info, + double Eval(HostDeviceVector const& preds, const MetaInfo& info, bool distributed) override { CHECK_EQ(preds.Size(), info.labels.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels, preds); + if (info.labels.Size() != 0) { + CHECK_NE(info.labels.Shape(1), 0); + } + auto labels = info.labels.View(tparam_->gpu_id); + info.weights_.SetDevice(tparam_->gpu_id); + common::OptionalWeights weights(tparam_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()); + preds.SetDevice(tparam_->gpu_id); + auto predts = tparam_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(); - double dat[2] { result.Residue(), result.Weights() }; + auto d_policy = policy_; + auto result = + Reduce(tparam_, info, [=] XGBOOST_DEVICE(size_t i, size_t sample_id, size_t target_id) { + float wt = weights[sample_id]; + float residue = d_policy.EvalRow(labels(sample_id, target_id), predts[i]); + residue *= wt; + return std::make_tuple(residue, wt); + }); + + double dat[2]{result.Residue(), result.Weights()}; if (distributed) { rabit::Allreduce(dat, 2); @@ -376,13 +374,10 @@ struct EvalEWiseBase : public Metric { return Policy::GetFinal(dat[0], dat[1]); } - const char* Name() const override { - return policy_.Name(); - } + const char* Name() const override { return policy_.Name(); } private: Policy policy_; - ElementWiseMetricsReduction reducer_{policy_}; }; XGBOOST_REGISTER_METRIC(RMSE, "rmse") @@ -401,14 +396,14 @@ XGBOOST_REGISTER_METRIC(MAPE, "mape") .describe("Mean absolute percentage error.") .set_body([](const char* param) { return new EvalEWiseBase(); }); -XGBOOST_REGISTER_METRIC(MPHE, "mphe") -.describe("Mean Pseudo Huber error.") -.set_body([](const char* param) { return new EvalEWiseBase(); }); - XGBOOST_REGISTER_METRIC(LogLoss, "logloss") .describe("Negative loglikelihood for logistic regression.") .set_body([](const char* param) { return new EvalEWiseBase(); }); +XGBOOST_REGISTER_METRIC(PseudoErrorLoss, "mphe") + .describe("Mean Pseudo-huber error.") + .set_body([](const char* param) { return new PseudoErrorLoss{}; }); + XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") .describe("Negative loglikelihood for poisson regression.") .set_body([](const char* param) { return new EvalEWiseBase(); }); @@ -430,6 +425,5 @@ XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik") .set_body([](const char* param) { return new EvalEWiseBase(param); }); - } // namespace metric } // namespace xgboost diff --git a/src/objective/regression_loss.h b/src/objective/regression_loss.h index 30605b348..f92dfe2d4 100644 --- a/src/objective/regression_loss.h +++ b/src/objective/regression_loss.h @@ -105,38 +105,6 @@ struct LogisticRegression { static ObjInfo Info() { return {ObjInfo::kRegression, false}; } }; -struct PseudoHuberError { - XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { - return x; - } - XGBOOST_DEVICE static bool CheckLabel(bst_float) { - return true; - } - XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) { - const float z = predt - label; - const float scale_sqrt = std::sqrt(1 + std::pow(z, 2)); - return z/scale_sqrt; - } - XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) { - const float scale = 1 + std::pow(predt - label, 2); - const float scale_sqrt = std::sqrt(scale); - return 1/(scale*scale_sqrt); - } - static bst_float ProbToMargin(bst_float base_score) { - return base_score; - } - static const char* LabelErrorMsg() { - return ""; - } - static const char* DefaultEvalMetric() { - return "mphe"; - } - static const char* Name() { - return "reg:pseudohubererror"; - } - static ObjInfo Info() { return {ObjInfo::kRegression, false}; } -}; - // logistic loss for binary classification task struct LogisticClassification : public LogisticRegression { static const char* DefaultEvalMetric() { return "logloss"; } diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index a07de8e44..fa294a5a5 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -8,23 +8,38 @@ #include #include #include + #include #include #include +#include "../common/common.h" +#include "../common/linalg_op.h" +#include "../common/pseudo_huber.h" +#include "../common/threading_utils.h" +#include "../common/transform.h" +#include "./regression_loss.h" #include "xgboost/host_device_vector.h" #include "xgboost/json.h" #include "xgboost/parameter.h" #include "xgboost/span.h" -#include "../common/transform.h" -#include "../common/common.h" -#include "../common/threading_utils.h" -#include "./regression_loss.h" - +#if defined(XGBOOST_USE_CUDA) +#include "../common/linalg_op.cuh" +#endif // defined(XGBOOST_USE_CUDA) namespace xgboost { namespace obj { +namespace { +void CheckRegInputs(MetaInfo const& info, HostDeviceVector const& preds) { + CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels."; + CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels."; + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), info.num_row_) + << "Number of weights should be equal to number of data points."; + } +} +} // anonymous namespace #if defined(XGBOOST_USE_CUDA) DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); @@ -64,20 +79,13 @@ class RegLossObj : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector* out_gpair) override { - CHECK_EQ(preds.Size(), info.labels.Size()) - << " " << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " - << "Loss: " << Loss::Name(); + CheckRegInputs(info, preds); size_t const ndata = preds.Size(); out_gpair->Resize(ndata); auto device = ctx_->gpu_id; additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag bool is_null_weight = info.weights_.Size() == 0; - if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) - << "Number of weights should be equal to number of data points."; - } auto scale_pos_weight = param_.scale_pos_weight; additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[2] = is_null_weight; @@ -179,10 +187,6 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name()) .describe("Logistic regression for probability regression task.") .set_body([]() { return new RegLossObj(); }); -XGBOOST_REGISTER_OBJECTIVE(PseudoHuberError, PseudoHuberError::Name()) -.describe("Regression Pseudo Huber error.") -.set_body([]() { return new RegLossObj(); }); - XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name()) .describe("Logistic regression for binary classification task.") .set_body([]() { return new RegLossObj(); }); @@ -200,6 +204,70 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") return new RegLossObj(); }); // End deprecated +class PseudoHuberRegression : public ObjFunction { + PesudoHuberParam param_; + + public: + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + struct ObjInfo Task() const override { return {ObjInfo::kRegression, false}; } + uint32_t Targets(MetaInfo const& info) const override { + return std::max(static_cast(1), info.labels.Shape(1)); + } + + void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, int iter, + HostDeviceVector* out_gpair) override { + CheckRegInputs(info, preds); + auto slope = param_.huber_slope; + CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0."; + auto labels = info.labels.View(ctx_->gpu_id); + + out_gpair->SetDevice(ctx_->gpu_id); + out_gpair->Resize(info.labels.Size()); + auto gpair = linalg::MakeVec(out_gpair); + + preds.SetDevice(ctx_->gpu_id); + auto predt = linalg::MakeVec(&preds); + + info.weights_.SetDevice(ctx_->gpu_id); + common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()}; + + linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { + auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); + const float z = predt(i) - y; + const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope)); + float grad = z / scale_sqrt; + + auto scale = common::Sqr(slope) + common::Sqr(z); + float hess = common::Sqr(slope) / (scale * scale_sqrt); + + auto w = weight[sample_id]; + gpair(i) = {grad * w, hess * w}; + }); + } + + const char* DefaultEvalMetric() const override { return "mphe"; } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String("reg:pseudohubererror"); + out["pseduo_huber_param"] = ToJson(param_); + } + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + if (config.find("pseduo_huber_param") == config.cend()) { + // The parameter is added in 1.6. + return; + } + FromJson(in["pseduo_huber_param"], ¶m_); + } +}; + +XGBOOST_REGISTER_OBJECTIVE(PseudoHuberRegression, "reg:pseudohubererror") + .describe("Regression Pseudo Huber error.") + .set_body([]() { return new PseudoHuberRegression(); }); + // declare parameter struct PoissonRegressionParam : public XGBoostParameter { float max_delta_step; diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index a4f3e6ab4..110f18fcb 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -314,11 +314,11 @@ TEST(Linalg, Popc) { TEST(Linalg, Stack) { Tensor l{{2, 3, 4}, kCpuId}; - ElementWiseKernelHost(l.View(kCpuId), omp_get_max_threads(), - [=](size_t i, float v) { return i; }); + ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(), + [=](size_t i, float v) { return i; }); Tensor r_0{{2, 3, 4}, kCpuId}; - ElementWiseKernelHost(r_0.View(kCpuId), omp_get_max_threads(), - [=](size_t i, float v) { return i; }); + ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(), + [=](size_t i, float v) { return i; }); Stack(&l, r_0); diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 9ea6b22dd..ae0eb28a7 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #include @@ -19,7 +19,7 @@ void TestElementWiseKernel() { // GPU view auto t = l.View(0).Slice(linalg::All(), 1, linalg::All()); ASSERT_FALSE(t.CContiguous()); - ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; }); + ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; }); // CPU view t = l.View(GenericParameter::kCpuId).Slice(linalg::All(), 1, linalg::All()); size_t k = 0; @@ -30,10 +30,7 @@ void TestElementWiseKernel() { } t = l.View(0).Slice(linalg::All(), 1, linalg::All()); - ElementWiseKernelDevice(t, [] __device__(size_t i, float v) { - SPAN_CHECK(v == i); - return v; - }); + ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); }); } { @@ -41,7 +38,7 @@ void TestElementWiseKernel() { * Contiguous */ auto t = l.View(0); - ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; }); + ElementWiseTransformDevice(t, [] XGBOOST_DEVICE(size_t i, float) { return i; }); ASSERT_TRUE(t.CContiguous()); // CPU view t = l.View(GenericParameter::kCpuId); diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index f070e6f81..bb86e16ea 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -29,14 +29,13 @@ inline void TestMetaInfoStridedData(int32_t device) { auto const& h_result = info.labels.View(-1); ASSERT_EQ(h_result.Shape().size(), 2); auto in_labels = labels.View(-1); - linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) { auto tup = linalg::UnravelIndex(i, h_result.Shape()); auto i0 = std::get<0>(tup); auto i1 = std::get<1>(tup); // Sliced at second dimension. auto v_1 = in_labels(i0, 0, i1); CHECK_EQ(v_0, v_1); - return v_0; }); } { @@ -71,7 +70,6 @@ inline void TestMetaInfoStridedData(int32_t device) { // Sliced at second dimension. auto v_1 = in_margin(i0, 0, i1); CHECK_EQ(v_0, v_1); - return v_0; }); } } diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index fe32a0593..05c138781 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2016-2020 XGBoost contributors + * Copyright 2016-2022 by XGBoost contributors */ #include #include @@ -136,8 +136,8 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels = - xgboost::linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; + info.labels = xgboost::linalg::Tensor{ + labels.cbegin(), labels.cend(), {labels.size(), static_cast(1)}, -1}; info.weights_.HostVector() = weights; info.group_ptr_ = groups; diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index 514b8753c..2cf353bf3 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -1,12 +1,13 @@ /*! - * Copyright 2018-2019 XGBoost contributors + * Copyright 2018-2022 by XGBoost contributors */ -#include #include +#include #include #include +#include "../../../src/common/linalg_op.h" #include "../helpers.h" namespace xgboost { @@ -16,14 +17,17 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) std::unique_ptr metric{Metric::Create(name.c_str(), &lparam)}; HostDeviceVector predts; + size_t n_samples = 2048; + MetaInfo info; + info.labels.Reshape(n_samples, 1); + info.num_row_ = n_samples; auto &h_labels = info.labels.Data()->HostVector(); auto &h_predts = predts.HostVector(); SimpleLCG lcg; SimpleRealUniformDistribution dist{0.0f, 1.0f}; - size_t n_samples = 2048; h_labels.resize(n_samples); h_predts.resize(n_samples); @@ -145,27 +149,33 @@ TEST(Metric, DeclareUnifiedTest(MAPE)) { TEST(Metric, DeclareUnifiedTest(MPHE)) { auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("mphe", &lparam); + std::unique_ptr metric{xgboost::Metric::Create("mphe", &lparam)}; metric->Configure({}); ASSERT_STREQ(metric->Name(), "mphe"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), 0.1751f, 1e-4); - EXPECT_NEAR(GetMetricEval(metric, + EXPECT_NEAR(GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}, { -1, 1, 9, -9}), 3.4037f, 1e-4); - EXPECT_NEAR(GetMetricEval(metric, + EXPECT_NEAR(GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}, { 1, 2, 9, 8}), 0.1922f, 1e-4); - delete metric; xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); + + metric->Configure({{"huber_slope", "0.1"}}); + EXPECT_NEAR(GetMetricEval(metric.get(), + {0.1f, 0.9f, 0.1f, 0.9f}, + { 0, 0, 1, 1}, + { 1, 2, 9, 8}), + 0.0461686f, 1e-4); } TEST(Metric, DeclareUnifiedTest(LogLoss)) { @@ -277,7 +287,7 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { 1.5783f, 0.001f); delete metric; - xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"mphe"}, GPUIDX); + xgboost::CheckDeterministicMetricElementWise(xgboost::StringView{"poisson-nloglik"}, GPUIDX); } TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { @@ -288,8 +298,8 @@ TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { HostDeviceVector predt(n_samples * n_targets, 0); - auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX); - std::unique_ptr metric{Metric::Create("rmse", &lparam)}; + auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + std::unique_ptr metric{Metric::Create("rmse", &ctx)}; metric->Configure({}); auto loss = GetMultiMetricEval(metric.get(), predt, y); diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 6f396ea76..ef4529934 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -57,25 +57,31 @@ TEST(Objective, DeclareUnifiedTest(SquaredLog)) { TEST(Objective, DeclareUnifiedTest(PseudoHuber)) { GenericParameter tparam = CreateEmptyGenericParam(GPUIDX); - std::vector> args; + Args args; - std::unique_ptr obj { ObjFunction::Create("reg:pseudohubererror", &tparam) }; + std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", &tparam)}; obj->Configure(args); CheckConfigReload(obj, "reg:pseudohubererror"); - CheckObjFunction(obj, - {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights - {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad - { 0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess - CheckObjFunction(obj, - {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {}, // empty weights - {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad - { 0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights + {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad + {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {}, // empty weights + {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad + {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); + + obj->Configure({{"huber_slope", "0.1"}}); + CheckConfigReload(obj, "reg:pseudohubererror"); + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights + {-0.099388f, -0.099228f, -0.098639f, -0.089443f, 0.098639f}, // out_grad + {0.0013467f, 0.001908f, 0.004443f, 0.089443f, 0.004443f}); // out_hess } TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { @@ -131,7 +137,6 @@ TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { std::unique_ptr obj { ObjFunction::Create("binary:logitraw", &lparam) }; - obj->Configure(args); CheckObjFunction(obj, @@ -373,5 +378,4 @@ TEST(Objective, CoxRegressionGPair) { { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); } #endif - } // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index f7e221540..eaba41b6a 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -430,8 +430,8 @@ TEST(Learner, MultiTarget) { size_t constexpr kRows{128}, kCols{10}, kTargets{3}; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); m->Info().labels.Reshape(kRows, kTargets); - linalg::ElementWiseKernelHost(m->Info().labels.HostView(), omp_get_max_threads(), - [](auto i, auto) { return i; }); + linalg::ElementWiseTransformHost(m->Info().labels.HostView(), omp_get_max_threads(), + [](auto i, auto) { return i; }); { std::unique_ptr learner{Learner::Create({m})};