Porting elementwise metrics to GPU. (#3952)
* Port elementwise metrics to GPU. * All elementwise metrics are converted to static polymorphic. * Create a reducer for metrics reduction. * Remove const of Metric::Eval to accommodate CubMemory.
This commit is contained in:
parent
a9d684db18
commit
48dddfd635
@ -11,8 +11,11 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "./data.h"
|
#include "./data.h"
|
||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
|
#include "../../src/common/host_device_vector.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*!
|
/*!
|
||||||
@ -21,6 +24,23 @@ namespace xgboost {
|
|||||||
*/
|
*/
|
||||||
class Metric {
|
class Metric {
|
||||||
public:
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief Configure the Metric with the specified parameters.
|
||||||
|
* \param args arguments to the objective function.
|
||||||
|
*/
|
||||||
|
virtual void Configure(
|
||||||
|
const std::vector<std::pair<std::string, std::string> >& args) {}
|
||||||
|
/*!
|
||||||
|
* \brief set configuration from pair iterators.
|
||||||
|
* \param begin The beginning iterator.
|
||||||
|
* \param end The end iterator.
|
||||||
|
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||||
|
*/
|
||||||
|
template<typename PairIter>
|
||||||
|
inline void Configure(PairIter begin, PairIter end) {
|
||||||
|
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||||
|
this->Configure(vec);
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief evaluate a specific metric
|
* \brief evaluate a specific metric
|
||||||
* \param preds prediction
|
* \param preds prediction
|
||||||
@ -29,9 +49,9 @@ class Metric {
|
|||||||
* the average statistics across all the node,
|
* the average statistics across all the node,
|
||||||
* this is only supported by some metrics
|
* this is only supported by some metrics
|
||||||
*/
|
*/
|
||||||
virtual bst_float Eval(const std::vector<bst_float>& preds,
|
virtual bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
bool distributed) const = 0;
|
bool distributed) = 0;
|
||||||
/*! \return name of metric */
|
/*! \return name of metric */
|
||||||
virtual const char* Name() const = 0;
|
virtual const char* Name() const = 0;
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
|
|||||||
@ -127,7 +127,7 @@ inline bool CheckNAN(T v) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline T LogGamma(T v) {
|
XGBOOST_DEVICE inline T LogGamma(T v) {
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#if _MSC_VER >= 1800
|
#if _MSC_VER >= 1800
|
||||||
return lgamma(v);
|
return lgamma(v);
|
||||||
|
|||||||
@ -310,6 +310,10 @@ class LearnerImpl : public Learner {
|
|||||||
if (obj_ != nullptr) {
|
if (obj_ != nullptr) {
|
||||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto& p_metric : metrics_) {
|
||||||
|
p_metric->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitModel() override { this->LazyInitModel(); }
|
void InitModel() override { this->LazyInitModel(); }
|
||||||
@ -407,6 +411,10 @@ class LearnerImpl : public Learner {
|
|||||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
|
||||||
|
for (auto& p_metric : metrics_) {
|
||||||
|
p_metric->Configure(cfg_.begin(), cfg_.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rabit save model to rabit checkpoint
|
// rabit save model to rabit checkpoint
|
||||||
@ -503,13 +511,14 @@ class LearnerImpl : public Learner {
|
|||||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
||||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||||
|
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||||
this->PredictRaw(data_sets[i], &preds_);
|
this->PredictRaw(data_sets[i], &preds_);
|
||||||
obj_->EvalTransform(&preds_);
|
obj_->EvalTransform(&preds_);
|
||||||
for (auto& ev : metrics_) {
|
for (auto& ev : metrics_) {
|
||||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||||
<< ev->Eval(preds_.ConstHostVector(), data_sets[i]->Info(),
|
<< ev->Eval(preds_, data_sets[i]->Info(),
|
||||||
tparam_.dsplit == DataSplitMode::kRow);
|
tparam_.dsplit == DataSplitMode::kRow);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -553,7 +562,7 @@ class LearnerImpl : public Learner {
|
|||||||
this->PredictRaw(data, &preds_);
|
this->PredictRaw(data, &preds_);
|
||||||
obj_->EvalTransform(&preds_);
|
obj_->EvalTransform(&preds_);
|
||||||
return std::make_pair(metric,
|
return std::make_pair(metric,
|
||||||
ev->Eval(preds_.ConstHostVector(), data->Info(),
|
ev->Eval(preds_, data->Info(),
|
||||||
tparam_.dsplit == DataSplitMode::kRow));
|
tparam_.dsplit == DataSplitMode::kRow));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,227 +1,8 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2015 by Contributors
|
* Copyright 2018 XGBoost contributors
|
||||||
* \file elementwise_metric.cc
|
|
||||||
* \brief evaluation metrics for elementwise binary or regression.
|
|
||||||
* \author Kailong Chen, Tianqi Chen
|
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
// Dummy file to keep the CUDA conditional compile trick.
|
||||||
#include <xgboost/metric.h>
|
|
||||||
#include <dmlc/registry.h>
|
|
||||||
#include <cmath>
|
|
||||||
#include "../common/math.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
namespace metric {
|
#include "elementwise_metric.cu"
|
||||||
// tag the this file, used by force static link later.
|
#endif
|
||||||
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief base class of element-wise evaluation
|
|
||||||
* \tparam Derived the name of subclass
|
|
||||||
*/
|
|
||||||
template<typename Derived>
|
|
||||||
struct EvalEWiseBase : public Metric {
|
|
||||||
bst_float Eval(const std::vector<bst_float>& preds,
|
|
||||||
const MetaInfo& info,
|
|
||||||
bool distributed) const override {
|
|
||||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
|
||||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
|
||||||
<< "label and prediction size not match, "
|
|
||||||
<< "hint: use merror or mlogloss for multi-class classification";
|
|
||||||
const auto ndata = static_cast<omp_ulong>(info.labels_.Size());
|
|
||||||
double sum = 0.0, wsum = 0.0;
|
|
||||||
const auto& labels = info.labels_.HostVector();
|
|
||||||
const auto& weights = info.weights_.HostVector();
|
|
||||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
|
||||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
|
||||||
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
|
|
||||||
sum += static_cast<const Derived*>(this)->EvalRow(labels[i], preds[i]) * wt;
|
|
||||||
wsum += wt;
|
|
||||||
}
|
|
||||||
double dat[2]; dat[0] = sum, dat[1] = wsum;
|
|
||||||
if (distributed) {
|
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
|
||||||
}
|
|
||||||
return Derived::GetFinal(dat[0], dat[1]);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief to be implemented by subclass,
|
|
||||||
* get evaluation result from one row
|
|
||||||
* \param label label of current instance
|
|
||||||
* \param pred prediction value of current instance
|
|
||||||
*/
|
|
||||||
inline bst_float EvalRow(bst_float label, bst_float pred) const;
|
|
||||||
/*!
|
|
||||||
* \brief to be overridden by subclass, final transformation
|
|
||||||
* \param esum the sum statistics returned by EvalRow
|
|
||||||
* \param wsum sum of weight
|
|
||||||
*/
|
|
||||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
|
||||||
return esum / wsum;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalRMSE : public EvalEWiseBase<EvalRMSE> {
|
|
||||||
const char *Name() const override {
|
|
||||||
return "rmse";
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
|
||||||
bst_float diff = label - pred;
|
|
||||||
return diff * diff;
|
|
||||||
}
|
|
||||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
|
||||||
return std::sqrt(esum / wsum);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalMAE : public EvalEWiseBase<EvalMAE> {
|
|
||||||
const char *Name() const override {
|
|
||||||
return "mae";
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
|
||||||
return std::abs(label - pred);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalLogLoss : public EvalEWiseBase<EvalLogLoss> {
|
|
||||||
const char *Name() const override {
|
|
||||||
return "logloss";
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float y, bst_float py) const {
|
|
||||||
const bst_float eps = 1e-16f;
|
|
||||||
const bst_float pneg = 1.0f - py;
|
|
||||||
if (py < eps) {
|
|
||||||
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
|
|
||||||
} else if (pneg < eps) {
|
|
||||||
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
|
|
||||||
} else {
|
|
||||||
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalError : public EvalEWiseBase<EvalError> {
|
|
||||||
explicit EvalError(const char* param) {
|
|
||||||
if (param != nullptr) {
|
|
||||||
std::ostringstream os;
|
|
||||||
os << "error";
|
|
||||||
CHECK_EQ(sscanf(param, "%f", &threshold_), 1)
|
|
||||||
<< "unable to parse the threshold value for the error metric";
|
|
||||||
if (threshold_ != 0.5f) os << '@' << threshold_;
|
|
||||||
name_ = os.str();
|
|
||||||
} else {
|
|
||||||
threshold_ = 0.5f;
|
|
||||||
name_ = "error";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const char *Name() const override {
|
|
||||||
return name_.c_str();
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
|
||||||
// assume label is in [0,1]
|
|
||||||
return pred > threshold_ ? 1.0f - label : label;
|
|
||||||
}
|
|
||||||
protected:
|
|
||||||
bst_float threshold_;
|
|
||||||
std::string name_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalPoissonNegLogLik : public EvalEWiseBase<EvalPoissonNegLogLik> {
|
|
||||||
const char *Name() const override {
|
|
||||||
return "poisson-nloglik";
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float y, bst_float py) const {
|
|
||||||
const bst_float eps = 1e-16f;
|
|
||||||
if (py < eps) py = eps;
|
|
||||||
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalGammaDeviance : public EvalEWiseBase<EvalGammaDeviance> {
|
|
||||||
const char *Name() const override {
|
|
||||||
return "gamma-deviance";
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float label, bst_float pred) const {
|
|
||||||
bst_float epsilon = 1.0e-9;
|
|
||||||
bst_float tmp = label / (pred + epsilon);
|
|
||||||
return tmp - std::log(tmp) - 1;
|
|
||||||
}
|
|
||||||
inline static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
|
||||||
return 2 * esum;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalGammaNLogLik: public EvalEWiseBase<EvalGammaNLogLik> {
|
|
||||||
const char *Name() const override {
|
|
||||||
return "gamma-nloglik";
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float y, bst_float py) const {
|
|
||||||
bst_float psi = 1.0;
|
|
||||||
bst_float theta = -1. / py;
|
|
||||||
bst_float a = psi;
|
|
||||||
bst_float b = -std::log(-theta);
|
|
||||||
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
|
|
||||||
return -((y * theta - b) / a + c);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct EvalTweedieNLogLik: public EvalEWiseBase<EvalTweedieNLogLik> {
|
|
||||||
explicit EvalTweedieNLogLik(const char* param) {
|
|
||||||
CHECK(param != nullptr)
|
|
||||||
<< "tweedie-nloglik must be in format tweedie-nloglik@rho";
|
|
||||||
rho_ = atof(param);
|
|
||||||
CHECK(rho_ < 2 && rho_ >= 1)
|
|
||||||
<< "tweedie variance power must be in interval [1, 2)";
|
|
||||||
std::ostringstream os;
|
|
||||||
os << "tweedie-nloglik@" << rho_;
|
|
||||||
name_ = os.str();
|
|
||||||
}
|
|
||||||
const char *Name() const override {
|
|
||||||
return name_.c_str();
|
|
||||||
}
|
|
||||||
inline bst_float EvalRow(bst_float y, bst_float p) const {
|
|
||||||
bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
|
|
||||||
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
|
|
||||||
return -a + b;
|
|
||||||
}
|
|
||||||
protected:
|
|
||||||
std::string name_;
|
|
||||||
bst_float rho_;
|
|
||||||
};
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
|
||||||
.describe("Rooted mean square error.")
|
|
||||||
.set_body([](const char* param) { return new EvalRMSE(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(MAE, "mae")
|
|
||||||
.describe("Mean absolute error.")
|
|
||||||
.set_body([](const char* param) { return new EvalMAE(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(LogLoss, "logloss")
|
|
||||||
.describe("Negative loglikelihood for logistic regression.")
|
|
||||||
.set_body([](const char* param) { return new EvalLogLoss(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(Error, "error")
|
|
||||||
.describe("Binary classification error.")
|
|
||||||
.set_body([](const char* param) { return new EvalError(param); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
|
||||||
.describe("Negative loglikelihood for poisson regression.")
|
|
||||||
.set_body([](const char* param) { return new EvalPoissonNegLogLik(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(GammaDeviance, "gamma-deviance")
|
|
||||||
.describe("Residual deviance for gamma regression.")
|
|
||||||
.set_body([](const char* param) { return new EvalGammaDeviance(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik")
|
|
||||||
.describe("Negative log-likelihood for gamma regression.")
|
|
||||||
.set_body([](const char* param) { return new EvalGammaNLogLik(); });
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
|
|
||||||
.describe("tweedie-nloglik@rho for tweedie regression.")
|
|
||||||
.set_body([](const char* param) {
|
|
||||||
return new EvalTweedieNLogLik(param);
|
|
||||||
});
|
|
||||||
|
|
||||||
} // namespace metric
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
406
src/metric/elementwise_metric.cu
Normal file
406
src/metric/elementwise_metric.cu
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2015-2018 by Contributors
|
||||||
|
* \file elementwise_metric.cc
|
||||||
|
* \brief evaluation metrics for elementwise binary or regression.
|
||||||
|
* \author Kailong Chen, Tianqi Chen
|
||||||
|
*/
|
||||||
|
#include <rabit/rabit.h>
|
||||||
|
#include <xgboost/metric.h>
|
||||||
|
#include <dmlc/registry.h>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "metric_param.h"
|
||||||
|
#include "../common/math.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
#include <thrust/iterator/counting_iterator.h>
|
||||||
|
#include <thrust/transform_reduce.h>
|
||||||
|
#include <thrust/execution_policy.h>
|
||||||
|
#include <thrust/functional.h> // thrust::plus<>
|
||||||
|
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
#endif // XGBOOST_USE_CUDA
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace metric {
|
||||||
|
// tag the this file, used by force static link later.
|
||||||
|
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
|
||||||
|
|
||||||
|
struct PackedReduceResult {
|
||||||
|
double residue_sum_;
|
||||||
|
double weights_sum_;
|
||||||
|
|
||||||
|
XGBOOST_DEVICE PackedReduceResult() : residue_sum_{0}, weights_sum_{0} {}
|
||||||
|
XGBOOST_DEVICE PackedReduceResult(double residue, double weight) :
|
||||||
|
residue_sum_{residue}, weights_sum_{weight} {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE
|
||||||
|
PackedReduceResult operator+(PackedReduceResult const& other) const {
|
||||||
|
return PackedReduceResult { residue_sum_ + other.residue_sum_,
|
||||||
|
weights_sum_ + other.weights_sum_ };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename EvalRow>
|
||||||
|
class MetricsReduction {
|
||||||
|
public:
|
||||||
|
explicit MetricsReduction(EvalRow policy) :
|
||||||
|
policy_(std::move(policy)) {}
|
||||||
|
|
||||||
|
PackedReduceResult CpuReduceMetrics(
|
||||||
|
const HostDeviceVector<bst_float>& weights,
|
||||||
|
const HostDeviceVector<bst_float>& labels,
|
||||||
|
const HostDeviceVector<bst_float>& preds) const {
|
||||||
|
size_t ndata = labels.Size();
|
||||||
|
|
||||||
|
const auto& h_labels = labels.HostVector();
|
||||||
|
const auto& h_weights = weights.HostVector();
|
||||||
|
const auto& h_preds = preds.HostVector();
|
||||||
|
|
||||||
|
bst_float residue_sum = 0;
|
||||||
|
bst_float weights_sum = 0;
|
||||||
|
|
||||||
|
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
|
||||||
|
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||||
|
const bst_float wt = h_weights.size() > 0 ? h_weights[i] : 1.0f;
|
||||||
|
residue_sum += policy_.EvalRow(h_labels[i], h_preds[i]) * wt;
|
||||||
|
weights_sum += wt;
|
||||||
|
}
|
||||||
|
PackedReduceResult res { residue_sum, weights_sum };
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
|
PackedReduceResult DeviceReduceMetrics(
|
||||||
|
GPUSet::GpuIdType device_id,
|
||||||
|
size_t device_index,
|
||||||
|
const HostDeviceVector<bst_float>& weights,
|
||||||
|
const HostDeviceVector<bst_float>& labels,
|
||||||
|
const HostDeviceVector<bst_float>& preds) {
|
||||||
|
size_t n_data = preds.DeviceSize(device_id);
|
||||||
|
|
||||||
|
thrust::counting_iterator<size_t> begin(0);
|
||||||
|
thrust::counting_iterator<size_t> end = begin + n_data;
|
||||||
|
|
||||||
|
auto s_label = labels.DeviceSpan(device_id);
|
||||||
|
auto s_preds = preds.DeviceSpan(device_id);
|
||||||
|
auto s_weights = weights.DeviceSpan(device_id);
|
||||||
|
|
||||||
|
bool const is_null_weight = weights.Size() == 0;
|
||||||
|
|
||||||
|
auto d_policy = policy_;
|
||||||
|
|
||||||
|
PackedReduceResult result = thrust::transform_reduce(
|
||||||
|
thrust::cuda::par(allocators_.at(device_index)),
|
||||||
|
begin, end,
|
||||||
|
[=] XGBOOST_DEVICE(size_t idx) {
|
||||||
|
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
||||||
|
|
||||||
|
bst_float residue = d_policy.EvalRow(s_label[idx], s_preds[idx]);
|
||||||
|
residue *= weight;
|
||||||
|
return PackedReduceResult{ residue, weight };
|
||||||
|
},
|
||||||
|
PackedReduceResult(),
|
||||||
|
thrust::plus<PackedReduceResult>());
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // XGBOOST_USE_CUDA
|
||||||
|
|
||||||
|
PackedReduceResult Reduce(
|
||||||
|
GPUSet devices,
|
||||||
|
const HostDeviceVector<bst_float>& weights,
|
||||||
|
const HostDeviceVector<bst_float>& labels,
|
||||||
|
const HostDeviceVector<bst_float>& preds) {
|
||||||
|
PackedReduceResult result;
|
||||||
|
|
||||||
|
if (devices.IsEmpty()) {
|
||||||
|
result = CpuReduceMetrics(weights, labels, preds);
|
||||||
|
}
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
else { // NOLINT
|
||||||
|
if (allocators_.size() != devices.Size()) {
|
||||||
|
allocators_.clear();
|
||||||
|
allocators_.resize(devices.Size());
|
||||||
|
}
|
||||||
|
preds.Reshard(devices);
|
||||||
|
labels.Reshard(devices);
|
||||||
|
weights.Reshard(devices);
|
||||||
|
std::vector<PackedReduceResult> res_per_device(devices.Size());
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
|
||||||
|
for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(id));
|
||||||
|
size_t index = devices.Index(id);
|
||||||
|
res_per_device.at(index) =
|
||||||
|
DeviceReduceMetrics(id, index, weights, labels, preds);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||||
|
result.residue_sum_ += res_per_device[i].residue_sum_;
|
||||||
|
result.weights_sum_ += res_per_device[i].weights_sum_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
EvalRow policy_;
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
std::vector<dh::CubMemory> allocators_;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalRowRMSE {
|
||||||
|
char const *Name() const {
|
||||||
|
return "rmse";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||||
|
bst_float diff = label - pred;
|
||||||
|
return diff * diff;
|
||||||
|
}
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return std::sqrt(esum / wsum);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalRowMAE {
|
||||||
|
const char *Name() const {
|
||||||
|
return "mae";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||||
|
return std::abs(label - pred);
|
||||||
|
}
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalRowLogLoss {
|
||||||
|
const char *Name() const {
|
||||||
|
return "logloss";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||||
|
const bst_float eps = 1e-16f;
|
||||||
|
const bst_float pneg = 1.0f - py;
|
||||||
|
if (py < eps) {
|
||||||
|
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
|
||||||
|
} else if (pneg < eps) {
|
||||||
|
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
|
||||||
|
} else {
|
||||||
|
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalError {
|
||||||
|
explicit EvalError(const char* param) {
|
||||||
|
if (param != nullptr) {
|
||||||
|
CHECK_EQ(sscanf(param, "%f", &threshold_), 1)
|
||||||
|
<< "unable to parse the threshold value for the error metric";
|
||||||
|
has_param_ = true;
|
||||||
|
} else {
|
||||||
|
threshold_ = 0.5f;
|
||||||
|
has_param_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const char *Name() const {
|
||||||
|
static std::string name;
|
||||||
|
if (has_param_) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "error";
|
||||||
|
if (threshold_ != 0.5f) os << '@' << threshold_;
|
||||||
|
name = os.str();
|
||||||
|
return name.c_str();
|
||||||
|
} else {
|
||||||
|
return "error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(
|
||||||
|
bst_float label, bst_float pred) const {
|
||||||
|
// assume label is in [0,1]
|
||||||
|
return pred > threshold_ ? 1.0f - label : label;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bst_float threshold_;
|
||||||
|
bool has_param_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalPoissonNegLogLik {
|
||||||
|
const char *Name() const {
|
||||||
|
return "poisson-nloglik";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||||
|
const bst_float eps = 1e-16f;
|
||||||
|
if (py < eps) py = eps;
|
||||||
|
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalGammaDeviance {
|
||||||
|
const char *Name() const {
|
||||||
|
return "gamma-deviance";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||||
|
bst_float epsilon = 1.0e-9;
|
||||||
|
bst_float tmp = label / (pred + epsilon);
|
||||||
|
return tmp - std::log(tmp) - 1;
|
||||||
|
}
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return 2 * esum;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalGammaNLogLik {
|
||||||
|
static const char *Name() {
|
||||||
|
return "gamma-nloglik";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||||
|
bst_float psi = 1.0;
|
||||||
|
bst_float theta = -1. / py;
|
||||||
|
bst_float a = psi;
|
||||||
|
bst_float b = -std::log(-theta);
|
||||||
|
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
|
||||||
|
return -((y * theta - b) / a + c);
|
||||||
|
}
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct EvalTweedieNLogLik {
|
||||||
|
explicit EvalTweedieNLogLik(const char* param) {
|
||||||
|
CHECK(param != nullptr)
|
||||||
|
<< "tweedie-nloglik must be in format tweedie-nloglik@rho";
|
||||||
|
rho_ = atof(param);
|
||||||
|
CHECK(rho_ < 2 && rho_ >= 1)
|
||||||
|
<< "tweedie variance power must be in interval [1, 2)";
|
||||||
|
}
|
||||||
|
const char *Name() const {
|
||||||
|
static std::string name;
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "tweedie-nloglik@" << rho_;
|
||||||
|
name = os.str();
|
||||||
|
return name.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const {
|
||||||
|
bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
|
||||||
|
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
|
||||||
|
return -a + b;
|
||||||
|
}
|
||||||
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bst_float rho_;
|
||||||
|
};
|
||||||
|
/*!
|
||||||
|
* \brief base class of element-wise evaluation
|
||||||
|
* \tparam Derived the name of subclass
|
||||||
|
*/
|
||||||
|
template<typename Policy>
|
||||||
|
struct EvalEWiseBase : public Metric {
|
||||||
|
EvalEWiseBase() : policy_{}, reducer_{policy_} {}
|
||||||
|
explicit EvalEWiseBase(char const* policy_param) :
|
||||||
|
policy_{policy_param}, reducer_{policy_} {}
|
||||||
|
|
||||||
|
void Configure(
|
||||||
|
const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
|
param_.InitAllowUnknown(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||||
|
const MetaInfo& info,
|
||||||
|
bool distributed) override {
|
||||||
|
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||||
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
|
<< "label and prediction size not match, "
|
||||||
|
<< "hint: use merror or mlogloss for multi-class classification";
|
||||||
|
const auto ndata = static_cast<omp_ulong>(info.labels_.Size());
|
||||||
|
// Dealing with ndata < n_gpus.
|
||||||
|
GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus, ndata);
|
||||||
|
|
||||||
|
PackedReduceResult result =
|
||||||
|
reducer_.Reduce(devices, info.weights_, info.labels_, preds);
|
||||||
|
|
||||||
|
double dat[2] { result.residue_sum_, result.weights_sum_ };
|
||||||
|
if (distributed) {
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
|
}
|
||||||
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* Name() const override {
|
||||||
|
return policy_.Name();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Policy policy_;
|
||||||
|
|
||||||
|
MetricParam param_;
|
||||||
|
|
||||||
|
MetricsReduction<Policy> reducer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||||
|
.describe("Rooted mean square error.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowRMSE>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(MAE, "mae")
|
||||||
|
.describe("Mean absolute error.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowMAE>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(LogLoss, "logloss")
|
||||||
|
.describe("Negative loglikelihood for logistic regression.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowLogLoss>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
||||||
|
.describe("Negative loglikelihood for poisson regression.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalPoissonNegLogLik>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(GammaDeviance, "gamma-deviance")
|
||||||
|
.describe("Residual deviance for gamma regression.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalGammaDeviance>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik")
|
||||||
|
.describe("Negative log-likelihood for gamma regression.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalGammaNLogLik>(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(Error, "error")
|
||||||
|
.describe("Binary classification error.")
|
||||||
|
.set_body([](const char* param) { return new EvalEWiseBase<EvalError>(param); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
|
||||||
|
.describe("tweedie-nloglik@rho for tweedie regression.")
|
||||||
|
.set_body([](const char* param) {
|
||||||
|
return new EvalEWiseBase<EvalTweedieNLogLik>(param);
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace metric
|
||||||
|
} // namespace xgboost
|
||||||
@ -6,6 +6,8 @@
|
|||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
|
#include "metric_param.h"
|
||||||
|
|
||||||
namespace dmlc {
|
namespace dmlc {
|
||||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||||
}
|
}
|
||||||
@ -34,6 +36,8 @@ Metric* Metric::Create(const std::string& name) {
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace metric {
|
namespace metric {
|
||||||
|
DMLC_REGISTER_PARAMETER(MetricParam);
|
||||||
|
|
||||||
// List of files that will be force linked in static links.
|
// List of files that will be force linked in static links.
|
||||||
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
|
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
|
||||||
DMLC_REGISTRY_LINK_TAG(multiclass_metric);
|
DMLC_REGISTRY_LINK_TAG(multiclass_metric);
|
||||||
|
|||||||
31
src/metric/metric_param.h
Normal file
31
src/metric/metric_param.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2018 by Contributors
|
||||||
|
* \file metric_param.cc
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_METRIC_METRIC_PARAM_H_
|
||||||
|
#define XGBOOST_METRIC_METRIC_PARAM_H_
|
||||||
|
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace metric {
|
||||||
|
|
||||||
|
// Created exclusively for GPU.
|
||||||
|
struct MetricParam : public dmlc::Parameter<MetricParam> {
|
||||||
|
int n_gpus;
|
||||||
|
int gpu_id;
|
||||||
|
DMLC_DECLARE_PARAMETER(MetricParam) {
|
||||||
|
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(GPUSet::kAll)
|
||||||
|
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||||
|
DMLC_DECLARE_FIELD(gpu_id)
|
||||||
|
.set_lower_bound(0)
|
||||||
|
.set_default(0)
|
||||||
|
.describe("gpu to use for objective function evaluation");
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace metric
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_METRIC_METRIC_PARAM_H_
|
||||||
@ -20,13 +20,13 @@ DMLC_REGISTRY_FILE_TAG(multiclass_metric);
|
|||||||
*/
|
*/
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
struct EvalMClassBase : public Metric {
|
struct EvalMClassBase : public Metric {
|
||||||
bst_float Eval(const std::vector<bst_float> &preds,
|
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bool distributed) const override {
|
bool distributed) override {
|
||||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||||
CHECK(preds.size() % info.labels_.Size() == 0)
|
CHECK(preds.Size() % info.labels_.Size() == 0)
|
||||||
<< "label and prediction size not match";
|
<< "label and prediction size not match";
|
||||||
const size_t nclass = preds.size() / info.labels_.Size();
|
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||||
CHECK_GE(nclass, 1U)
|
CHECK_GE(nclass, 1U)
|
||||||
<< "mlogloss and merror are only used for multi-class classification,"
|
<< "mlogloss and merror are only used for multi-class classification,"
|
||||||
<< " use logloss for binary classification";
|
<< " use logloss for binary classification";
|
||||||
@ -36,14 +36,15 @@ struct EvalMClassBase : public Metric {
|
|||||||
|
|
||||||
const auto& labels = info.labels_.HostVector();
|
const auto& labels = info.labels_.HostVector();
|
||||||
const auto& weights = info.weights_.HostVector();
|
const auto& weights = info.weights_.HostVector();
|
||||||
|
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||||
|
|
||||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
|
const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f;
|
||||||
auto label = static_cast<int>(labels[i]);
|
auto label = static_cast<int>(labels[i]);
|
||||||
if (label >= 0 && label < static_cast<int>(nclass)) {
|
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||||
sum += Derived::EvalRow(label,
|
sum += Derived::EvalRow(label,
|
||||||
preds.data() + i * nclass,
|
h_preds.data() + i * nclass,
|
||||||
nclass) * wt;
|
nclass) * wt;
|
||||||
wsum += wt;
|
wsum += wt;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -8,6 +8,10 @@
|
|||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/host_device_vector.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -26,18 +30,20 @@ struct EvalAMS : public Metric {
|
|||||||
os << "ams@" << ratio_;
|
os << "ams@" << ratio_;
|
||||||
name_ = os.str();
|
name_ = os.str();
|
||||||
}
|
}
|
||||||
bst_float Eval(const std::vector<bst_float> &preds,
|
|
||||||
|
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bool distributed) const override {
|
bool distributed) override {
|
||||||
CHECK(!distributed) << "metric AMS do not support distributed evaluation";
|
CHECK(!distributed) << "metric AMS do not support distributed evaluation";
|
||||||
using namespace std; // NOLINT(*)
|
using namespace std; // NOLINT(*)
|
||||||
|
|
||||||
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
|
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
|
||||||
std::vector<std::pair<bst_float, unsigned> > rec(ndata);
|
std::vector<std::pair<bst_float, unsigned> > rec(ndata);
|
||||||
|
|
||||||
#pragma omp parallel for schedule(static)
|
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
rec[i] = std::make_pair(preds[i], i);
|
rec[i] = std::make_pair(h_preds[i], i);
|
||||||
}
|
}
|
||||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||||
@ -82,11 +88,11 @@ struct EvalAMS : public Metric {
|
|||||||
|
|
||||||
/*! \brief Area Under Curve, for both classification and rank */
|
/*! \brief Area Under Curve, for both classification and rank */
|
||||||
struct EvalAuc : public Metric {
|
struct EvalAuc : public Metric {
|
||||||
bst_float Eval(const std::vector<bst_float> &preds,
|
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bool distributed) const override {
|
bool distributed) override {
|
||||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
<< "label size predict size not match";
|
<< "label size predict size not match";
|
||||||
std::vector<unsigned> tgptr(2, 0);
|
std::vector<unsigned> tgptr(2, 0);
|
||||||
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
||||||
@ -101,10 +107,11 @@ struct EvalAuc : public Metric {
|
|||||||
// each thread takes a local rec
|
// each thread takes a local rec
|
||||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||||
const auto& labels = info.labels_.HostVector();
|
const auto& labels = info.labels_.HostVector();
|
||||||
|
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||||
rec.clear();
|
rec.clear();
|
||||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||||
rec.emplace_back(preds[j], j);
|
rec.emplace_back(h_preds[j], j);
|
||||||
}
|
}
|
||||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||||
// calculate AUC
|
// calculate AUC
|
||||||
@ -155,23 +162,25 @@ struct EvalAuc : public Metric {
|
|||||||
/*! \brief Evaluate rank list */
|
/*! \brief Evaluate rank list */
|
||||||
struct EvalRankList : public Metric {
|
struct EvalRankList : public Metric {
|
||||||
public:
|
public:
|
||||||
bst_float Eval(const std::vector<bst_float> &preds,
|
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bool distributed) const override {
|
bool distributed) override {
|
||||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
<< "label size predict size not match";
|
<< "label size predict size not match";
|
||||||
// quick consistency when group is not available
|
// quick consistency when group is not available
|
||||||
std::vector<unsigned> tgptr(2, 0);
|
std::vector<unsigned> tgptr(2, 0);
|
||||||
tgptr[1] = static_cast<unsigned>(preds.size());
|
tgptr[1] = static_cast<unsigned>(preds.Size());
|
||||||
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
||||||
CHECK_NE(gptr.size(), 0U) << "must specify group when constructing rank file";
|
CHECK_NE(gptr.size(), 0U) << "must specify group when constructing rank file";
|
||||||
CHECK_EQ(gptr.back(), preds.size())
|
CHECK_EQ(gptr.back(), preds.Size())
|
||||||
<< "EvalRanklist: group structure must match number of prediction";
|
<< "EvalRanklist: group structure must match number of prediction";
|
||||||
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||||
// sum statistics
|
// sum statistics
|
||||||
double sum_metric = 0.0f;
|
double sum_metric = 0.0f;
|
||||||
const auto& labels = info.labels_.HostVector();
|
const auto& labels = info.labels_.HostVector();
|
||||||
#pragma omp parallel reduction(+:sum_metric)
|
|
||||||
|
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||||
|
#pragma omp parallel reduction(+:sum_metric)
|
||||||
{
|
{
|
||||||
// each thread takes a local rec
|
// each thread takes a local rec
|
||||||
std::vector< std::pair<bst_float, unsigned> > rec;
|
std::vector< std::pair<bst_float, unsigned> > rec;
|
||||||
@ -179,7 +188,7 @@ struct EvalRankList : public Metric {
|
|||||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||||
rec.clear();
|
rec.clear();
|
||||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||||
rec.emplace_back(preds[j], static_cast<int>(labels[j]));
|
rec.emplace_back(h_preds[j], static_cast<int>(labels[j]));
|
||||||
}
|
}
|
||||||
sum_metric += this->EvalMetric(rec);
|
sum_metric += this->EvalMetric(rec);
|
||||||
}
|
}
|
||||||
@ -311,9 +320,9 @@ struct EvalMAP : public EvalRankList {
|
|||||||
struct EvalCox : public Metric {
|
struct EvalCox : public Metric {
|
||||||
public:
|
public:
|
||||||
EvalCox() = default;
|
EvalCox() = default;
|
||||||
bst_float Eval(const std::vector<bst_float> &preds,
|
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bool distributed) const override {
|
bool distributed) override {
|
||||||
CHECK(!distributed) << "Cox metric does not support distributed evaluation";
|
CHECK(!distributed) << "Cox metric does not support distributed evaluation";
|
||||||
using namespace std; // NOLINT(*)
|
using namespace std; // NOLINT(*)
|
||||||
|
|
||||||
@ -322,8 +331,10 @@ struct EvalCox : public Metric {
|
|||||||
|
|
||||||
// pre-compute a sum for the denominator
|
// pre-compute a sum for the denominator
|
||||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||||
|
|
||||||
|
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||||
exp_p_sum += preds[i];
|
exp_p_sum += h_preds[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
double out = 0;
|
double out = 0;
|
||||||
@ -334,12 +345,12 @@ struct EvalCox : public Metric {
|
|||||||
const size_t ind = label_order[i];
|
const size_t ind = label_order[i];
|
||||||
const auto label = labels[ind];
|
const auto label = labels[ind];
|
||||||
if (label > 0) {
|
if (label > 0) {
|
||||||
out -= log(preds[ind]) - log(exp_p_sum);
|
out -= log(h_preds[ind]) - log(exp_p_sum);
|
||||||
++num_events;
|
++num_events;
|
||||||
}
|
}
|
||||||
|
|
||||||
// only update the denominator after we move forward in time (labels are sorted)
|
// only update the denominator after we move forward in time (labels are sorted)
|
||||||
accumulated_sum += preds[ind];
|
accumulated_sum += h_preds[ind];
|
||||||
if (i == ndata - 1 || std::abs(label) < std::abs(labels[label_order[i + 1]])) {
|
if (i == ndata - 1 || std::abs(label) < std::abs(labels[label_order[i + 1]])) {
|
||||||
exp_p_sum -= accumulated_sum;
|
exp_p_sum -= accumulated_sum;
|
||||||
accumulated_sum = 0;
|
accumulated_sum = 0;
|
||||||
@ -360,10 +371,10 @@ struct EvalAucPR : public Metric {
|
|||||||
// translated from PRROC R Package
|
// translated from PRROC R Package
|
||||||
// see https://doi.org/10.1371/journal.pone.0092209
|
// see https://doi.org/10.1371/journal.pone.0092209
|
||||||
|
|
||||||
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
|
bst_float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
|
||||||
bool distributed) const override {
|
bool distributed) override {
|
||||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||||
CHECK_EQ(preds.size(), info.labels_.Size())
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
<< "label size predict size not match";
|
<< "label size predict size not match";
|
||||||
std::vector<unsigned> tgptr(2, 0);
|
std::vector<unsigned> tgptr(2, 0);
|
||||||
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
||||||
@ -377,15 +388,17 @@ struct EvalAucPR : public Metric {
|
|||||||
int auc_error = 0, auc_gt_one = 0;
|
int auc_error = 0, auc_gt_one = 0;
|
||||||
// each thread takes a local rec
|
// each thread takes a local rec
|
||||||
std::vector<std::pair<bst_float, unsigned>> rec;
|
std::vector<std::pair<bst_float, unsigned>> rec;
|
||||||
const auto& labels = info.labels_.HostVector();
|
const auto& h_labels = info.labels_.HostVector();
|
||||||
|
const std::vector<bst_float>& h_preds = preds.HostVector();
|
||||||
|
|
||||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||||
double total_pos = 0.0;
|
double total_pos = 0.0;
|
||||||
double total_neg = 0.0;
|
double total_neg = 0.0;
|
||||||
rec.clear();
|
rec.clear();
|
||||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||||
total_pos += info.GetWeight(j) * labels[j];
|
total_pos += info.GetWeight(j) * h_labels[j];
|
||||||
total_neg += info.GetWeight(j) * (1.0f - labels[j]);
|
total_neg += info.GetWeight(j) * (1.0f - h_labels[j]);
|
||||||
rec.emplace_back(preds[j], j);
|
rec.emplace_back(h_preds[j], j);
|
||||||
}
|
}
|
||||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||||
// we need pos > 0 && neg > 0
|
// we need pos > 0 && neg > 0
|
||||||
@ -395,8 +408,8 @@ struct EvalAucPR : public Metric {
|
|||||||
// calculate AUC
|
// calculate AUC
|
||||||
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
||||||
for (size_t j = 0; j < rec.size(); ++j) {
|
for (size_t j = 0; j < rec.size(); ++j) {
|
||||||
tp += info.GetWeight(rec[j].second) * labels[rec[j].second];
|
tp += info.GetWeight(rec[j].second) * h_labels[rec[j].second];
|
||||||
fp += info.GetWeight(rec[j].second) * (1.0f - labels[rec[j].second]);
|
fp += info.GetWeight(rec[j].second) * (1.0f - h_labels[rec[j].second]);
|
||||||
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
|
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
|
||||||
if (tp == prevtp) {
|
if (tp == prevtp) {
|
||||||
a = 1.0;
|
a = 1.0;
|
||||||
@ -471,4 +484,3 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
|||||||
.set_body([](const char* param) { return new EvalCox(); });
|
.set_body([](const char* param) { return new EvalCox(); });
|
||||||
} // namespace metric
|
} // namespace metric
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -85,13 +85,14 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj,
|
|||||||
|
|
||||||
|
|
||||||
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
xgboost::bst_float GetMetricEval(xgboost::Metric * metric,
|
||||||
std::vector<xgboost::bst_float> preds,
|
xgboost::HostDeviceVector<xgboost::bst_float> preds,
|
||||||
std::vector<xgboost::bst_float> labels,
|
std::vector<xgboost::bst_float> labels,
|
||||||
std::vector<xgboost::bst_float> weights) {
|
std::vector<xgboost::bst_float> weights) {
|
||||||
xgboost::MetaInfo info;
|
xgboost::MetaInfo info;
|
||||||
info.num_row_ = labels.size();
|
info.num_row_ = labels.size();
|
||||||
info.labels_.HostVector() = labels;
|
info.labels_.HostVector() = labels;
|
||||||
info.weights_.HostVector() = weights;
|
info.weights_.HostVector() = weights;
|
||||||
|
|
||||||
return metric->Eval(preds, info, false);
|
return metric->Eval(preds, info, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj,
|
|||||||
|
|
||||||
xgboost::bst_float GetMetricEval(
|
xgboost::bst_float GetMetricEval(
|
||||||
xgboost::Metric * metric,
|
xgboost::Metric * metric,
|
||||||
std::vector<xgboost::bst_float> preds,
|
xgboost::HostDeviceVector<xgboost::bst_float> preds,
|
||||||
std::vector<xgboost::bst_float> labels,
|
std::vector<xgboost::bst_float> labels,
|
||||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,34 @@
|
|||||||
// Copyright by Contributors
|
/*!
|
||||||
|
* Copyright 2018 XGBoost contributors
|
||||||
|
*/
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
|
#include <map>
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
TEST(Metric, RMSE) {
|
using Arg = std::pair<std::string, std::string>;
|
||||||
|
|
||||||
|
#if defined(__CUDACC__)
|
||||||
|
#define N_GPU() Arg{"n_gpus", "1"}
|
||||||
|
#else
|
||||||
|
#define N_GPU() Arg{"n_gpus", "0"}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TEST(Metric, DeclareUnifiedTest(RMSE)) {
|
||||||
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
ASSERT_STREQ(metric->Name(), "rmse");
|
ASSERT_STREQ(metric->Name(), "rmse");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||||
{ 0, 0, 1, 1}),
|
{ 0, 0, 1, 1},
|
||||||
|
{ 0, 1, 2, 3}),
|
||||||
0.6403f, 0.001f);
|
0.6403f, 0.001f);
|
||||||
delete metric;
|
delete metric;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Metric, MAE) {
|
TEST(Metric, DeclareUnifiedTest(MAE)) {
|
||||||
xgboost::Metric * metric = xgboost::Metric::Create("mae");
|
xgboost::Metric * metric = xgboost::Metric::Create("mae");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
ASSERT_STREQ(metric->Name(), "mae");
|
ASSERT_STREQ(metric->Name(), "mae");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
@ -25,8 +38,9 @@ TEST(Metric, MAE) {
|
|||||||
delete metric;
|
delete metric;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Metric, LogLoss) {
|
TEST(Metric, DeclareUnifiedTest(LogLoss)) {
|
||||||
xgboost::Metric * metric = xgboost::Metric::Create("logloss");
|
xgboost::Metric * metric = xgboost::Metric::Create("logloss");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
ASSERT_STREQ(metric->Name(), "logloss");
|
ASSERT_STREQ(metric->Name(), "logloss");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
@ -36,8 +50,9 @@ TEST(Metric, LogLoss) {
|
|||||||
delete metric;
|
delete metric;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Metric, Error) {
|
TEST(Metric, DeclareUnifiedTest(Error)) {
|
||||||
xgboost::Metric * metric = xgboost::Metric::Create("error");
|
xgboost::Metric * metric = xgboost::Metric::Create("error");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
ASSERT_STREQ(metric->Name(), "error");
|
ASSERT_STREQ(metric->Name(), "error");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
@ -47,11 +62,15 @@ TEST(Metric, Error) {
|
|||||||
|
|
||||||
EXPECT_ANY_THROW(xgboost::Metric::Create("error@abc"));
|
EXPECT_ANY_THROW(xgboost::Metric::Create("error@abc"));
|
||||||
delete metric;
|
delete metric;
|
||||||
|
|
||||||
metric = xgboost::Metric::Create("error@0.5f");
|
metric = xgboost::Metric::Create("error@0.5f");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
EXPECT_STREQ(metric->Name(), "error");
|
EXPECT_STREQ(metric->Name(), "error");
|
||||||
|
|
||||||
delete metric;
|
delete metric;
|
||||||
|
|
||||||
metric = xgboost::Metric::Create("error@0.1");
|
metric = xgboost::Metric::Create("error@0.1");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
ASSERT_STREQ(metric->Name(), "error@0.1");
|
ASSERT_STREQ(metric->Name(), "error@0.1");
|
||||||
EXPECT_STREQ(metric->Name(), "error@0.1");
|
EXPECT_STREQ(metric->Name(), "error@0.1");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
@ -62,8 +81,9 @@ TEST(Metric, Error) {
|
|||||||
delete metric;
|
delete metric;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Metric, PoissionNegLogLik) {
|
TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
|
||||||
xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik");
|
xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik");
|
||||||
|
metric->Configure({N_GPU()});
|
||||||
ASSERT_STREQ(metric->Name(), "poisson-nloglik");
|
ASSERT_STREQ(metric->Name(), "poisson-nloglik");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.5f, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.5f, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
@ -72,3 +92,31 @@ TEST(Metric, PoissionNegLogLik) {
|
|||||||
1.1280f, 0.001f);
|
1.1280f, 0.001f);
|
||||||
delete metric;
|
delete metric;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
|
||||||
|
TEST(Metric, MGPU_RMSE) {
|
||||||
|
{
|
||||||
|
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||||
|
metric->Configure({Arg{"n_gpus", "-1"}});
|
||||||
|
ASSERT_STREQ(metric->Name(), "rmse");
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric, {0}, {0}), 0, 1e-10);
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||||
|
{ 0, 0, 1, 1}),
|
||||||
|
0.6403f, 0.001f);
|
||||||
|
delete metric;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
xgboost::Metric * metric = xgboost::Metric::Create("rmse");
|
||||||
|
metric->Configure({Arg{"n_gpus", "-1"}, Arg{"gpu_id", "1"}});
|
||||||
|
ASSERT_STREQ(metric->Name(), "rmse");
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||||
|
{ 0, 0, 1, 1}),
|
||||||
|
0.6403f, 0.001f);
|
||||||
|
delete metric;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|||||||
5
tests/cpp/metric/test_elementwise_metric.cu
Normal file
5
tests/cpp/metric/test_elementwise_metric.cu
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2018 XGBoost contributors
|
||||||
|
*/
|
||||||
|
// Dummy file to keep the CUDA conditional compile trick.
|
||||||
|
#include "test_elementwise_metric.cc"
|
||||||
@ -88,7 +88,9 @@ TEST(Metric, NDCG) {
|
|||||||
xgboost::Metric * metric = xgboost::Metric::Create("ndcg");
|
xgboost::Metric * metric = xgboost::Metric::Create("ndcg");
|
||||||
ASSERT_STREQ(metric->Name(), "ndcg");
|
ASSERT_STREQ(metric->Name(), "ndcg");
|
||||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
|
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 1, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||||
|
{}), 1, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||||
@ -107,7 +109,9 @@ TEST(Metric, NDCG) {
|
|||||||
delete metric;
|
delete metric;
|
||||||
metric = xgboost::Metric::Create("ndcg@-");
|
metric = xgboost::Metric::Create("ndcg@-");
|
||||||
ASSERT_STREQ(metric->Name(), "ndcg@-");
|
ASSERT_STREQ(metric->Name(), "ndcg@-");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||||
|
{}), 0, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
||||||
EXPECT_NEAR(GetMetricEval(metric,
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||||
@ -134,12 +138,16 @@ TEST(Metric, MAP) {
|
|||||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||||
{ 0, 0, 1, 1}),
|
{ 0, 0, 1, 1}),
|
||||||
0.5f, 0.001f);
|
0.5f, 0.001f);
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 1, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||||
|
std::vector<xgboost::bst_float>{}), 1, 1e-10);
|
||||||
|
|
||||||
delete metric;
|
delete metric;
|
||||||
metric = xgboost::Metric::Create("map@-");
|
metric = xgboost::Metric::Create("map@-");
|
||||||
ASSERT_STREQ(metric->Name(), "map@-");
|
ASSERT_STREQ(metric->Name(), "map@-");
|
||||||
EXPECT_NEAR(GetMetricEval(metric, {}, {}), 0, 1e-10);
|
EXPECT_NEAR(GetMetricEval(metric,
|
||||||
|
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||||
|
{}), 0, 1e-10);
|
||||||
|
|
||||||
delete metric;
|
delete metric;
|
||||||
metric = xgboost::Metric::Create("map@2");
|
metric = xgboost::Metric::Create("map@2");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user