add new evaluation metric mlogloss for multi-class classification logloss
This commit is contained in:
parent
8025b338a8
commit
e1538ae615
@ -27,8 +27,9 @@ struct EvalEWiseBase : public IEvaluator {
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
"label and prediction size not match"\
|
||||
"hint: use merror or mlogloss for multi-class classification");
|
||||
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
|
||||
@ -50,7 +51,6 @@ struct EvalEWiseBase : public IEvaluator {
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
* \param weight weight of current instance
|
||||
*/
|
||||
inline static float EvalRow(float label, float pred);
|
||||
/*!
|
||||
@ -98,15 +98,84 @@ struct EvalError : public EvalEWiseBase<EvalError> {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief base class of multi-class evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public IEvaluator {
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) const {
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
const size_t nclass = preds.size() / info.labels.size();
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += Derived::EvalRow(info.labels[i],
|
||||
BeginPtr(preds) + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
* \param nclass number of class in the prediction
|
||||
*/
|
||||
inline static float EvalRow(float label,
|
||||
const float *pred,
|
||||
size_t nclass);
|
||||
/*!
|
||||
* \brief to be overide by subclas, final trasnformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
};
|
||||
/*! \brief match error */
|
||||
struct EvalMatchError : public EvalEWiseBase<EvalMatchError> {
|
||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||
virtual const char *Name(void) const {
|
||||
return "merror";
|
||||
}
|
||||
inline static float EvalRow(float label, float pred) {
|
||||
return static_cast<int>(pred) != static_cast<int>(label);
|
||||
inline static float EvalRow(float label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
|
||||
}
|
||||
};
|
||||
/*! \brief match error */
|
||||
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||
virtual const char *Name(void) const {
|
||||
return "mlogloss";
|
||||
}
|
||||
inline static float EvalRow(float label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
size_t k = static_cast<size_t>(label);
|
||||
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
|
||||
if (pred[k] > eps) {
|
||||
return -std::log(pred[k]);
|
||||
} else {
|
||||
return -std::log(eps);
|
||||
}
|
||||
}
|
||||
const static float eps = 1e-16;
|
||||
};
|
||||
|
||||
/*! \brief ctest */
|
||||
struct EvalCTest: public IEvaluator {
|
||||
|
||||
@ -45,6 +45,7 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
||||
if (!strcmp(name, "error")) return new EvalError();
|
||||
if (!strcmp(name, "merror")) return new EvalMatchError();
|
||||
if (!strcmp(name, "logloss")) return new EvalLogLoss();
|
||||
if (!strcmp(name, "mlogloss")) return new EvalMultiLogLoss();
|
||||
if (!strcmp(name, "auc")) return new EvalAuc();
|
||||
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||
|
||||
@ -27,21 +27,28 @@ inline static void Softmax(std::vector<float>* p_rec) {
|
||||
rec[i] /= static_cast<float>(wsum);
|
||||
}
|
||||
}
|
||||
// simple helper function to do softmax
|
||||
inline static int FindMaxIndex(const std::vector<float>& rec) {
|
||||
|
||||
inline static int FindMaxIndex(const float *rec, size_t size) {
|
||||
size_t mxid = 0;
|
||||
for (size_t i = 1; i < rec.size(); ++i) {
|
||||
if (rec[i] > rec[mxid] + 1e-6f) {
|
||||
for (size_t i = 1; i < size; ++i) {
|
||||
if (rec[i] > rec[mxid]) {
|
||||
mxid = i;
|
||||
}
|
||||
}
|
||||
return static_cast<int>(mxid);
|
||||
}
|
||||
|
||||
// simple helper function to do softmax
|
||||
inline static int FindMaxIndex(const std::vector<float>& rec) {
|
||||
return FindMaxIndex(BeginPtr(rec), rec.size());
|
||||
}
|
||||
|
||||
|
||||
inline static bool CmpFirst(const std::pair<float, unsigned> &a,
|
||||
const std::pair<float, unsigned> &b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
||||
const std::pair<float, unsigned> &b) {
|
||||
return a.second > b.second;
|
||||
|
||||
@ -225,7 +225,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
this->Transform(io_preds, output_prob);
|
||||
}
|
||||
virtual void EvalTransform(std::vector<float> *io_preds) {
|
||||
this->Transform(io_preds, 0);
|
||||
this->Transform(io_preds, 1);
|
||||
}
|
||||
virtual const char* DefaultEvalMetric(void) const {
|
||||
return "merror";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user