[METRIC] all metric move finished
This commit is contained in:
parent
dedd87662b
commit
b4d0bb5a6d
75
include/xgboost/metric.h
Normal file
75
include/xgboost/metric.h
Normal file
@ -0,0 +1,75 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file metric.h
|
||||
* \brief interface of evaluation metric function supported in xgboost.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
#ifndef XGBOOST_METRIC_H_
|
||||
#define XGBOOST_METRIC_H_
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "./data.h"
|
||||
#include "./base.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
* \brief interface of evaluation metric used to evaluate model performance.
|
||||
* This has nothing to do with training, but merely act as evaluation purpose.
|
||||
*/
|
||||
class Metric {
|
||||
public:
|
||||
/*!
|
||||
* \brief evaluate a specific metric
|
||||
* \param preds prediction
|
||||
* \param info information, including label etc.
|
||||
* \param distributed whether a call to Allreduce is needed to gather
|
||||
* the average statistics across all the node,
|
||||
* this is only supported by some metrics
|
||||
*/
|
||||
virtual float Eval(const std::vector<float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) const = 0;
|
||||
/*! \return name of metric */
|
||||
virtual const char* Name() const = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~Metric() {}
|
||||
/*!
|
||||
* \brief create a metric according to name.
|
||||
* \param name name of the metric.
|
||||
* name can be in form metric@param
|
||||
* and the name will be matched in the registry.
|
||||
* \return the created metric.
|
||||
*/
|
||||
static Metric* Create(const char *name);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for Metric factory functions.
|
||||
* The additional parameter const char* param gives the value after @, can be null.
|
||||
* For example, metric map@3, then: param == "3".
|
||||
*/
|
||||
struct MetricReg
|
||||
: public dmlc::FunctionRegEntryBase<MetricReg,
|
||||
std::function<Metric* (const char*)> > {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register metric.
|
||||
*
|
||||
* \code
|
||||
* // example of registering a objective ndcg@k
|
||||
* XGBOOST_REGISTER_METRIC(RMSE, "ndcg")
|
||||
* .describe("Rooted mean square error.")
|
||||
* .set_body([](const char* param) {
|
||||
* int at_k = atoi(param);
|
||||
* return new NDCG(at_k);
|
||||
* });
|
||||
* \endcode
|
||||
*/
|
||||
#define XGBOOST_REGISTER_METRIC(UniqueId, Name) \
|
||||
static ::xgboost::MetricReg & __make_ ## MetricReg ## _ ## UniqueId ## __ = \
|
||||
::dmlc::Registry< ::xgboost::MetricReg>::Get()->__REGISTER__(#Name)
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_METRIC_H_
|
||||
@ -70,7 +70,7 @@ class ObjFunction {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for DataIterator factory functions.
|
||||
* \brief Registry entry for objective factory functions.
|
||||
*/
|
||||
struct ObjFunctionReg
|
||||
: public dmlc::FunctionRegEntryBase<ObjFunctionReg,
|
||||
@ -78,7 +78,7 @@ struct ObjFunctionReg
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register objective
|
||||
* \brief Macro to register objective function.
|
||||
*
|
||||
* \code
|
||||
* // example of registering a objective
|
||||
|
||||
13
include/xgboost/sync.h
Normal file
13
include/xgboost/sync.h
Normal file
@ -0,0 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file sync.h
|
||||
* \brief the synchronization module of rabit
|
||||
* redirects to rabit header
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_SYNC_H_
|
||||
#define XGBOOST_SYNC_H_
|
||||
|
||||
#include <rabit.h>
|
||||
|
||||
#endif // XGBOOST_SYNC_H_
|
||||
@ -1,589 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file xgboost_evaluation-inl.hpp
|
||||
* \brief evaluation metrics for regression and classification and rank
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_LEARNER_EVALUATION_INL_HPP_
|
||||
#define XGBOOST_LEARNER_EVALUATION_INL_HPP_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <cmath>
|
||||
#include <climits>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include "../utils/math.h"
|
||||
#include "./evaluation.h"
|
||||
#include "./helper_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace learner {
|
||||
/*!
|
||||
* \brief base class of element-wise evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalEWiseBase : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
"label and prediction size not match"\
|
||||
"hint: use merror or mlogloss for multi-class classification");
|
||||
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += Derived::EvalRow(info.labels[i], preds[i]) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
*/
|
||||
inline static float EvalRow(float label, float pred);
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief RMSE */
|
||||
struct EvalRMSE : public EvalEWiseBase<EvalRMSE> {
|
||||
virtual const char *Name(void) const {
|
||||
return "rmse";
|
||||
}
|
||||
inline static float EvalRow(float label, float pred) {
|
||||
float diff = label - pred;
|
||||
return diff * diff;
|
||||
}
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return std::sqrt(esum / wsum);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief logloss */
|
||||
struct EvalLogLoss : public EvalEWiseBase<EvalLogLoss> {
|
||||
virtual const char *Name(void) const {
|
||||
return "logloss";
|
||||
}
|
||||
inline static float EvalRow(float y, float py) {
|
||||
const float eps = 1e-16f;
|
||||
const float pneg = 1.0f - py;
|
||||
if (py < eps) {
|
||||
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
|
||||
} else if (pneg < eps) {
|
||||
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
|
||||
} else {
|
||||
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief error */
|
||||
struct EvalError : public EvalEWiseBase<EvalError> {
|
||||
virtual const char *Name(void) const {
|
||||
return "error";
|
||||
}
|
||||
inline static float EvalRow(float label, float pred) {
|
||||
// assume label is in [0,1]
|
||||
return pred > 0.5f ? 1.0f - label : label;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief log-likelihood of Poission distribution */
|
||||
struct EvalPoissionNegLogLik : public EvalEWiseBase<EvalPoissionNegLogLik> {
|
||||
virtual const char *Name(void) const {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
inline static float EvalRow(float y, float py) {
|
||||
const float eps = 1e-16f;
|
||||
if (py < eps) py = eps;
|
||||
return utils::LogGamma(y + 1.0f) + py - std::log(py) * y;
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief base class of multi-class evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
const size_t nclass = preds.size() / info.labels.size();
|
||||
utils::Check(nclass > 1,
|
||||
"mlogloss and merror are only used for multi-class classification,"\
|
||||
" use logloss for binary classification");
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
int label_error = 0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
int label = static_cast<int>(info.labels[i]);
|
||||
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||
sum += Derived::EvalRow(label,
|
||||
BeginPtr(preds) + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
} else {
|
||||
label_error = label;
|
||||
}
|
||||
}
|
||||
utils::Check(label_error >= 0 && label_error < static_cast<int>(nclass),
|
||||
"MultiClassEvaluation: label must be in [0, num_class)," \
|
||||
" num_class=%d but found %d in label",
|
||||
static_cast<int>(nclass), label_error);
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
* \param nclass number of class in the prediction
|
||||
*/
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass);
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
// used to store error message
|
||||
const char *error_msg_;
|
||||
};
|
||||
/*! \brief match error */
|
||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||
virtual const char *Name(void) const {
|
||||
return "merror";
|
||||
}
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
|
||||
}
|
||||
};
|
||||
/*! \brief match error */
|
||||
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||
virtual const char *Name(void) const {
|
||||
return "mlogloss";
|
||||
}
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
const float eps = 1e-16f;
|
||||
size_t k = static_cast<size_t>(label);
|
||||
if (pred[k] > eps) {
|
||||
return -std::log(pred[k]);
|
||||
} else {
|
||||
return -std::log(eps);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief ctest */
|
||||
struct EvalCTest: public IEvaluator {
|
||||
EvalCTest(IEvaluator *base, const char *name)
|
||||
: base_(base), name_(name) {}
|
||||
virtual ~EvalCTest(void) {
|
||||
delete base_;
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(!distributed, "metric %s do not support distributed evaluation", name_.c_str());
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
size_t ngroup = preds.size() / info.labels.size() - 1;
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
utils::Check(ngroup > 1, "pred size does not meet requirement");
|
||||
utils::Check(ndata == info.info.fold_index.size(), "need fold index");
|
||||
double wsum = 0.0;
|
||||
for (size_t k = 0; k < ngroup; ++k) {
|
||||
std::vector<float> tpred;
|
||||
MetaInfo tinfo;
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
if (info.info.fold_index[i] == k) {
|
||||
tpred.push_back(preds[i + (k + 1) * ndata]);
|
||||
tinfo.labels.push_back(info.labels[i]);
|
||||
tinfo.weights.push_back(info.GetWeight(i));
|
||||
}
|
||||
}
|
||||
wsum += base_->Eval(tpred, tinfo);
|
||||
}
|
||||
return static_cast<float>(wsum / ngroup);
|
||||
}
|
||||
|
||||
private:
|
||||
IEvaluator *base_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public IEvaluator {
|
||||
public:
|
||||
explicit EvalAMS(const char *name) {
|
||||
name_ = name;
|
||||
// note: ams@0 will automatically select which ratio to go
|
||||
utils::Check(std::sscanf(name, "ams@%f", &ratio_) == 1, "invalid ams format");
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(!distributed, "metric AMS do not support distributed evaluation");
|
||||
using namespace std;
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
|
||||
utils::Check(info.weights.size() == ndata, "we need weight to evaluate ams");
|
||||
std::vector< std::pair<float, unsigned> > rec(ndata);
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
rec[i] = std::make_pair(preds[i], i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), CmpFirst);
|
||||
unsigned ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
if (ntop == 0) ntop = ndata;
|
||||
const double br = 10.0;
|
||||
unsigned thresindex = 0;
|
||||
double s_tp = 0.0, b_fp = 0.0, tams = 0.0;
|
||||
for (unsigned i = 0; i < static_cast<unsigned>(ndata-1) && i < ntop; ++i) {
|
||||
const unsigned ridx = rec[i].second;
|
||||
const float wt = info.weights[ridx];
|
||||
if (info.labels[ridx] > 0.5f) {
|
||||
s_tp += wt;
|
||||
} else {
|
||||
b_fp += wt;
|
||||
}
|
||||
if (rec[i].first != rec[i+1].first) {
|
||||
double ams = sqrt(2*((s_tp+b_fp+br) * log(1.0 + s_tp/(b_fp+br)) - s_tp));
|
||||
if (tams < ams) {
|
||||
thresindex = i;
|
||||
tams = ams;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ntop == ndata) {
|
||||
utils::Printf("\tams-ratio=%g", static_cast<float>(thresindex) / ndata);
|
||||
return static_cast<float>(tams);
|
||||
} else {
|
||||
return static_cast<float>(sqrt(2*((s_tp+b_fp+br) * log(1.0 + s_tp/(b_fp+br)) - s_tp)));
|
||||
}
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
float ratio_;
|
||||
};
|
||||
|
||||
/*! \brief precision with cut off at top percentile */
|
||||
struct EvalPrecisionRatio : public IEvaluator{
|
||||
public:
|
||||
explicit EvalPrecisionRatio(const char *name) : name_(name) {
|
||||
using namespace std;
|
||||
if (sscanf(name, "apratio@%f", &ratio_) == 1) {
|
||||
use_ap = 1;
|
||||
} else {
|
||||
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
|
||||
use_ap = 0;
|
||||
}
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(!distributed, "metric %s do not support distributed evaluation", Name());
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Assert(preds.size() % info.labels.size() == 0,
|
||||
"label size predict size not match");
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
for (size_t j = 0; j < info.labels.size(); ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], static_cast<unsigned>(j)));
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), CmpFirst);
|
||||
double pratio = CalcPRatio(rec, info);
|
||||
return static_cast<float>(pratio);
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
protected:
|
||||
inline double CalcPRatio(const std::vector< std::pair<float, unsigned> >& rec,
|
||||
const MetaInfo &info) const {
|
||||
size_t cutoff = static_cast<size_t>(ratio_ * rec.size());
|
||||
double wt_hit = 0.0, wsum = 0.0, wt_sum = 0.0;
|
||||
for (size_t j = 0; j < cutoff; ++j) {
|
||||
const float wt = info.GetWeight(j);
|
||||
wt_hit += info.labels[rec[j].second] * wt;
|
||||
wt_sum += wt;
|
||||
wsum += wt_hit / wt_sum;
|
||||
}
|
||||
if (use_ap != 0) {
|
||||
return wsum / cutoff;
|
||||
} else {
|
||||
return wt_hit / wt_sum;
|
||||
}
|
||||
}
|
||||
int use_ap;
|
||||
float ratio_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
/*! \brief Area Under Curve, for both classification and rank */
|
||||
struct EvalAuc : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label size predict size not match");
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
utils::Check(gptr.back() == info.labels.size(),
|
||||
"EvalAuc: group structure must match number of prediction");
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double sum_auc = 0.0f;
|
||||
#pragma omp parallel reduction(+:sum_auc)
|
||||
{
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), CmpFirst);
|
||||
// calculate AUC
|
||||
double sum_pospair = 0.0;
|
||||
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
const float wt = info.GetWeight(rec[j].second);
|
||||
const float ctr = info.labels[rec[j].second];
|
||||
// keep bucketing predictions in same bucket
|
||||
if (j != 0 && rec[j].first != rec[j - 1].first) {
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
buf_neg = buf_pos = 0.0f;
|
||||
}
|
||||
buf_pos += ctr * wt;
|
||||
buf_neg += (1.0f - ctr) * wt;
|
||||
}
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
// check weird conditions
|
||||
utils::Check(sum_npos > 0.0 && sum_nneg > 0.0,
|
||||
"AUC: the dataset only contains pos or neg samples");
|
||||
// this is the AUC
|
||||
sum_auc += sum_pospair / (sum_npos*sum_nneg);
|
||||
}
|
||||
}
|
||||
if (distributed) {
|
||||
float dat[2];
|
||||
dat[0] = static_cast<float>(sum_auc);
|
||||
dat[1] = static_cast<float>(ngroup);
|
||||
// approximately estimate auc using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<float>(sum_auc) / ngroup;
|
||||
}
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return "auc";
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Evaluate rank list */
|
||||
struct EvalRankList : public IEvaluator {
|
||||
public:
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
"label size predict size not match");
|
||||
// quick consistency when group is not available
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(preds.size());
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
utils::Assert(gptr.size() != 0, "must specify group when constructing rank file");
|
||||
utils::Assert(gptr.back() == preds.size(),
|
||||
"EvalRanklist: group structure must match number of prediction");
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double sum_metric = 0.0f;
|
||||
#pragma omp parallel reduction(+:sum_metric)
|
||||
{
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], static_cast<int>(info.labels[j])));
|
||||
}
|
||||
sum_metric += this->EvalMetric(rec);
|
||||
}
|
||||
}
|
||||
if (distributed) {
|
||||
float dat[2];
|
||||
dat[0] = static_cast<float>(sum_metric);
|
||||
dat[1] = static_cast<float>(ngroup);
|
||||
// approximately estimate the metric using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<float>(sum_metric) / ngroup;
|
||||
}
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit EvalRankList(const char *name) {
|
||||
using namespace std;
|
||||
name_ = name;
|
||||
minus_ = false;
|
||||
if (sscanf(name, "%*[^@]@%u[-]?", &topn_) != 1) {
|
||||
topn_ = UINT_MAX;
|
||||
}
|
||||
if (name[strlen(name) - 1] == '-') {
|
||||
minus_ = true;
|
||||
}
|
||||
}
|
||||
/*! \return evaluation metric, given the pair_sort record, (pred,label) */
|
||||
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &pair_sort) const = 0; // NOLINT(*)
|
||||
|
||||
protected:
|
||||
unsigned topn_;
|
||||
std::string name_;
|
||||
bool minus_;
|
||||
};
|
||||
|
||||
/*! \brief Precision at N, for both classification and rank */
|
||||
struct EvalPrecision : public EvalRankList{
|
||||
public:
|
||||
explicit EvalPrecision(const char *name) : EvalRankList(name) {}
|
||||
|
||||
protected:
|
||||
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const {
|
||||
// calculate Precision
|
||||
std::sort(rec.begin(), rec.end(), CmpFirst);
|
||||
unsigned nhit = 0;
|
||||
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
|
||||
nhit += (rec[j].second != 0);
|
||||
}
|
||||
return static_cast<float>(nhit) / topn_;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */
|
||||
struct EvalNDCG : public EvalRankList{
|
||||
public:
|
||||
explicit EvalNDCG(const char *name) : EvalRankList(name) {}
|
||||
|
||||
protected:
|
||||
inline float CalcDCG(const std::vector< std::pair<float, unsigned> > &rec) const {
|
||||
double sumdcg = 0.0;
|
||||
for (size_t i = 0; i < rec.size() && i < this->topn_; ++i) {
|
||||
const unsigned rel = rec[i].second;
|
||||
if (rel != 0) {
|
||||
sumdcg += ((1 << rel) - 1) / std::log(i + 2.0);
|
||||
}
|
||||
}
|
||||
return static_cast<float>(sumdcg);
|
||||
}
|
||||
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const { // NOLINT(*)
|
||||
std::stable_sort(rec.begin(), rec.end(), CmpFirst);
|
||||
float dcg = this->CalcDCG(rec);
|
||||
std::stable_sort(rec.begin(), rec.end(), CmpSecond);
|
||||
float idcg = this->CalcDCG(rec);
|
||||
if (idcg == 0.0f) {
|
||||
if (minus_) {
|
||||
return 0.0f;
|
||||
} else {
|
||||
return 1.0f;
|
||||
}
|
||||
}
|
||||
return dcg/idcg;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Mean Average Precision at N, for both classification and rank */
|
||||
struct EvalMAP : public EvalRankList {
|
||||
public:
|
||||
explicit EvalMAP(const char *name) : EvalRankList(name) {}
|
||||
|
||||
protected:
|
||||
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const {
|
||||
std::sort(rec.begin(), rec.end(), CmpFirst);
|
||||
unsigned nhits = 0;
|
||||
double sumap = 0.0;
|
||||
for (size_t i = 0; i < rec.size(); ++i) {
|
||||
if (rec[i].second != 0) {
|
||||
nhits += 1;
|
||||
if (i < this->topn_) {
|
||||
sumap += static_cast<float>(nhits) / (i+1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nhits != 0) {
|
||||
sumap /= nhits;
|
||||
return static_cast<float>(sumap);
|
||||
} else {
|
||||
if (minus_) {
|
||||
return 0.0f;
|
||||
} else {
|
||||
return 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LEARNER_EVALUATION_INL_HPP_
|
||||
@ -1,101 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file evaluation.h
|
||||
* \brief interface of evaluation function supported in xgboost
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
#ifndef XGBOOST_LEARNER_EVALUATION_H_
|
||||
#define XGBOOST_LEARNER_EVALUATION_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include "../utils/utils.h"
|
||||
#include "./dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace learner {
|
||||
/*! \brief evaluator that evaluates the loss metrics */
|
||||
struct IEvaluator{
|
||||
/*!
|
||||
* \brief evaluate a specific metric
|
||||
* \param preds prediction
|
||||
* \param info information, including label etc.
|
||||
* \param distributed whether a call to Allreduce is needed to gather
|
||||
* the average statistics across all the node,
|
||||
* this is only supported by some metrics
|
||||
*/
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed = false) const = 0;
|
||||
/*! \return name of metric */
|
||||
virtual const char *Name(void) const = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IEvaluator(void) {}
|
||||
};
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
|
||||
// include implementations of evaluation functions
|
||||
#include "evaluation-inl.hpp"
|
||||
// factory function
|
||||
namespace xgboost {
|
||||
namespace learner {
|
||||
inline IEvaluator* CreateEvaluator(const char *name) {
|
||||
using namespace std;
|
||||
if (!strcmp(name, "rmse")) return new EvalRMSE();
|
||||
if (!strcmp(name, "error")) return new EvalError();
|
||||
if (!strcmp(name, "merror")) return new EvalMatchError();
|
||||
if (!strcmp(name, "logloss")) return new EvalLogLoss();
|
||||
if (!strcmp(name, "mlogloss")) return new EvalMultiLogLoss();
|
||||
if (!strcmp(name, "poisson-nloglik")) return new EvalPoissionNegLogLik();
|
||||
if (!strcmp(name, "auc")) return new EvalAuc();
|
||||
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
|
||||
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
||||
if (!strncmp(name, "ndcg", 4)) return new EvalNDCG(name);
|
||||
if (!strncmp(name, "ct-", 3)) return new EvalCTest(CreateEvaluator(name+3), name);
|
||||
|
||||
utils::Error("unknown evaluation metric type: %s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/*! \brief a set of evaluators */
|
||||
class EvalSet{
|
||||
public:
|
||||
inline void AddEval(const char *name) {
|
||||
using namespace std;
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
if (!strcmp(name, evals_[i]->Name())) return;
|
||||
}
|
||||
evals_.push_back(CreateEvaluator(name));
|
||||
}
|
||||
~EvalSet(void) {
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
delete evals_[i];
|
||||
}
|
||||
}
|
||||
inline std::string Eval(const char *evname,
|
||||
const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed = false) {
|
||||
std::string result = "";
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
float res = evals_[i]->Eval(preds, info, distributed);
|
||||
char tmp[1024];
|
||||
utils::SPrintf(tmp, sizeof(tmp), "\t%s-%s:%f", evname, evals_[i]->Name(), res);
|
||||
result += tmp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
inline size_t Size(void) const {
|
||||
return evals_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<const IEvaluator*> evals_;
|
||||
};
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LEARNER_EVALUATION_H_
|
||||
@ -1,13 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file sync.h
|
||||
* \brief the synchronization module of rabit
|
||||
* redirects to subtree rabit header
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_SYNC_SYNC_H_
|
||||
#define XGBOOST_SYNC_SYNC_H_
|
||||
|
||||
#include "../../subtree/rabit/include/rabit.h"
|
||||
#include "../../subtree/rabit/include/rabit/timer.h"
|
||||
#endif // XGBOOST_SYNC_SYNC_H_
|
||||
@ -1,45 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file math.h
|
||||
* \brief support additional math
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_UTILS_MATH_H_
|
||||
#define XGBOOST_UTILS_MATH_H_
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace xgboost {
|
||||
namespace utils {
|
||||
#ifdef XGBOOST_STRICT_CXX98_
|
||||
// check nan
|
||||
bool CheckNAN(double v);
|
||||
double LogGamma(double v);
|
||||
#else
|
||||
template<typename T>
|
||||
inline bool CheckNAN(T v) {
|
||||
#ifdef _MSC_VER
|
||||
return (_isnan(v) != 0);
|
||||
#else
|
||||
return isnan(v);
|
||||
#endif
|
||||
}
|
||||
template<typename T>
|
||||
inline T LogGamma(T v) {
|
||||
#ifdef _MSC_VER
|
||||
#if _MSC_VER >= 1800
|
||||
return lgamma(v);
|
||||
#else
|
||||
#pragma message("Warning: lgamma function was not available until VS2013"\
|
||||
", poisson regression will be disabled")
|
||||
utils::Error("lgamma function was not available until VS2013");
|
||||
return static_cast<T>(1.0);
|
||||
#endif
|
||||
#else
|
||||
return lgamma(v);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
} // namespace utils
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_UTILS_MATH_H_
|
||||
@ -102,6 +102,36 @@ inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
||||
const std::pair<float, unsigned> &b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
#ifdef XGBOOST_STRICT_R_MODE_
|
||||
// check nan
|
||||
bool CheckNAN(double v);
|
||||
double LogGamma(double v);
|
||||
#else
|
||||
template<typename T>
|
||||
inline bool CheckNAN(T v) {
|
||||
#ifdef _MSC_VER
|
||||
return (_isnan(v) != 0);
|
||||
#else
|
||||
return isnan(v);
|
||||
#endif
|
||||
}
|
||||
template<typename T>
|
||||
inline T LogGamma(T v) {
|
||||
#ifdef _MSC_VER
|
||||
#if _MSC_VER >= 1800
|
||||
return lgamma(v);
|
||||
#else
|
||||
#pragma message("Warning: lgamma function was not available until VS2013"\
|
||||
", poisson regression will be disabled")
|
||||
utils::Error("lgamma function was not available until VS2013");
|
||||
return static_cast<T>(1.0);
|
||||
#endif
|
||||
#else
|
||||
return lgamma(v);
|
||||
#endif
|
||||
}
|
||||
#endif // XGBOOST_STRICT_R_MODE_
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_MATH_H_
|
||||
|
||||
53
src/common/metric_set.h
Normal file
53
src/common/metric_set.h
Normal file
@ -0,0 +1,53 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file metric_set.h
|
||||
* \brief additional math utils
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_METRIC_SET_H_
|
||||
#define XGBOOST_COMMON_METRIC_SET_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
/*! \brief helper util to create a set of metrics */
|
||||
class MetricSet {
|
||||
inline void AddEval(const char *name) {
|
||||
using namespace std;
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
if (!strcmp(name, evals_[i]->Name())) return;
|
||||
}
|
||||
evals_.push_back(CreateEvaluator(name));
|
||||
}
|
||||
~EvalSet(void) {
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
delete evals_[i];
|
||||
}
|
||||
}
|
||||
inline std::string Eval(const char *evname,
|
||||
const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed = false) {
|
||||
std::string result = "";
|
||||
for (size_t i = 0; i < evals_.size(); ++i) {
|
||||
float res = evals_[i]->Eval(preds, info, distributed);
|
||||
char tmp[1024];
|
||||
utils::SPrintf(tmp, sizeof(tmp), "\t%s-%s:%f", evname, evals_[i]->Name(), res);
|
||||
result += tmp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
inline size_t Size(void) const {
|
||||
return evals_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<const IEvaluator*> evals_;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_METRIC_SET_H_
|
||||
44
src/global.cc
Normal file
44
src/global.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file global.cc
|
||||
* \brief Enable all kinds of global static registry and variables.
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
// implement factory functions
|
||||
ObjFunction* ObjFunction::Create(const char* name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
Metric* Metric::Create(const char* name) {
|
||||
std::string buf = name;
|
||||
std::string prefix = name;
|
||||
auto pos = buf.find('@');
|
||||
if (pos == std::string::npos) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)(nullptr);
|
||||
} else {
|
||||
std::string prefix = buf.substr(0, pos);
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)(buf.substr(pos + 1, buf.length()).c_str());
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
127
src/metric/elementwise_metric.cc
Normal file
127
src/metric/elementwise_metric.cc
Normal file
@ -0,0 +1,127 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file elementwise_metric.cc
|
||||
* \brief evaluation metrics for elementwise binary or regression.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/sync.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
/*!
|
||||
* \brief base class of element-wise evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalEWiseBase : public Metric {
|
||||
float Eval(const std::vector<float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "label and prediction size not match, "
|
||||
<< "hint: use merror or mlogloss for multi-class classification";
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += Derived::EvalRow(info.labels[i], preds[i]) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
*/
|
||||
inline static float EvalRow(float label, float pred);
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalRMSE : public EvalEWiseBase<EvalRMSE> {
|
||||
const char *Name() const override {
|
||||
return "rmse";
|
||||
}
|
||||
inline static float EvalRow(float label, float pred) {
|
||||
float diff = label - pred;
|
||||
return diff * diff;
|
||||
}
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return std::sqrt(esum / wsum);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalLogLoss : public EvalEWiseBase<EvalLogLoss> {
|
||||
const char *Name() const override {
|
||||
return "logloss";
|
||||
}
|
||||
inline static float EvalRow(float y, float py) {
|
||||
const float eps = 1e-16f;
|
||||
const float pneg = 1.0f - py;
|
||||
if (py < eps) {
|
||||
return -y * std::log(eps) - (1.0f - y) * std::log(1.0f - eps);
|
||||
} else if (pneg < eps) {
|
||||
return -y * std::log(1.0f - eps) - (1.0f - y) * std::log(eps);
|
||||
} else {
|
||||
return -y * std::log(py) - (1.0f - y) * std::log(pneg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalError : public EvalEWiseBase<EvalError> {
|
||||
const char *Name() const override {
|
||||
return "error";
|
||||
}
|
||||
inline static float EvalRow(float label, float pred) {
|
||||
// assume label is in [0,1]
|
||||
return pred > 0.5f ? 1.0f - label : label;
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalPoissionNegLogLik : public EvalEWiseBase<EvalPoissionNegLogLik> {
|
||||
const char *Name() const override {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
inline static float EvalRow(float y, float py) {
|
||||
const float eps = 1e-16f;
|
||||
if (py < eps) py = eps;
|
||||
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||
.describe("Rooted mean square error.")
|
||||
.set_body([](const char* param) { return new EvalRMSE(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(LogLoss, "logloss")
|
||||
.describe("Negative loglikelihood for logistic regression.")
|
||||
.set_body([](const char* param) { return new EvalLogLoss(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Error, "error")
|
||||
.describe("Binary classification error.")
|
||||
.set_body([](const char* param) { return new EvalError(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
||||
.describe("Negative loglikelihood for poisson regression.")
|
||||
.set_body([](const char* param) { return new EvalPoissionNegLogLik(); });
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
117
src/metric/multiclass_metric.cc
Normal file
117
src/metric/multiclass_metric.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file multiclass_metric.cc
|
||||
* \brief evaluation metrics for multiclass classification.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/sync.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
/*!
|
||||
* \brief base class of multi-class evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public Metric {
|
||||
float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK(preds.size() % info.labels.size() == 0)
|
||||
<< "label and prediction size not match";
|
||||
const size_t nclass = preds.size() / info.labels.size();
|
||||
CHECK_GE(nclass, 1)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
int label_error = 0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
int label = static_cast<int>(info.labels[i]);
|
||||
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||
sum += Derived::EvalRow(label,
|
||||
dmlc::BeginPtr(preds) + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
} else {
|
||||
label_error = label;
|
||||
}
|
||||
}
|
||||
CHECK(label_error >= 0 && label_error < static_cast<int>(nclass))
|
||||
<< "MultiClassEvaluation: label must be in [0, num_class),"
|
||||
<< " num_class=" << nclass << " but found " << label_error << " in label";
|
||||
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
* \param nclass number of class in the prediction
|
||||
*/
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass);
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
// used to store error message
|
||||
const char *error_msg_;
|
||||
};
|
||||
|
||||
/*! \brief match error */
|
||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||
const char* Name() const override {
|
||||
return "merror";
|
||||
}
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast<int>(label);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief match error */
|
||||
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||
const char* Name() const override {
|
||||
return "mlogloss";
|
||||
}
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
const float eps = 1e-16f;
|
||||
size_t k = static_cast<size_t>(label);
|
||||
if (pred[k] > eps) {
|
||||
return -std::log(pred[k]);
|
||||
} else {
|
||||
return -std::log(eps);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MatchError, "merror")
|
||||
.describe("Multiclass classification error.")
|
||||
.set_body([](const char* param) { return new EvalMatchError(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
|
||||
.describe("Multiclass negative loglikelihood.")
|
||||
.set_body([](const char* param) { return new EvalMultiLogLoss(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
322
src/metric/rank_metric.cc
Normal file
322
src/metric/rank_metric.cc
Normal file
@ -0,0 +1,322 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file rank_metric.cc
|
||||
* \brief prediction rank based metrics.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/sync.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public Metric {
|
||||
public:
|
||||
explicit EvalAMS(const char* param) {
|
||||
CHECK(param != nullptr)
|
||||
<< "AMS must be in format ams@k";
|
||||
ratio_ = atof(param);
|
||||
std::ostringstream os;
|
||||
os << "ams@" << ratio_;
|
||||
name_ = os.str();
|
||||
}
|
||||
float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK(!distributed) << "metric AMS do not support distributed evaluation";
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
CHECK_EQ(info.weights.size(), ndata) << "we need weight to evaluate ams";
|
||||
std::vector<std::pair<float, unsigned> > rec(ndata);
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
rec[i] = std::make_pair(preds[i], i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
if (ntop == 0) ntop = ndata;
|
||||
const double br = 10.0;
|
||||
unsigned thresindex = 0;
|
||||
double s_tp = 0.0, b_fp = 0.0, tams = 0.0;
|
||||
for (unsigned i = 0; i < static_cast<unsigned>(ndata-1) && i < ntop; ++i) {
|
||||
const unsigned ridx = rec[i].second;
|
||||
const float wt = info.weights[ridx];
|
||||
if (info.labels[ridx] > 0.5f) {
|
||||
s_tp += wt;
|
||||
} else {
|
||||
b_fp += wt;
|
||||
}
|
||||
if (rec[i].first != rec[i + 1].first) {
|
||||
double ams = sqrt(2 * ((s_tp + b_fp + br) * log(1.0 + s_tp / (b_fp + br)) - s_tp));
|
||||
if (tams < ams) {
|
||||
thresindex = i;
|
||||
tams = ams;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ntop == ndata) {
|
||||
LOG(INFO) << "best-ams-ratio=" << static_cast<float>(thresindex) / ndata;
|
||||
return static_cast<float>(tams);
|
||||
} else {
|
||||
return static_cast<float>(
|
||||
sqrt(2 * ((s_tp + b_fp + br) * log(1.0 + s_tp/(b_fp + br)) - s_tp)));
|
||||
}
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
float ratio_;
|
||||
};
|
||||
|
||||
/*! \brief Area Under Curve, for both classification and rank */
|
||||
struct EvalAuc : public Metric {
|
||||
float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "label size predict size not match";
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
CHECK_EQ(gptr.back(), info.labels.size())
|
||||
<< "EvalAuc: group structure must match number of prediction";
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double sum_auc = 0.0f;
|
||||
#pragma omp parallel reduction(+:sum_auc)
|
||||
{
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// calculate AUC
|
||||
double sum_pospair = 0.0;
|
||||
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
const float wt = info.GetWeight(rec[j].second);
|
||||
const float ctr = info.labels[rec[j].second];
|
||||
// keep bucketing predictions in same bucket
|
||||
if (j != 0 && rec[j].first != rec[j - 1].first) {
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
buf_neg = buf_pos = 0.0f;
|
||||
}
|
||||
buf_pos += ctr * wt;
|
||||
buf_neg += (1.0f - ctr) * wt;
|
||||
}
|
||||
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
|
||||
sum_npos += buf_pos;
|
||||
sum_nneg += buf_neg;
|
||||
// check weird conditions
|
||||
CHECK(sum_npos > 0.0 && sum_nneg > 0.0)
|
||||
<< "AUC: the dataset only contains pos or neg samples";
|
||||
// this is the AUC
|
||||
sum_auc += sum_pospair / (sum_npos*sum_nneg);
|
||||
}
|
||||
}
|
||||
if (distributed) {
|
||||
float dat[2];
|
||||
dat[0] = static_cast<float>(sum_auc);
|
||||
dat[1] = static_cast<float>(ngroup);
|
||||
// approximately estimate auc using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<float>(sum_auc) / ngroup;
|
||||
}
|
||||
}
|
||||
const char* Name() const override {
|
||||
return "auc";
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Evaluate rank list */
|
||||
struct EvalRankList : public Metric {
|
||||
public:
|
||||
float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "label size predict size not match";
|
||||
// quick consistency when group is not available
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(preds.size());
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
CHECK_NE(gptr.size(), 0) << "must specify group when constructing rank file";
|
||||
CHECK_EQ(gptr.back(), preds.size())
|
||||
<< "EvalRanklist: group structure must match number of prediction";
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double sum_metric = 0.0f;
|
||||
#pragma omp parallel reduction(+:sum_metric)
|
||||
{
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], static_cast<int>(info.labels[j])));
|
||||
}
|
||||
sum_metric += this->EvalMetric(rec);
|
||||
}
|
||||
}
|
||||
if (distributed) {
|
||||
float dat[2];
|
||||
dat[0] = static_cast<float>(sum_metric);
|
||||
dat[1] = static_cast<float>(ngroup);
|
||||
// approximately estimate the metric using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<float>(sum_metric) / ngroup;
|
||||
}
|
||||
}
|
||||
const char* Name() const override {
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit EvalRankList(const char* name, const char* param) {
|
||||
using namespace std; // NOLINT(*)
|
||||
minus_ = false;
|
||||
if (param != nullptr) {
|
||||
std::ostringstream os;
|
||||
os << name << '@' << param;
|
||||
name_ = os.str();
|
||||
if (sscanf(param, "%u[-]?", &topn_) != 1) {
|
||||
topn_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
if (param[strlen(param) - 1] == '-') {
|
||||
minus_ = true;
|
||||
}
|
||||
} else {
|
||||
topn_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
}
|
||||
/*! \return evaluation metric, given the pair_sort record, (pred,label) */
|
||||
virtual float EvalMetric(std::vector<std::pair<float, unsigned> > &pair_sort) const = 0; // NOLINT(*)
|
||||
|
||||
protected:
|
||||
unsigned topn_;
|
||||
std::string name_;
|
||||
bool minus_;
|
||||
};
|
||||
|
||||
/*! \brief Precision at N, for both classification and rank */
|
||||
struct EvalPrecision : public EvalRankList{
|
||||
public:
|
||||
explicit EvalPrecision(const char *name) : EvalRankList("pre", name) {}
|
||||
|
||||
protected:
|
||||
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const {
|
||||
// calculate Precision
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhit = 0;
|
||||
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
|
||||
nhit += (rec[j].second != 0);
|
||||
}
|
||||
return static_cast<float>(nhit) / topn_;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */
|
||||
struct EvalNDCG : public EvalRankList{
|
||||
public:
|
||||
explicit EvalNDCG(const char *name) : EvalRankList("ndcg", name) {}
|
||||
|
||||
protected:
|
||||
inline float CalcDCG(const std::vector<std::pair<float, unsigned> > &rec) const {
|
||||
double sumdcg = 0.0;
|
||||
for (size_t i = 0; i < rec.size() && i < this->topn_; ++i) {
|
||||
const unsigned rel = rec[i].second;
|
||||
if (rel != 0) {
|
||||
sumdcg += ((1 << rel) - 1) / std::log(i + 2.0);
|
||||
}
|
||||
}
|
||||
return static_cast<float>(sumdcg);
|
||||
}
|
||||
virtual float EvalMetric(std::vector<std::pair<float, unsigned> > &rec) const { // NOLINT(*)
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
float dcg = this->CalcDCG(rec);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpSecond);
|
||||
float idcg = this->CalcDCG(rec);
|
||||
if (idcg == 0.0f) {
|
||||
if (minus_) {
|
||||
return 0.0f;
|
||||
} else {
|
||||
return 1.0f;
|
||||
}
|
||||
}
|
||||
return dcg/idcg;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Mean Average Precision at N, for both classification and rank */
|
||||
struct EvalMAP : public EvalRankList {
|
||||
public:
|
||||
explicit EvalMAP(const char *name) : EvalRankList("map", name) {}
|
||||
|
||||
protected:
|
||||
virtual float EvalMetric(std::vector< std::pair<float, unsigned> > &rec) const {
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhits = 0;
|
||||
double sumap = 0.0;
|
||||
for (size_t i = 0; i < rec.size(); ++i) {
|
||||
if (rec[i].second != 0) {
|
||||
nhits += 1;
|
||||
if (i < this->topn_) {
|
||||
sumap += static_cast<float>(nhits) / (i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nhits != 0) {
|
||||
sumap /= nhits;
|
||||
return static_cast<float>(sumap);
|
||||
} else {
|
||||
if (minus_) {
|
||||
return 0.0f;
|
||||
} else {
|
||||
return 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AMS, "ams")
|
||||
.describe("AMS metric for higgs.")
|
||||
.set_body([](const char* param) { return new EvalAMS(param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Auc, "auc")
|
||||
.describe("Area under curve for both classification and rank.")
|
||||
.set_body([](const char* param) { return new EvalAuc(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||
.describe("precision@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalPrecision(param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(NDCG, "ndcg")
|
||||
.describe("ndcg@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalNDCG(param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MAP, "map")
|
||||
.describe("map@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalMAP(param); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
@ -1,17 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file objective.cc
|
||||
* \brief global objective function definition.
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
/*! \brief namespace of objective function */
|
||||
namespace obj {
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
@ -26,7 +26,6 @@ struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
|
||||
DMLC_DECLARE_FIELD(fix_list_weight).set_lower_bound(0.0f).set_default(0.0f)
|
||||
.describe("Normalize the weight of each list by this value,"
|
||||
" if equals 0, no effect will happen");
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -40,9 +39,8 @@ class LambdaRankObj : public ObjFunction {
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
std::vector<bst_gpair>* out_gpair) override {
|
||||
CHECK_EQ(preds.size(),info.labels.size()) << "label size predict size not match";
|
||||
CHECK_EQ(preds.size(), info.labels.size()) << "label size predict size not match";
|
||||
std::vector<bst_gpair>& gpair = *out_gpair;
|
||||
|
||||
gpair.resize(preds.size());
|
||||
// quick consistency when group is not available
|
||||
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||
Loading…
x
Reference in New Issue
Block a user