Porting elementwise metrics to GPU. (#3952)
* Port elementwise metrics to GPU. * All elementwise metrics are converted to static polymorphic. * Create a reducer for metrics reduction. * Remove const of Metric::Eval to accommodate CubMemory.
This commit is contained in:
parent
a9d684db18
commit
48dddfd635
@ -11,8 +11,11 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "./data.h"
|
||||
#include "./base.h"
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
@ -21,6 +24,23 @@ namespace xgboost {
|
||||
*/
|
||||
class Metric {
|
||||
public:
|
||||
/*!
|
||||
* \brief Configure the Metric with the specified parameters.
|
||||
* \param args arguments to the objective function.
|
||||
*/
|
||||
virtual void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> >& args) {}
|
||||
/*!
|
||||
* \brief set configuration from pair iterators.
|
||||
* \param begin The beginning iterator.
|
||||
* \param end The end iterator.
|
||||
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||
*/
|
||||
template<typename PairIter>
|
||||
inline void Configure(PairIter begin, PairIter end) {
|
||||
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||
this->Configure(vec);
|
||||
}
|
||||
/*!
|
||||
* \brief evaluate a specific metric
|
||||
* \param preds prediction
|
||||
@ -29,9 +49,9 @@ class Metric {
|
||||
* the average statistics across all the node,
|
||||
* this is only supported by some metrics
|
||||
*/
|
||||
virtual bst_float Eval(const std::vector<bst_float>& preds,
|
||||
virtual bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) const = 0;
|
||||
bool distributed) = 0;
|
||||
/*! \return name of metric */
|
||||
virtual const char* Name() const = 0;
|
||||
/*! \brief virtual destructor */
|
||||
|
||||
@ -127,7 +127,7 @@ inline bool CheckNAN(T v) {
|
||||
#endif
|
||||
}
|
||||
template<typename T>
|
||||
inline T LogGamma(T v) {
|
||||
XGBOOST_DEVICE inline T LogGamma(T v) {
|
||||
#ifdef _MSC_VER
|
||||
#if _MSC_VER >= 1800
|
||||
return lgamma(v);
|
||||
|
||||
@ -310,6 +310,10 @@ class LearnerImpl : public Learner {
|
||||
if (obj_ != nullptr) {
|
||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
|
||||
for (auto& p_metric : metrics_) {
|
||||
p_metric->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
}
|
||||
|
||||
void InitModel() override { this->LazyInitModel(); }
|
||||
@ -407,6 +411,10 @@ class LearnerImpl : public Learner {
|
||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||
|
||||
for (auto& p_metric : metrics_) {
|
||||
p_metric->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
}
|
||||
|
||||
// rabit save model to rabit checkpoint
|
||||
@ -503,13 +511,14 @@ class LearnerImpl : public Learner {
|
||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
this->PredictRaw(data_sets[i], &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
<< ev->Eval(preds_.ConstHostVector(), data_sets[i]->Info(),
|
||||
<< ev->Eval(preds_, data_sets[i]->Info(),
|
||||
tparam_.dsplit == DataSplitMode::kRow);
|
||||
}
|
||||
}
|
||||
@ -553,7 +562,7 @@ class LearnerImpl : public Learner {
|
||||
this->PredictRaw(data, &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
return std::make_pair(metric,
|
||||
ev->Eval(preds_.ConstHostVector(), data->Info(),
|
||||
ev->Eval(preds_, data->Info(),
|
||||
tparam_.dsplit == DataSplitMode::kRow));
|
||||
}
|
||||
|
||||
|
||||
@ -1,227 +1,8 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file elementwise_metric.cc
|
||||
* \brief evaluation metrics for elementwise binary or regression.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
|
||||
|
||||
/*!
|
||||
* \brief base class of element-wise evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalEWiseBase : public Metric {
|
||||
bst_float Eval(const std::vector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
||||
<< "label and prediction size not match, "
|
||||
<< "hint: use merror or mlogloss for multi-class classification";
|
||||
const auto ndata = static_cast<omp_ulong>(info.labels_.Size());
|
||||
double sum = 0.0, wsum = 0.0;
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
const auto& weights = info.weights_.HostVector();
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
|
||||
sum += static_cast<const Derived*>(this)->EvalRow(labels[i], preds[i]) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
double dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
*/
|
||||
inline bst_float EvalRow(bst_float label, bst_float pred) const;
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalRMSE : public EvalEWiseBase<EvalRMSE> {
|
||||
const char *Name() const override {
|
||||
return "rmse";
|
||||
}
|
||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float diff = label - pred;
|
||||
return diff * diff;
|
||||
}
|
||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return std::sqrt(esum / wsum);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalMAE : public EvalEWiseBase<EvalMAE> {
|
||||
const char *Name() const override {
|
||||
return "mae";
|
||||
}
|
||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
return std::abs(label - pred);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalLogLoss : public EvalEWiseBase<EvalLogLoss> {
|
||||
const char *Name() const override {
|
||||
return "logloss";
|
||||
}
|
||||
inline bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
const bst_float eps = 1e-16f;
|
||||
const bst_float pneg = 1.0f - py;
|
||||
if (py < eps) {
|
||||
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
|
||||
} else if (pneg < eps) {
|
||||
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
|
||||
} else {
|
||||
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalError : public EvalEWiseBase<EvalError> {
|
||||
explicit EvalError(const char* param) {
|
||||
if (param != nullptr) {
|
||||
std::ostringstream os;
|
||||
os << "error";
|
||||
CHECK_EQ(sscanf(param, "%f", &threshold_), 1)
|
||||
<< "unable to parse the threshold value for the error metric";
|
||||
if (threshold_ != 0.5f) os << '@' << threshold_;
|
||||
name_ = os.str();
|
||||
} else {
|
||||
threshold_ = 0.5f;
|
||||
name_ = "error";
|
||||
}
|
||||
}
|
||||
const char *Name() const override {
|
||||
return name_.c_str();
|
||||
}
|
||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
// assume label is in [0,1]
|
||||
return pred > threshold_ ? 1.0f - label : label;
|
||||
}
|
||||
protected:
|
||||
bst_float threshold_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
struct EvalPoissonNegLogLik : public EvalEWiseBase<EvalPoissonNegLogLik> {
|
||||
const char *Name() const override {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
inline bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
const bst_float eps = 1e-16f;
|
||||
if (py < eps) py = eps;
|
||||
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalGammaDeviance : public EvalEWiseBase<EvalGammaDeviance> {
|
||||
const char *Name() const override {
|
||||
return "gamma-deviance";
|
||||
}
|
||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float epsilon = 1.0e-9;
|
||||
bst_float tmp = label / (pred + epsilon);
|
||||
return tmp - std::log(tmp) - 1;
|
||||
}
|
||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return 2 * esum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalGammaNLogLik: public EvalEWiseBase<EvalGammaNLogLik> {
|
||||
const char *Name() const override {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
inline bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
bst_float psi = 1.0;
|
||||
bst_float theta = -1. / py;
|
||||
bst_float a = psi;
|
||||
bst_float b = -std::log(-theta);
|
||||
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
|
||||
return -((y * theta - b) / a + c);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalTweedieNLogLik: public EvalEWiseBase<EvalTweedieNLogLik> {
|
||||
explicit EvalTweedieNLogLik(const char* param) {
|
||||
CHECK(param != nullptr)
|
||||
<< "tweedie-nloglik must be in format tweedie-nloglik@rho";
|
||||
rho_ = atof(param);
|
||||
CHECK(rho_ < 2 && rho_ >= 1)
|
||||
<< "tweedie variance power must be in interval [1, 2)";
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << rho_;
|
||||
name_ = os.str();
|
||||
}
|
||||
const char *Name() const override {
|
||||
return name_.c_str();
|
||||
}
|
||||
inline bst_float EvalRow(bst_float y, bst_float p) const {
|
||||
bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
|
||||
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
|
||||
return -a + b;
|
||||
}
|
||||
protected:
|
||||
std::string name_;
|
||||
bst_float rho_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||
.describe("Rooted mean square error.")
|
||||
.set_body([](const char* param) { return new EvalRMSE(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MAE, "mae")
|
||||
.describe("Mean absolute error.")
|
||||
.set_body([](const char* param) { return new EvalMAE(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(LogLoss, "logloss")
|
||||
.describe("Negative loglikelihood for logistic regression.")
|
||||
.set_body([](const char* param) { return new EvalLogLoss(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Error, "error")
|
||||
.describe("Binary classification error.")
|
||||
.set_body([](const char* param) { return new EvalError(param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
||||
.describe("Negative loglikelihood for poisson regression.")
|
||||
.set_body([](const char* param) { return new EvalPoissonNegLogLik(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(GammaDeviance, "gamma-deviance")
|
||||
.describe("Residual deviance for gamma regression.")
|
||||
.set_body([](const char* param) { return new EvalGammaDeviance(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik")
|
||||
.describe("Negative log-likelihood for gamma regression.")
|
||||
.set_body([](const char* param) { return new EvalGammaNLogLik(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
|
||||
.describe("tweedie-nloglik@rho for tweedie regression.")
|
||||
.set_body([](const char* param) {
|
||||
return new EvalTweedieNLogLik(param);
|
||||
});
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#include "elementwise_metric.cu"
|
||||
#endif
|
||||
|
||||
406
src/metric/elementwise_metric.cu
Normal file
406
src/metric/elementwise_metric.cu
Normal file
@ -0,0 +1,406 @@
|
||||
/*!
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* \file elementwise_metric.cc
|
||||
* \brief evaluation metrics for elementwise binary or regression.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "metric_param.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/functional.h> // thrust::plus<>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
|
||||
|
||||
struct PackedReduceResult {
|
||||
double residue_sum_;
|
||||
double weights_sum_;
|
||||
|
||||
XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {}
|
||||
XGBOOST_DEVICE PackedReduceResult(double residue, double weight) :
|
||||
residue_sum_{residue}, weights_sum_{weight} {}
|
||||
|
||||
XGBOOST_DEVICE
|
||||
PackedReduceResult operator+(PackedReduceResult const& other) const {
|
||||
return PackedReduceResult { residue_sum_ + other.residue_sum_,
|
||||
weights_sum_ + other.weights_sum_ };
|
||||
}
|
||||
};
|
||||
|
||||
template <typename EvalRow>
|
||||
class MetricsReduction {
|
||||
public:
|
||||
explicit MetricsReduction(EvalRow policy) :
|
||||
policy_(std::move(policy)) {}
|
||||
|
||||
PackedReduceResult CpuReduceMetrics(
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) const {
|
||||
size_t ndata = labels.Size();
|
||||
|
||||
const auto& h_labels = labels.HostVector();
|
||||
const auto& h_weights = weights.HostVector();
|
||||
const auto& h_preds = preds.HostVector();
|
||||
|
||||
bst_float residue_sum = 0;
|
||||
bst_float weights_sum = 0;
|
||||
|
||||
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
const bst_float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
|
||||
residue_sum += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
|
||||
weights_sum += wt;
|
||||
}
|
||||
PackedReduceResult res { residue_sum, weights_sum };
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
|
||||
PackedReduceResult DeviceReduceMetrics(
|
||||
GPUSet::GpuIdType device_id,
|
||||
size_t device_index,
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
size_t n_data = preds.DeviceSize(device_id);
|
||||
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + n_data;
|
||||
|
||||
auto s_label = labels.DeviceSpan(device_id);
|
||||
auto s_preds = preds.DeviceSpan(device_id);
|
||||
auto s_weights = weights.DeviceSpan(device_id);
|
||||
|
||||
bool const is_null_weight = weights.Size() == 0;
|
||||
|
||||
auto d_policy = policy_;
|
||||
|
||||
PackedReduceResult result = thrust::transform_reduce(
|
||||
thrust::cuda::par(allocators_.at(device_index)),
|
||||
begin, end,
|
||||
[=] XGBOOST_DEVICE(size_t idx) {
|
||||
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
||||
|
||||
bst_float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]);
|
||||
residue *= weight;
|
||||
return PackedReduceResult{ residue, weight };
|
||||
},
|
||||
PackedReduceResult(),
|
||||
thrust::plus<PackedReduceResult>());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
PackedReduceResult Reduce(
|
||||
GPUSet devices,
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
PackedReduceResult result;
|
||||
|
||||
if (devices.IsEmpty()) {
|
||||
result = CpuReduceMetrics(weights, labels, preds);
|
||||
}
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
else { // NOLINT
|
||||
if (allocators_.size() != devices.Size()) {
|
||||
allocators_.clear();
|
||||
allocators_.resize(devices.Size());
|
||||
}
|
||||
preds.Reshard(devices);
|
||||
labels.Reshard(devices);
|
||||
weights.Reshard(devices);
|
||||
std::vector<PackedReduceResult> res_per_device(devices.Size());
|
||||
|
||||
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
|
||||
for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) {
|
||||
dh::safe_cuda(cudaSetDevice(id));
|
||||
size_t index = devices.Index(id);
|
||||
res_per_device.at(index) =
|
||||
DeviceReduceMetrics(id, index, weights, labels, preds);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||
result.residue_sum_ += res_per_device[i].residue_sum_;
|
||||
result.weights_sum_ += res_per_device[i].weights_sum_;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
EvalRow policy_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
std::vector<dh::CubMemory> allocators_;
|
||||
#endif
|
||||
};
|
||||
|
||||
struct EvalRowRMSE {
|
||||
char const *Name() const {
|
||||
return "rmse";
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float diff = label - pred;
|
||||
return diff * diff;
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return std::sqrt(esum / wsum);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalRowMAE {
|
||||
const char *Name() const {
|
||||
return "mae";
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
return std::abs(label - pred);
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalRowLogLoss {
|
||||
const char *Name() const {
|
||||
return "logloss";
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
const bst_float eps = 1e-16f;
|
||||
const bst_float pneg = 1.0f - py;
|
||||
if (py < eps) {
|
||||
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
|
||||
} else if (pneg < eps) {
|
||||
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
|
||||
} else {
|
||||
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
|
||||
}
|
||||
}
|
||||
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalError {
|
||||
explicit EvalError(const char* param) {
|
||||
if (param != nullptr) {
|
||||
CHECK_EQ(sscanf(param, "%f", &threshold_), 1)
|
||||
<< "unable to parse the threshold value for the error metric";
|
||||
has_param_ = true;
|
||||
} else {
|
||||
threshold_ = 0.5f;
|
||||
has_param_ = false;
|
||||
}
|
||||
}
|
||||
const char *Name() const {
|
||||
static std::string name;
|
||||
if (has_param_) {
|
||||
std::ostringstream os;
|
||||
os << "error";
|
||||
if (threshold_ != 0.5f) os << '@' << threshold_;
|
||||
name = os.str();
|
||||
return name.c_str();
|
||||
} else {
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(
|
||||
bst_float label, bst_float pred) const {
|
||||
// assume label is in [0,1]
|
||||
return pred > threshold_ ? 1.0f - label : label;
|
||||
}
|
||||
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
|
||||
private:
|
||||
bst_float threshold_;
|
||||
bool has_param_;
|
||||
};
|
||||
|
||||
struct EvalPoissonNegLogLik {
|
||||
const char *Name() const {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
const bst_float eps = 1e-16f;
|
||||
if (py < eps) py = eps;
|
||||
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
|
||||
}
|
||||
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalGammaDeviance {
|
||||
const char *Name() const {
|
||||
return "gamma-deviance";
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float epsilon = 1.0e-9;
|
||||
bst_float tmp = label / (pred + epsilon);
|
||||
return tmp - std::log(tmp) - 1;
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return 2 * esum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalGammaNLogLik {
|
||||
static const char *Name() {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
bst_float psi = 1.0;
|
||||
bst_float theta = -1. / py;
|
||||
bst_float a = psi;
|
||||
bst_float b = -std::log(-theta);
|
||||
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
|
||||
return -((y * theta - b) / a + c);
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalTweedieNLogLik {
|
||||
explicit EvalTweedieNLogLik(const char* param) {
|
||||
CHECK(param != nullptr)
|
||||
<< "tweedie-nloglik must be in format tweedie-nloglik@rho";
|
||||
rho_ = atof(param);
|
||||
CHECK(rho_ < 2 && rho_ >= 1)
|
||||
<< "tweedie variance power must be in interval [1, 2)";
|
||||
}
|
||||
const char *Name() const {
|
||||
static std::string name;
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << rho_;
|
||||
name = os.str();
|
||||
return name.c_str();
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const {
|
||||
bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
|
||||
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
|
||||
return -a + b;
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
|
||||
protected:
|
||||
bst_float rho_;
|
||||
};
|
||||
/*!
|
||||
* \brief base class of element-wise evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Policy>
|
||||
struct EvalEWiseBase : public Metric {
|
||||
EvalEWiseBase() : policy_{}, reducer_{policy_} {}
|
||||
explicit EvalEWiseBase(char const* policy_param) :
|
||||
policy_{policy_param}, reducer_{policy_} {}
|
||||
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
|
||||
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "label and prediction size not match, "
|
||||
<< "hint: use merror or mlogloss for multi-class classification";
|
||||
const auto ndata = static_cast<omp_ulong>(info.labels_.Size());
|
||||
// Dealing with ndata < n_gpus.
|
||||
GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata);
|
||||
|
||||
PackedReduceResult result =
|
||||
reducer_.Reduce(devices, info.weights_, info.labels_, preds);
|
||||
|
||||
double dat[2] { result.residue_sum_, result.weights_sum_ };
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
return policy_.Name();
|
||||
}
|
||||
|
||||
private:
|
||||
Policy policy_;
|
||||
|
||||
MetricParam param_;
|
||||
|
||||
MetricsReduction<Policy> reducer_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||
.describe("Rooted mean square error.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowRMSE>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MAE, "mae")
|
||||
.describe("Mean absolute error.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowMAE>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(LogLoss, "logloss")
|
||||
.describe("Negative loglikelihood for logistic regression.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowLogLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
||||
.describe("Negative loglikelihood for poisson regression.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalPoissonNegLogLik>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(GammaDeviance, "gamma-deviance")
|
||||
.describe("Residual deviance for gamma regression.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalGammaDeviance>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik")
|
||||
.describe("Negative log-likelihood for gamma regression.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalGammaNLogLik>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Error, "error")
|
||||
.describe("Binary classification error.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalError>(param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
|
||||
.describe("tweedie-nloglik@rho for tweedie regression.")
|
||||
.set_body([](const char* param) {
|
||||
return new EvalEWiseBase<EvalTweedieNLogLik>(param);
|
||||
});
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
@ -6,6 +6,8 @@
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include "metric_param.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
}
|
||||
@ -34,6 +36,8 @@ Metric* Metric::Create(const std::string& name) {
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
DMLC_REGISTER_PARAMETER(MetricParam);
|
||||
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_metric);
|
||||
|
||||
31
src/metric/metric_param.h
Normal file
31
src/metric/metric_param.h
Normal file
@ -0,0 +1,31 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file metric_param.cc
|
||||
*/
|
||||
#ifndef XGBOOST_METRIC_METRIC_PARAM_H_
|
||||
#define XGBOOST_METRIC_METRIC_PARAM_H_
|
||||
|
||||
#include <dmlc/parameter.h>
|
||||
#include "../common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
|
||||
// Created exclusively for GPU.
|
||||
struct MetricParam : public dmlc::Parameter<MetricParam> {
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(MetricParam) {
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_METRIC_METRIC_PARAM_H_
|
||||
@ -20,13 +20,13 @@ DMLC_REGISTRY_FILE_TAG(multiclass_metric);
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public Metric {
|
||||
bst_float Eval(const std::vector<bst_float> &preds,
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.size() % info.labels_.Size() == 0)
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0)
|
||||
<< "label and prediction size not match";
|
||||
const size_t nclass = preds.size() / info.labels_.Size();
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
@ -36,14 +36,15 @@ struct EvalMClassBase : public Metric {
|
||||
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
const auto& weights = info.weights_.HostVector();
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
|
||||
auto label = static_cast<int>(labels[i]);
|
||||
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||
sum += Derived::EvalRow(label,
|
||||
preds.data() + i * nclass,
|
||||
h_preds.data() + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
} else {
|
||||
|
||||
@ -8,6 +8,10 @@
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../common/host_device_vector.h"
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -26,18 +30,20 @@ struct EvalAMS : public Metric {
|
||||
os << "ams@" << ratio_;
|
||||
name_ = os.str();
|
||||
}
|
||||
bst_float Eval(const std::vector<bst_float> &preds,
|
||||
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
bool distributed) override {
|
||||
CHECK(!distributed) << "metric AMS do not support distributed evaluation";
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
|
||||
std::vector<std::pair<bst_float, unsigned> > rec(ndata);
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
rec[i] = std::make_pair(preds[i], i);
|
||||
rec[i] = std::make_pair(h_preds[i], i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
@ -82,11 +88,11 @@ struct EvalAMS : public Metric {
|
||||
|
||||
/*! \brief Area Under Curve, for both classification and rank */
|
||||
struct EvalAuc : public Metric {
|
||||
bst_float Eval(const std::vector<bst_float> &preds,
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "label size predict size not match";
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
||||
@ -101,10 +107,11 @@ struct EvalAuc : public Metric {
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.emplace_back(preds[j], j);
|
||||
rec.emplace_back(h_preds[j], j);
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// calculate AUC
|
||||
@ -155,23 +162,25 @@ struct EvalAuc : public Metric {
|
||||
/*! \brief Evaluate rank list */
|
||||
struct EvalRankList : public Metric {
|
||||
public:
|
||||
bst_float Eval(const std::vector<bst_float> &preds,
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
||||
bool distributed) override {
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "label size predict size not match";
|
||||
// quick consistency when group is not available
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(preds.size());
|
||||
tgptr[1] = static_cast<unsigned>(preds.Size());
|
||||
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
||||
CHECK_NE(gptr.size(), 0U) << "must specify group when constructing rank file";
|
||||
CHECK_EQ(gptr.back(), preds.size())
|
||||
CHECK_EQ(gptr.back(), preds.Size())
|
||||
<< "EvalRanklist: group structure must match number of prediction";
|
||||
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double sum_metric = 0.0f;
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
#pragma omp parallel reduction(+:sum_metric)
|
||||
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
#pragma omp parallel reduction(+:sum_metric)
|
||||
{
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||
@ -179,7 +188,7 @@ struct EvalRankList : public Metric {
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.emplace_back(preds[j], static_cast<int>(labels[j]));
|
||||
rec.emplace_back(h_preds[j], static_cast<int>(labels[j]));
|
||||
}
|
||||
sum_metric += this->EvalMetric(rec);
|
||||
}
|
||||
@ -311,9 +320,9 @@ struct EvalMAP : public EvalRankList {
|
||||
struct EvalCox : public Metric {
|
||||
public:
|
||||
EvalCox() = default;
|
||||
bst_float Eval(const std::vector<bst_float> &preds,
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
bool distributed) override {
|
||||
CHECK(!distributed) << "Cox metric does not support distributed evaluation";
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
@ -322,8 +331,10 @@ struct EvalCox : public Metric {
|
||||
|
||||
// pre-compute a sum for the denominator
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
exp_p_sum += preds[i];
|
||||
exp_p_sum += h_preds[i];
|
||||
}
|
||||
|
||||
double out = 0;
|
||||
@ -334,12 +345,12 @@ struct EvalCox : public Metric {
|
||||
const size_t ind = label_order[i];
|
||||
const auto label = labels[ind];
|
||||
if (label > 0) {
|
||||
out -= log(preds[ind]) - log(exp_p_sum);
|
||||
out -= log(h_preds[ind]) - log(exp_p_sum);
|
||||
++num_events;
|
||||
}
|
||||
|
||||
// only update the denominator after we move forward in time (labels are sorted)
|
||||
accumulated_sum += preds[ind];
|
||||
accumulated_sum += h_preds[ind];
|
||||
if (i == ndata - 1 || std::abs(label) < std::abs(labels[label_order[i + 1]])) {
|
||||
exp_p_sum -= accumulated_sum;
|
||||
accumulated_sum = 0;
|
||||
@ -360,10 +371,10 @@ struct EvalAucPR : public Metric {
|
||||
// translated from PRROC R Package
|
||||
// see https://doi.org/10.1371/journal.pone.0092209
|
||||
|
||||
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "label size predict size not match";
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
||||
@ -377,15 +388,17 @@ struct EvalAucPR : public Metric {
|
||||
int auc_error = 0, auc_gt_one = 0;
|
||||
// each thread takes a local rec
|
||||
std::vector<std::pair<bst_float, unsigned>> rec;
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
const auto& h_labels = info.labels_.HostVector();
|
||||
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
double total_pos = 0.0;
|
||||
double total_neg = 0.0;
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
total_pos += info.GetWeight(j) * labels[j];
|
||||
total_neg += info.GetWeight(j) * (1.0f - labels[j]);
|
||||
rec.emplace_back(preds[j], j);
|
||||
total_pos += info.GetWeight(j) * h_labels[j];
|
||||
total_neg += info.GetWeight(j) * (1.0f - h_labels[j]);
|
||||
rec.emplace_back(h_preds[j], j);
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// we need pos > 0 && neg > 0
|
||||
@ -395,8 +408,8 @@ struct EvalAucPR : public Metric {
|
||||
// calculate AUC
|
||||
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
tp += info.GetWeight(rec[j].second) * labels[rec[j].second];
|
||||
fp += info.GetWeight(rec[j].second) * (1.0f - labels[rec[j].second]);
|
||||
tp += info.GetWeight(rec[j].second) * h_labels[rec[j].second];
|
||||
fp += info.GetWeight(rec[j].second) * (1.0f - h_labels[rec[j].second]);
|
||||
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
|
||||
if (tp == prevtp) {
|
||||
a = 1.0;
|
||||
@ -471,4 +484,3 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
||||
.set_body([](const char* param) { return new EvalCox(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -85,13 +85,14 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj,
|
||||
|
||||
|
||||
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row_ = labels.size();
|
||||
info.labels_.HostVector() = labels;
|
||||
info.weights_.HostVector() = weights;
|
||||
|
||||
return metric->Eval(preds, info, false);
|
||||
}
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj,
|
||||
|
||||
xgboost::bst_float GetMetricEval(
|
||||
xgboost::Metric * metric,
|
||||
std::vector<xgboost::bst_float> preds,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
||||
|
||||
|
||||
@ -1,21 +1,34 @@
|
||||
// Copyright by Contributors
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#include <map>
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Metric, RMSE) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define N_GPU() Arg{"n_gpus", "1"}
|
||||
#else
|
||||
#define N_GPU() Arg{"n_gpus", "0"}
|
||||
#endif
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(RMSE)) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||
metric->Configure({N_GPU()});
|
||||
ASSERT_STREQ(metric->Name(), "rmse");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}),
|
||||
{ 0, 0, 1, 1},
|
||||
{ 0, 1, 2, 3}),
|
||||
0.6403f, 0.001f);
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, MAE) {
|
||||
TEST(Metric, DeclareUnifiedTest(MAE)) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("mae");
|
||||
metric->Configure({N_GPU()});
|
||||
ASSERT_STREQ(metric->Name(), "mae");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
@ -25,8 +38,9 @@ TEST(Metric, MAE) {
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, LogLoss) {
|
||||
TEST(Metric, DeclareUnifiedTest(LogLoss)) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("logloss");
|
||||
metric->Configure({N_GPU()});
|
||||
ASSERT_STREQ(metric->Name(), "logloss");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
@ -36,8 +50,9 @@ TEST(Metric, LogLoss) {
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, Error) {
|
||||
TEST(Metric, DeclareUnifiedTest(Error)) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("error");
|
||||
metric->Configure({N_GPU()});
|
||||
ASSERT_STREQ(metric->Name(), "error");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
@ -47,11 +62,15 @@ TEST(Metric, Error) {
|
||||
|
||||
EXPECT_ANY_THROW(xgboost::Metric::Create("error@abc"));
|
||||
delete metric;
|
||||
|
||||
metric = xgboost::Metric::Create("error@0.5f");
|
||||
metric->Configure({N_GPU()});
|
||||
EXPECT_STREQ(metric->Name(), "error");
|
||||
|
||||
delete metric;
|
||||
|
||||
metric = xgboost::Metric::Create("error@0.1");
|
||||
metric->Configure({N_GPU()});
|
||||
ASSERT_STREQ(metric->Name(), "error@0.1");
|
||||
EXPECT_STREQ(metric->Name(), "error@0.1");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
@ -62,8 +81,9 @@ TEST(Metric, Error) {
|
||||
delete metric;
|
||||
}
|
||||
|
||||
TEST(Metric, PoissionNegLogLik) {
|
||||
TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik");
|
||||
metric->Configure({N_GPU()});
|
||||
ASSERT_STREQ(metric->Name(), "poisson-nloglik");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.5f, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
@ -72,3 +92,31 @@ TEST(Metric, PoissionNegLogLik) {
|
||||
1.1280f, 0.001f);
|
||||
delete metric;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
|
||||
TEST(Metric, MGPU_RMSE) {
|
||||
{
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||
metric->Configure({Arg{"n_gpus", "-1"}});
|
||||
ASSERT_STREQ(metric->Name(), "rmse");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0}, {0}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}),
|
||||
0.6403f, 0.001f);
|
||||
delete metric;
|
||||
}
|
||||
|
||||
{
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||
metric->Configure({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}});
|
||||
ASSERT_STREQ(metric->Name(), "rmse");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}),
|
||||
0.6403f, 0.001f);
|
||||
delete metric;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
5
tests/cpp/metric/test_elementwise_metric.cu
Normal file
5
tests/cpp/metric/test_elementwise_metric.cu
Normal file
@ -0,0 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
#include "test_elementwise_metric.cc"
|
||||
@ -88,7 +88,9 @@ TEST(Metric, NDCG) {
|
||||
xgboost::Metric * metric = xgboost::Metric::Create("ndcg");
|
||||
ASSERT_STREQ(metric->Name(), "ndcg");
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
|
||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
@ -107,7 +109,9 @@ TEST(Metric, NDCG) {
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("ndcg@-");
|
||||
ASSERT_STREQ(metric->Name(), "ndcg@-");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
@ -134,12 +138,16 @@ TEST(Metric, MAP) {
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}),
|
||||
0.5f, 0.001f);
|
||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
std::vector<xgboost::bst_float>{}), 1, 1e-10);
|
||||
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("map@-");
|
||||
ASSERT_STREQ(metric->Name(), "map@-");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 0, 1e-10);
|
||||
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("map@2");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user