add new evaluation metric mlogloss for multi-class classification logloss

This commit is contained in:
tqchen 2015-03-19 11:34:38 -07:00
parent 8025b338a8
commit e1538ae615
4 changed files with 88 additions and 11 deletions

View File

@ -27,8 +27,9 @@ struct EvalEWiseBase : public IEvaluator {
const MetaInfo &info,
bool distributed) const {
utils::Check(info.labels.size() != 0, "label set cannot be empty");
utils::Check(preds.size() % info.labels.size() == 0,
"label and prediction size not match");
utils::Check(preds.size() == info.labels.size(),
"label and prediction size not match"\
"hint: use merror or mlogloss for multi-class classification");
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
@ -50,7 +51,6 @@ struct EvalEWiseBase : public IEvaluator {
* get evaluation result from one row
* \param label label of current instance
* \param pred prediction value of current instance
* \param weight weight of current instance
*/
inline static float EvalRow(float label, float pred);
/*!
@ -98,15 +98,84 @@ struct EvalError : public EvalEWiseBase<EvalError> {
}
};
/*!
* \brief base class of multi-class evaluation
* \tparam Derived the name of subclass
*/
template<typename Derived>
struct EvalMClassBase : public IEvaluator {
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info,
bool distributed) const {
utils::Check(info.labels.size() != 0, "label set cannot be empty");
utils::Check(preds.size() % info.labels.size() == 0,
"label and prediction size not match");
const size_t nclass = preds.size() / info.labels.size();
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
float sum = 0.0, wsum = 0.0;
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const float wt = info.GetWeight(i);
sum += Derived::EvalRow(info.labels[i],
BeginPtr(preds) + i * nclass,
nclass) * wt;
wsum += wt;
}
float dat[2]; dat[0] = sum, dat[1] = wsum;
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
return Derived::GetFinal(dat[0], dat[1]);
}
/*!
* \brief to be implemented by subclass,
* get evaluation result from one row
* \param label label of current instance
* \param pred prediction value of current instance
* \param nclass number of class in the prediction
*/
inline static float EvalRow(float label,
const float *pred,
size_t nclass);
/*!
* \brief to be overide by subclas, final trasnformation
* \param esum the sum statistics returned by EvalRow
* \param wsum sum of weight
*/
inline static float GetFinal(float esum, float wsum) {
return esum / wsum;
}
};
/*! \brief match error */
struct EvalMatchError : public EvalEWiseBase<EvalMatchError> {
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
virtual const char *Name(void) const {
return "merror";
}
inline static float EvalRow(float label, float pred) {
return static_cast<int>(pred) != static_cast<int>(label);
inline static float EvalRow(float label,
const float *pred,
size_t nclass) {
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
}
};
/*! \brief match error */
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
virtual const char *Name(void) const {
return "mlogloss";
}
inline static float EvalRow(float label,
const float *pred,
size_t nclass) {
size_t k = static_cast<size_t>(label);
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
if (pred[k] > eps) {
return -std::log(pred[k]);
} else {
return -std::log(eps);
}
}
const static float eps = 1e-16;
};
/*! \brief ctest */
struct EvalCTest: public IEvaluator {

View File

@ -45,6 +45,7 @@ inline IEvaluator* CreateEvaluator(const char *name) {
if (!strcmp(name, "error")) return new EvalError();
if (!strcmp(name, "merror")) return new EvalMatchError();
if (!strcmp(name, "logloss")) return new EvalLogLoss();
if (!strcmp(name, "mlogloss")) return new EvalMultiLogLoss();
if (!strcmp(name, "auc")) return new EvalAuc();
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);

View File

@ -27,21 +27,28 @@ inline static void Softmax(std::vector<float>* p_rec) {
rec[i] /= static_cast<float>(wsum);
}
}
// simple helper function to do softmax
inline static int FindMaxIndex(const std::vector<float>& rec) {
inline static int FindMaxIndex(const float *rec, size_t size) {
size_t mxid = 0;
for (size_t i = 1; i < rec.size(); ++i) {
if (rec[i] > rec[mxid] + 1e-6f) {
for (size_t i = 1; i < size; ++i) {
if (rec[i] > rec[mxid]) {
mxid = i;
}
}
return static_cast<int>(mxid);
}
// simple helper function to do softmax
inline static int FindMaxIndex(const std::vector<float>& rec) {
return FindMaxIndex(BeginPtr(rec), rec.size());
}
inline static bool CmpFirst(const std::pair<float, unsigned> &a,
const std::pair<float, unsigned> &b) {
return a.first > b.first;
}
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
const std::pair<float, unsigned> &b) {
return a.second > b.second;

View File

@ -225,7 +225,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
this->Transform(io_preds, output_prob);
}
virtual void EvalTransform(std::vector<float> *io_preds) {
this->Transform(io_preds, 0);
this->Transform(io_preds, 1);
}
virtual const char* DefaultEvalMetric(void) const {
return "merror";