From 48dddfd635b0fa6575b9934334254533e5d8708d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 1 Dec 2018 18:46:45 +1300 Subject: [PATCH] Porting elementwise metrics to GPU. (#3952) * Port elementwise metrics to GPU. * All elementwise metrics are converted to static polymorphic. * Create a reducer for metrics reduction. * Remove const of Metric::Eval to accommodate CubMemory. --- include/xgboost/metric.h | 24 +- src/common/math.h | 2 +- src/learner.cc | 13 +- src/metric/elementwise_metric.cc | 229 +---------- src/metric/elementwise_metric.cu | 406 ++++++++++++++++++++ src/metric/metric.cc | 4 + src/metric/metric_param.h | 31 ++ src/metric/multiclass_metric.cc | 13 +- src/metric/rank_metric.cc | 72 ++-- tests/cpp/helpers.cc | 3 +- tests/cpp/helpers.h | 2 +- tests/cpp/metric/test_elementwise_metric.cc | 64 ++- tests/cpp/metric/test_elementwise_metric.cu | 5 + tests/cpp/metric/test_rank_metric.cc | 16 +- 14 files changed, 605 insertions(+), 279 deletions(-) create mode 100644 src/metric/elementwise_metric.cu create mode 100644 src/metric/metric_param.h create mode 100644 tests/cpp/metric/test_elementwise_metric.cu diff --git a/include/xgboost/metric.h b/include/xgboost/metric.h index 80adec194..56ecebfbf 100644 --- a/include/xgboost/metric.h +++ b/include/xgboost/metric.h @@ -11,8 +11,11 @@ #include #include #include +#include + #include "./data.h" #include "./base.h" +#include "../../src/common/host_device_vector.h" namespace xgboost { /*! @@ -21,6 +24,23 @@ namespace xgboost { */ class Metric { public: + /*! + * \brief Configure the Metric with the specified parameters. + * \param args arguments to the objective function. + */ + virtual void Configure( + const std::vector >& args) {} + /*! + * \brief set configuration from pair iterators. + * \param begin The beginning iterator. + * \param end The end iterator. + * \tparam PairIter iterator > + */ + template + inline void Configure(PairIter begin, PairIter end) { + std::vector > vec(begin, end); + this->Configure(vec); + } /*! * \brief evaluate a specific metric * \param preds prediction @@ -29,9 +49,9 @@ class Metric { * the average statistics across all the node, * this is only supported by some metrics */ - virtual bst_float Eval(const std::vector& preds, + virtual bst_float Eval(const HostDeviceVector& preds, const MetaInfo& info, - bool distributed) const = 0; + bool distributed) = 0; /*! \return name of metric */ virtual const char* Name() const = 0; /*! \brief virtual destructor */ diff --git a/src/common/math.h b/src/common/math.h index 53e48bf4e..14342cadf 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -127,7 +127,7 @@ inline bool CheckNAN(T v) { #endif } template -inline T LogGamma(T v) { +XGBOOST_DEVICE inline T LogGamma(T v) { #ifdef _MSC_VER #if _MSC_VER >= 1800 return lgamma(v); diff --git a/src/learner.cc b/src/learner.cc index dd187a9a6..7155f5d96 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -310,6 +310,10 @@ class LearnerImpl : public Learner { if (obj_ != nullptr) { obj_->Configure(cfg_.begin(), cfg_.end()); } + + for (auto& p_metric : metrics_) { + p_metric->Configure(cfg_.begin(), cfg_.end()); + } } void InitModel() override { this->LazyInitModel(); } @@ -407,6 +411,10 @@ class LearnerImpl : public Learner { cfg_["num_class"] = common::ToString(mparam_.num_class); cfg_["num_feature"] = common::ToString(mparam_.num_feature); obj_->Configure(cfg_.begin(), cfg_.end()); + + for (auto& p_metric : metrics_) { + p_metric->Configure(cfg_.begin(), cfg_.end()); + } } // rabit save model to rabit checkpoint @@ -503,13 +511,14 @@ class LearnerImpl : public Learner { os << '[' << iter << ']' << std::setiosflags(std::ios::fixed); if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) { metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric())); + metrics_.back()->Configure(cfg_.begin(), cfg_.end()); } for (size_t i = 0; i < data_sets.size(); ++i) { this->PredictRaw(data_sets[i], &preds_); obj_->EvalTransform(&preds_); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' - << ev->Eval(preds_.ConstHostVector(), data_sets[i]->Info(), + << ev->Eval(preds_, data_sets[i]->Info(), tparam_.dsplit == DataSplitMode::kRow); } } @@ -553,7 +562,7 @@ class LearnerImpl : public Learner { this->PredictRaw(data, &preds_); obj_->EvalTransform(&preds_); return std::make_pair(metric, - ev->Eval(preds_.ConstHostVector(), data->Info(), + ev->Eval(preds_, data->Info(), tparam_.dsplit == DataSplitMode::kRow)); } diff --git a/src/metric/elementwise_metric.cc b/src/metric/elementwise_metric.cc index 6ca022e0e..9773f525c 100644 --- a/src/metric/elementwise_metric.cc +++ b/src/metric/elementwise_metric.cc @@ -1,227 +1,8 @@ /*! - * Copyright 2015 by Contributors - * \file elementwise_metric.cc - * \brief evaluation metrics for elementwise binary or regression. - * \author Kailong Chen, Tianqi Chen + * Copyright 2018 XGBoost contributors */ -#include -#include -#include -#include -#include "../common/math.h" +// Dummy file to keep the CUDA conditional compile trick. -namespace xgboost { -namespace metric { -// tag the this file, used by force static link later. -DMLC_REGISTRY_FILE_TAG(elementwise_metric); - -/*! - * \brief base class of element-wise evaluation - * \tparam Derived the name of subclass - */ -template -struct EvalEWiseBase : public Metric { - bst_float Eval(const std::vector& preds, - const MetaInfo& info, - bool distributed) const override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.size(), info.labels_.Size()) - << "label and prediction size not match, " - << "hint: use merror or mlogloss for multi-class classification"; - const auto ndata = static_cast(info.labels_.Size()); - double sum = 0.0, wsum = 0.0; - const auto& labels = info.labels_.HostVector(); - const auto& weights = info.weights_.HostVector(); - #pragma omp parallel for reduction(+: sum, wsum) schedule(static) - for (omp_ulong i = 0; i < ndata; ++i) { - const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f; - sum += static_cast(this)->EvalRow(labels[i], preds[i]) * wt; - wsum += wt; - } - double dat[2]; dat[0] = sum, dat[1] = wsum; - if (distributed) { - rabit::Allreduce(dat, 2); - } - return Derived::GetFinal(dat[0], dat[1]); - } - /*! - * \brief to be implemented by subclass, - * get evaluation result from one row - * \param label label of current instance - * \param pred prediction value of current instance - */ - inline bst_float EvalRow(bst_float label, bst_float pred) const; - /*! - * \brief to be overridden by subclass, final transformation - * \param esum the sum statistics returned by EvalRow - * \param wsum sum of weight - */ - inline static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; - } -}; - -struct EvalRMSE : public EvalEWiseBase { - const char *Name() const override { - return "rmse"; - } - inline bst_float EvalRow(bst_float label, bst_float pred) const { - bst_float diff = label - pred; - return diff * diff; - } - inline static bst_float GetFinal(bst_float esum, bst_float wsum) { - return std::sqrt(esum / wsum); - } -}; - -struct EvalMAE : public EvalEWiseBase { - const char *Name() const override { - return "mae"; - } - inline bst_float EvalRow(bst_float label, bst_float pred) const { - return std::abs(label - pred); - } -}; - -struct EvalLogLoss : public EvalEWiseBase { - const char *Name() const override { - return "logloss"; - } - inline bst_float EvalRow(bst_float y, bst_float py) const { - const bst_float eps = 1e-16f; - const bst_float pneg = 1.0f - py; - if (py < eps) { - return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps); - } else if (pneg < eps) { - return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps); - } else { - return -y * std::log(py) - (1.0f - y) * std::log(pneg); - } - } -}; - -struct EvalError : public EvalEWiseBase { - explicit EvalError(const char* param) { - if (param != nullptr) { - std::ostringstream os; - os << "error"; - CHECK_EQ(sscanf(param, "%f", &threshold_), 1) - << "unable to parse the threshold value for the error metric"; - if (threshold_ != 0.5f) os << '@' << threshold_; - name_ = os.str(); - } else { - threshold_ = 0.5f; - name_ = "error"; - } - } - const char *Name() const override { - return name_.c_str(); - } - inline bst_float EvalRow(bst_float label, bst_float pred) const { - // assume label is in [0,1] - return pred > threshold_ ? 1.0f - label : label; - } - protected: - bst_float threshold_; - std::string name_; -}; - -struct EvalPoissonNegLogLik : public EvalEWiseBase { - const char *Name() const override { - return "poisson-nloglik"; - } - inline bst_float EvalRow(bst_float y, bst_float py) const { - const bst_float eps = 1e-16f; - if (py < eps) py = eps; - return common::LogGamma(y + 1.0f) + py - std::log(py) * y; - } -}; - -struct EvalGammaDeviance : public EvalEWiseBase { - const char *Name() const override { - return "gamma-deviance"; - } - inline bst_float EvalRow(bst_float label, bst_float pred) const { - bst_float epsilon = 1.0e-9; - bst_float tmp = label / (pred + epsilon); - return tmp - std::log(tmp) - 1; - } - inline static bst_float GetFinal(bst_float esum, bst_float wsum) { - return 2 * esum; - } -}; - -struct EvalGammaNLogLik: public EvalEWiseBase { - const char *Name() const override { - return "gamma-nloglik"; - } - inline bst_float EvalRow(bst_float y, bst_float py) const { - bst_float psi = 1.0; - bst_float theta = -1. / py; - bst_float a = psi; - bst_float b = -std::log(-theta); - bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi); - return -((y * theta - b) / a + c); - } -}; - -struct EvalTweedieNLogLik: public EvalEWiseBase { - explicit EvalTweedieNLogLik(const char* param) { - CHECK(param != nullptr) - << "tweedie-nloglik must be in format tweedie-nloglik@rho"; - rho_ = atof(param); - CHECK(rho_ < 2 && rho_ >= 1) - << "tweedie variance power must be in interval [1, 2)"; - std::ostringstream os; - os << "tweedie-nloglik@" << rho_; - name_ = os.str(); - } - const char *Name() const override { - return name_.c_str(); - } - inline bst_float EvalRow(bst_float y, bst_float p) const { - bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_); - bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_); - return -a + b; - } - protected: - std::string name_; - bst_float rho_; -}; - -XGBOOST_REGISTER_METRIC(RMSE, "rmse") -.describe("Rooted mean square error.") -.set_body([](const char* param) { return new EvalRMSE(); }); - -XGBOOST_REGISTER_METRIC(MAE, "mae") -.describe("Mean absolute error.") -.set_body([](const char* param) { return new EvalMAE(); }); - -XGBOOST_REGISTER_METRIC(LogLoss, "logloss") -.describe("Negative loglikelihood for logistic regression.") -.set_body([](const char* param) { return new EvalLogLoss(); }); - -XGBOOST_REGISTER_METRIC(Error, "error") -.describe("Binary classification error.") -.set_body([](const char* param) { return new EvalError(param); }); - -XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") -.describe("Negative loglikelihood for poisson regression.") -.set_body([](const char* param) { return new EvalPoissonNegLogLik(); }); - -XGBOOST_REGISTER_METRIC(GammaDeviance, "gamma-deviance") -.describe("Residual deviance for gamma regression.") -.set_body([](const char* param) { return new EvalGammaDeviance(); }); - -XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik") -.describe("Negative log-likelihood for gamma regression.") -.set_body([](const char* param) { return new EvalGammaNLogLik(); }); - -XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik") -.describe("tweedie-nloglik@rho for tweedie regression.") -.set_body([](const char* param) { - return new EvalTweedieNLogLik(param); -}); - -} // namespace metric -} // namespace xgboost +#if !defined(XGBOOST_USE_CUDA) +#include "elementwise_metric.cu" +#endif diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu new file mode 100644 index 000000000..0675f1a23 --- /dev/null +++ b/src/metric/elementwise_metric.cu @@ -0,0 +1,406 @@ +/*! + * Copyright 2015-2018 by Contributors + * \file elementwise_metric.cc + * \brief evaluation metrics for elementwise binary or regression. + * \author Kailong Chen, Tianqi Chen + */ +#include +#include +#include +#include + +#include "metric_param.h" +#include "../common/math.h" +#include "../common/common.h" + +#if defined(XGBOOST_USE_CUDA) +#include +#include +#include +#include // thrust::plus<> + +#include "../common/device_helpers.cuh" +#endif // XGBOOST_USE_CUDA + +namespace xgboost { +namespace metric { +// tag the this file, used by force static link later. +DMLC_REGISTRY_FILE_TAG(elementwise_metric); + +struct PackedReduceResult { + double residue_sum_; + double weights_sum_; + + XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {} + XGBOOST_DEVICE PackedReduceResult(double residue, double weight) : + residue_sum_{residue}, weights_sum_{weight} {} + + XGBOOST_DEVICE + PackedReduceResult operator+(PackedReduceResult const& other) const { + return PackedReduceResult { residue_sum_ + other.residue_sum_, + weights_sum_ + other.weights_sum_ }; + } +}; + +template +class MetricsReduction { + public: + explicit MetricsReduction(EvalRow policy) : + policy_(std::move(policy)) {} + + PackedReduceResult CpuReduceMetrics( + const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds) const { + size_t ndata = labels.Size(); + + const auto& h_labels = labels.HostVector(); + const auto& h_weights = weights.HostVector(); + const auto& h_preds = preds.HostVector(); + + bst_float residue_sum = 0; + bst_float weights_sum = 0; + +#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static) + for (omp_ulong i = 0; i < ndata; ++i) { + const bst_float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f; + residue_sum += policy_.EvalRow(h_labels[i], h_preds[i]) * wt; + weights_sum += wt; + } + PackedReduceResult res { residue_sum, weights_sum }; + return res; + } + +#if defined(XGBOOST_USE_CUDA) + + PackedReduceResult DeviceReduceMetrics( + GPUSet::GpuIdType device_id, + size_t device_index, + const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds) { + size_t n_data = preds.DeviceSize(device_id); + + thrust::counting_iterator begin(0); + thrust::counting_iterator end = begin + n_data; + + auto s_label = labels.DeviceSpan(device_id); + auto s_preds = preds.DeviceSpan(device_id); + auto s_weights = weights.DeviceSpan(device_id); + + bool const is_null_weight = weights.Size() == 0; + + auto d_policy = policy_; + + PackedReduceResult result = thrust::transform_reduce( + thrust::cuda::par(allocators_.at(device_index)), + begin, end, + [=] XGBOOST_DEVICE(size_t idx) { + bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; + + bst_float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]); + residue *= weight; + return PackedReduceResult{ residue, weight }; + }, + PackedReduceResult(), + thrust::plus()); + + return result; + } + +#endif // XGBOOST_USE_CUDA + + PackedReduceResult Reduce( + GPUSet devices, + const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds) { + PackedReduceResult result; + + if (devices.IsEmpty()) { + result = CpuReduceMetrics(weights, labels, preds); + } +#if defined(XGBOOST_USE_CUDA) + else { // NOLINT + if (allocators_.size() != devices.Size()) { + allocators_.clear(); + allocators_.resize(devices.Size()); + } + preds.Reshard(devices); + labels.Reshard(devices); + weights.Reshard(devices); + std::vector res_per_device(devices.Size()); + +#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) + for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) { + dh::safe_cuda(cudaSetDevice(id)); + size_t index = devices.Index(id); + res_per_device.at(index) = + DeviceReduceMetrics(id, index, weights, labels, preds); + } + + for (size_t i = 0; i < devices.Size(); ++i) { + result.residue_sum_ += res_per_device[i].residue_sum_; + result.weights_sum_ += res_per_device[i].weights_sum_; + } + } +#endif + return result; + } + + private: + EvalRow policy_; +#if defined(XGBOOST_USE_CUDA) + std::vector allocators_; +#endif +}; + +struct EvalRowRMSE { + char const *Name() const { + return "rmse"; + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { + bst_float diff = label - pred; + return diff * diff; + } + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return std::sqrt(esum / wsum); + } +}; + +struct EvalRowMAE { + const char *Name() const { + return "mae"; + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { + return std::abs(label - pred); + } + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } +}; + +struct EvalRowLogLoss { + const char *Name() const { + return "logloss"; + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { + const bst_float eps = 1e-16f; + const bst_float pneg = 1.0f - py; + if (py < eps) { + return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps); + } else if (pneg < eps) { + return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps); + } else { + return -y * std::log(py) - (1.0f - y) * std::log(pneg); + } + } + + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } +}; + +struct EvalError { + explicit EvalError(const char* param) { + if (param != nullptr) { + CHECK_EQ(sscanf(param, "%f", &threshold_), 1) + << "unable to parse the threshold value for the error metric"; + has_param_ = true; + } else { + threshold_ = 0.5f; + has_param_ = false; + } + } + const char *Name() const { + static std::string name; + if (has_param_) { + std::ostringstream os; + os << "error"; + if (threshold_ != 0.5f) os << '@' << threshold_; + name = os.str(); + return name.c_str(); + } else { + return "error"; + } + } + + XGBOOST_DEVICE bst_float EvalRow( + bst_float label, bst_float pred) const { + // assume label is in [0,1] + return pred > threshold_ ? 1.0f - label : label; + } + + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } + + private: + bst_float threshold_; + bool has_param_; +}; + +struct EvalPoissonNegLogLik { + const char *Name() const { + return "poisson-nloglik"; + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { + const bst_float eps = 1e-16f; + if (py < eps) py = eps; + return common::LogGamma(y + 1.0f) + py - std::log(py) * y; + } + + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } +}; + +struct EvalGammaDeviance { + const char *Name() const { + return "gamma-deviance"; + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { + bst_float epsilon = 1.0e-9; + bst_float tmp = label / (pred + epsilon); + return tmp - std::log(tmp) - 1; + } + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return 2 * esum; + } +}; + +struct EvalGammaNLogLik { + static const char *Name() { + return "gamma-nloglik"; + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { + bst_float psi = 1.0; + bst_float theta = -1. / py; + bst_float a = psi; + bst_float b = -std::log(-theta); + bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi); + return -((y * theta - b) / a + c); + } + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } +}; + +struct EvalTweedieNLogLik { + explicit EvalTweedieNLogLik(const char* param) { + CHECK(param != nullptr) + << "tweedie-nloglik must be in format tweedie-nloglik@rho"; + rho_ = atof(param); + CHECK(rho_ < 2 && rho_ >= 1) + << "tweedie variance power must be in interval [1, 2)"; + } + const char *Name() const { + static std::string name; + std::ostringstream os; + os << "tweedie-nloglik@" << rho_; + name = os.str(); + return name.c_str(); + } + + XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const { + bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_); + bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_); + return -a + b; + } + static bst_float GetFinal(bst_float esum, bst_float wsum) { + return esum / wsum; + } + + protected: + bst_float rho_; +}; +/*! + * \brief base class of element-wise evaluation + * \tparam Derived the name of subclass + */ +template +struct EvalEWiseBase : public Metric { + EvalEWiseBase() : policy_{}, reducer_{policy_} {} + explicit EvalEWiseBase(char const* policy_param) : + policy_{policy_param}, reducer_{policy_} {} + + void Configure( + const std::vector >& args) override { + param_.InitAllowUnknown(args); + } + + bst_float Eval(const HostDeviceVector& preds, + const MetaInfo& info, + bool distributed) override { + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) + << "label and prediction size not match, " + << "hint: use merror or mlogloss for multi-class classification"; + const auto ndata = static_cast(info.labels_.Size()); + // Dealing with ndata < n_gpus. + GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata); + + PackedReduceResult result = + reducer_.Reduce(devices, info.weights_, info.labels_, preds); + + double dat[2] { result.residue_sum_, result.weights_sum_ }; + if (distributed) { + rabit::Allreduce(dat, 2); + } + return Policy::GetFinal(dat[0], dat[1]); + } + + const char* Name() const override { + return policy_.Name(); + } + + private: + Policy policy_; + + MetricParam param_; + + MetricsReduction reducer_; +}; + +XGBOOST_REGISTER_METRIC(RMSE, "rmse") +.describe("Rooted mean square error.") +.set_body([](const char* param) { return new EvalEWiseBase(); }); + +XGBOOST_REGISTER_METRIC(MAE, "mae") +.describe("Mean absolute error.") +.set_body([](const char* param) { return new EvalEWiseBase(); }); + +XGBOOST_REGISTER_METRIC(LogLoss, "logloss") +.describe("Negative loglikelihood for logistic regression.") +.set_body([](const char* param) { return new EvalEWiseBase(); }); + +XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") +.describe("Negative loglikelihood for poisson regression.") +.set_body([](const char* param) { return new EvalEWiseBase(); }); + +XGBOOST_REGISTER_METRIC(GammaDeviance, "gamma-deviance") +.describe("Residual deviance for gamma regression.") +.set_body([](const char* param) { return new EvalEWiseBase(); }); + +XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik") +.describe("Negative log-likelihood for gamma regression.") +.set_body([](const char* param) { return new EvalEWiseBase(); }); + +XGBOOST_REGISTER_METRIC(Error, "error") +.describe("Binary classification error.") +.set_body([](const char* param) { return new EvalEWiseBase(param); }); + +XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik") +.describe("tweedie-nloglik@rho for tweedie regression.") +.set_body([](const char* param) { + return new EvalEWiseBase(param); +}); + +} // namespace metric +} // namespace xgboost diff --git a/src/metric/metric.cc b/src/metric/metric.cc index 7986dec6b..076a0ce91 100644 --- a/src/metric/metric.cc +++ b/src/metric/metric.cc @@ -6,6 +6,8 @@ #include #include +#include "metric_param.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); } @@ -34,6 +36,8 @@ Metric* Metric::Create(const std::string& name) { namespace xgboost { namespace metric { +DMLC_REGISTER_PARAMETER(MetricParam); + // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(elementwise_metric); DMLC_REGISTRY_LINK_TAG(multiclass_metric); diff --git a/src/metric/metric_param.h b/src/metric/metric_param.h new file mode 100644 index 000000000..3fd51a9f5 --- /dev/null +++ b/src/metric/metric_param.h @@ -0,0 +1,31 @@ +/*! + * Copyright 2018 by Contributors + * \file metric_param.cc + */ +#ifndef XGBOOST_METRIC_METRIC_PARAM_H_ +#define XGBOOST_METRIC_METRIC_PARAM_H_ + +#include +#include "../common/common.h" + +namespace xgboost { +namespace metric { + +// Created exclusively for GPU. +struct MetricParam : public dmlc::Parameter { + int n_gpus; + int gpu_id; + DMLC_DECLARE_PARAMETER(MetricParam) { + DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll) + .describe("Number of GPUs to use for multi-gpu algorithms."); + DMLC_DECLARE_FIELD(gpu_id) + .set_lower_bound(0) + .set_default(0) + .describe("gpu to use for objective function evaluation"); + }; +}; + +} // namespace metric +} // namespace xgboost + +#endif // XGBOOST_METRIC_METRIC_PARAM_H_ diff --git a/src/metric/multiclass_metric.cc b/src/metric/multiclass_metric.cc index be6279980..31a402755 100644 --- a/src/metric/multiclass_metric.cc +++ b/src/metric/multiclass_metric.cc @@ -20,13 +20,13 @@ DMLC_REGISTRY_FILE_TAG(multiclass_metric); */ template struct EvalMClassBase : public Metric { - bst_float Eval(const std::vector &preds, + bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) const override { + bool distributed) override { CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK(preds.size() % info.labels_.Size() == 0) + CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match"; - const size_t nclass = preds.size() / info.labels_.Size(); + const size_t nclass = preds.Size() / info.labels_.Size(); CHECK_GE(nclass, 1U) << "mlogloss and merror are only used for multi-class classification," << " use logloss for binary classification"; @@ -36,14 +36,15 @@ struct EvalMClassBase : public Metric { const auto& labels = info.labels_.HostVector(); const auto& weights = info.weights_.HostVector(); + const std::vector& h_preds = preds.HostVector(); - #pragma omp parallel for reduction(+: sum, wsum) schedule(static) +#pragma omp parallel for reduction(+: sum, wsum) schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f; auto label = static_cast(labels[i]); if (label >= 0 && label < static_cast(nclass)) { sum += Derived::EvalRow(label, - preds.data() + i * nclass, + h_preds.data() + i * nclass, nclass) * wt; wsum += wt; } else { diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index d1f6af909..43a5a2333 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -8,6 +8,10 @@ #include #include #include + +#include + +#include "../common/host_device_vector.h" #include "../common/math.h" namespace xgboost { @@ -26,18 +30,20 @@ struct EvalAMS : public Metric { os << "ams@" << ratio_; name_ = os.str(); } - bst_float Eval(const std::vector &preds, + + bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) const override { + bool distributed) override { CHECK(!distributed) << "metric AMS do not support distributed evaluation"; using namespace std; // NOLINT(*) const auto ndata = static_cast(info.labels_.Size()); std::vector > rec(ndata); - #pragma omp parallel for schedule(static) + const std::vector& h_preds = preds.HostVector(); +#pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { - rec[i] = std::make_pair(preds[i], i); + rec[i] = std::make_pair(h_preds[i], i); } std::sort(rec.begin(), rec.end(), common::CmpFirst); auto ntop = static_cast(ratio_ * ndata); @@ -82,11 +88,11 @@ struct EvalAMS : public Metric { /*! \brief Area Under Curve, for both classification and rank */ struct EvalAuc : public Metric { - bst_float Eval(const std::vector &preds, + bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) const override { + bool distributed) override { CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.size(), info.labels_.Size()) + CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels_.Size()); @@ -101,10 +107,11 @@ struct EvalAuc : public Metric { // each thread takes a local rec std::vector< std::pair > rec; const auto& labels = info.labels_.HostVector(); + const std::vector& h_preds = preds.HostVector(); for (bst_omp_uint k = 0; k < ngroup; ++k) { rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { - rec.emplace_back(preds[j], j); + rec.emplace_back(h_preds[j], j); } XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); // calculate AUC @@ -155,23 +162,25 @@ struct EvalAuc : public Metric { /*! \brief Evaluate rank list */ struct EvalRankList : public Metric { public: - bst_float Eval(const std::vector &preds, + bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) const override { - CHECK_EQ(preds.size(), info.labels_.Size()) + bool distributed) override { + CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; // quick consistency when group is not available std::vector tgptr(2, 0); - tgptr[1] = static_cast(preds.size()); + tgptr[1] = static_cast(preds.Size()); const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; CHECK_NE(gptr.size(), 0U) << "must specify group when constructing rank file"; - CHECK_EQ(gptr.back(), preds.size()) + CHECK_EQ(gptr.back(), preds.Size()) << "EvalRanklist: group structure must match number of prediction"; const auto ngroup = static_cast(gptr.size() - 1); // sum statistics double sum_metric = 0.0f; const auto& labels = info.labels_.HostVector(); - #pragma omp parallel reduction(+:sum_metric) + + const std::vector& h_preds = preds.HostVector(); +#pragma omp parallel reduction(+:sum_metric) { // each thread takes a local rec std::vector< std::pair > rec; @@ -179,7 +188,7 @@ struct EvalRankList : public Metric { for (bst_omp_uint k = 0; k < ngroup; ++k) { rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { - rec.emplace_back(preds[j], static_cast(labels[j])); + rec.emplace_back(h_preds[j], static_cast(labels[j])); } sum_metric += this->EvalMetric(rec); } @@ -311,9 +320,9 @@ struct EvalMAP : public EvalRankList { struct EvalCox : public Metric { public: EvalCox() = default; - bst_float Eval(const std::vector &preds, + bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, - bool distributed) const override { + bool distributed) override { CHECK(!distributed) << "Cox metric does not support distributed evaluation"; using namespace std; // NOLINT(*) @@ -322,8 +331,10 @@ struct EvalCox : public Metric { // pre-compute a sum for the denominator double exp_p_sum = 0; // we use double because we might need the precision with large datasets + + const std::vector& h_preds = preds.HostVector(); for (omp_ulong i = 0; i < ndata; ++i) { - exp_p_sum += preds[i]; + exp_p_sum += h_preds[i]; } double out = 0; @@ -334,12 +345,12 @@ struct EvalCox : public Metric { const size_t ind = label_order[i]; const auto label = labels[ind]; if (label > 0) { - out -= log(preds[ind]) - log(exp_p_sum); + out -= log(h_preds[ind]) - log(exp_p_sum); ++num_events; } // only update the denominator after we move forward in time (labels are sorted) - accumulated_sum += preds[ind]; + accumulated_sum += h_preds[ind]; if (i == ndata - 1 || std::abs(label) < std::abs(labels[label_order[i + 1]])) { exp_p_sum -= accumulated_sum; accumulated_sum = 0; @@ -360,10 +371,10 @@ struct EvalAucPR : public Metric { // translated from PRROC R Package // see https://doi.org/10.1371/journal.pone.0092209 - bst_float Eval(const std::vector &preds, const MetaInfo &info, - bool distributed) const override { + bst_float Eval(const HostDeviceVector &preds, const MetaInfo &info, + bool distributed) override { CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.size(), info.labels_.Size()) + CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels_.Size()); @@ -377,15 +388,17 @@ struct EvalAucPR : public Metric { int auc_error = 0, auc_gt_one = 0; // each thread takes a local rec std::vector> rec; - const auto& labels = info.labels_.HostVector(); + const auto& h_labels = info.labels_.HostVector(); + const std::vector& h_preds = preds.HostVector(); + for (bst_omp_uint k = 0; k < ngroup; ++k) { double total_pos = 0.0; double total_neg = 0.0; rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { - total_pos += info.GetWeight(j) * labels[j]; - total_neg += info.GetWeight(j) * (1.0f - labels[j]); - rec.emplace_back(preds[j], j); + total_pos += info.GetWeight(j) * h_labels[j]; + total_neg += info.GetWeight(j) * (1.0f - h_labels[j]); + rec.emplace_back(h_preds[j], j); } XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); // we need pos > 0 && neg > 0 @@ -395,8 +408,8 @@ struct EvalAucPR : public Metric { // calculate AUC double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; for (size_t j = 0; j < rec.size(); ++j) { - tp += info.GetWeight(rec[j].second) * labels[rec[j].second]; - fp += info.GetWeight(rec[j].second) * (1.0f - labels[rec[j].second]); + tp += info.GetWeight(rec[j].second) * h_labels[rec[j].second]; + fp += info.GetWeight(rec[j].second) * (1.0f - h_labels[rec[j].second]); if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) { if (tp == prevtp) { a = 1.0; @@ -471,4 +484,3 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .set_body([](const char* param) { return new EvalCox(); }); } // namespace metric } // namespace xgboost - diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 04bb48198..2eb53ad26 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -85,13 +85,14 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj, xgboost::bst_float GetMetricEval(xgboost::Metric * metric, - std::vector preds, + xgboost::HostDeviceVector preds, std::vector labels, std::vector weights) { xgboost::MetaInfo info; info.num_row_ = labels.size(); info.labels_.HostVector() = labels; info.weights_.HostVector() = weights; + return metric->Eval(preds, info, false); } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 641eec3be..9d64114e3 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -49,7 +49,7 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj, xgboost::bst_float GetMetricEval( xgboost::Metric * metric, - std::vector preds, + xgboost::HostDeviceVector preds, std::vector labels, std::vector weights = std::vector ()); diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index 3e0cb4c43..f4e3f4137 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -1,21 +1,34 @@ -// Copyright by Contributors +/*! + * Copyright 2018 XGBoost contributors + */ #include - +#include #include "../helpers.h" -TEST(Metric, RMSE) { +using Arg = std::pair; + +#if defined(__CUDACC__) +#define N_GPU() Arg{"n_gpus", "1"} +#else +#define N_GPU() Arg{"n_gpus", "0"} +#endif + +TEST(Metric, DeclareUnifiedTest(RMSE)) { xgboost::Metric * metric = xgboost::Metric::Create("rmse"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "rmse"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, - { 0, 0, 1, 1}), + { 0, 0, 1, 1}, + { 0, 1, 2, 3}), 0.6403f, 0.001f); delete metric; } -TEST(Metric, MAE) { +TEST(Metric, DeclareUnifiedTest(MAE)) { xgboost::Metric * metric = xgboost::Metric::Create("mae"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "mae"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric, @@ -25,8 +38,9 @@ TEST(Metric, MAE) { delete metric; } -TEST(Metric, LogLoss) { +TEST(Metric, DeclareUnifiedTest(LogLoss)) { xgboost::Metric * metric = xgboost::Metric::Create("logloss"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "logloss"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric, @@ -36,8 +50,9 @@ TEST(Metric, LogLoss) { delete metric; } -TEST(Metric, Error) { +TEST(Metric, DeclareUnifiedTest(Error)) { xgboost::Metric * metric = xgboost::Metric::Create("error"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "error"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric, @@ -47,11 +62,15 @@ TEST(Metric, Error) { EXPECT_ANY_THROW(xgboost::Metric::Create("error@abc")); delete metric; + metric = xgboost::Metric::Create("error@0.5f"); + metric->Configure({N_GPU()}); EXPECT_STREQ(metric->Name(), "error"); delete metric; + metric = xgboost::Metric::Create("error@0.1"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "error@0.1"); EXPECT_STREQ(metric->Name(), "error@0.1"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); @@ -62,8 +81,9 @@ TEST(Metric, Error) { delete metric; } -TEST(Metric, PoissionNegLogLik) { +TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik"); + metric->Configure({N_GPU()}); ASSERT_STREQ(metric->Name(), "poisson-nloglik"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.5f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, @@ -72,3 +92,31 @@ TEST(Metric, PoissionNegLogLik) { 1.1280f, 0.001f); delete metric; } + +#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__) +TEST(Metric, MGPU_RMSE) { + { + xgboost::Metric * metric = xgboost::Metric::Create("rmse"); + metric->Configure({Arg{"n_gpus", "-1"}}); + ASSERT_STREQ(metric->Name(), "rmse"); + EXPECT_NEAR(GetMetricEval(metric, {0}, {0}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + {0.1f, 0.9f, 0.1f, 0.9f}, + { 0, 0, 1, 1}), + 0.6403f, 0.001f); + delete metric; + } + + { + xgboost::Metric * metric = xgboost::Metric::Create("rmse"); + metric->Configure({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}}); + ASSERT_STREQ(metric->Name(), "rmse"); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + {0.1f, 0.9f, 0.1f, 0.9f}, + { 0, 0, 1, 1}), + 0.6403f, 0.001f); + delete metric; + } +} +#endif diff --git a/tests/cpp/metric/test_elementwise_metric.cu b/tests/cpp/metric/test_elementwise_metric.cu new file mode 100644 index 000000000..c45db8f7f --- /dev/null +++ b/tests/cpp/metric/test_elementwise_metric.cu @@ -0,0 +1,5 @@ +/*! + * Copyright 2018 XGBoost contributors + */ +// Dummy file to keep the CUDA conditional compile trick. +#include "test_elementwise_metric.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 5a235b110..e8082fc67 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -88,7 +88,9 @@ TEST(Metric, NDCG) { xgboost::Metric * metric = xgboost::Metric::Create("ndcg"); ASSERT_STREQ(metric->Name(), "ndcg"); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); - EXPECT_NEAR(GetMetricEval(metric, {}, {}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + xgboost::HostDeviceVector{}, + {}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, @@ -107,7 +109,9 @@ TEST(Metric, NDCG) { delete metric; metric = xgboost::Metric::Create("ndcg@-"); ASSERT_STREQ(metric->Name(), "ndcg@-"); - EXPECT_NEAR(GetMetricEval(metric, {}, {}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + xgboost::HostDeviceVector{}, + {}), 0, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, @@ -134,12 +138,16 @@ TEST(Metric, MAP) { {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), 0.5f, 0.001f); - EXPECT_NEAR(GetMetricEval(metric, {}, {}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + xgboost::HostDeviceVector{}, + std::vector{}), 1, 1e-10); delete metric; metric = xgboost::Metric::Create("map@-"); ASSERT_STREQ(metric->Name(), "map@-"); - EXPECT_NEAR(GetMetricEval(metric, {}, {}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + xgboost::HostDeviceVector{}, + {}), 0, 1e-10); delete metric; metric = xgboost::Metric::Create("map@2");