add new evaluation metric mlogloss for multi-class classification logloss
This commit is contained in:
parent
8025b338a8
commit
e1538ae615
@ -27,8 +27,9 @@ struct EvalEWiseBase : public IEvaluator {
|
|||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bool distributed) const {
|
bool distributed) const {
|
||||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||||
utils::Check(preds.size() % info.labels.size() == 0,
|
utils::Check(preds.size() == info.labels.size(),
|
||||||
"label and prediction size not match");
|
"label and prediction size not match"\
|
||||||
|
"hint: use merror or mlogloss for multi-class classification");
|
||||||
|
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||||
|
|
||||||
@ -50,7 +51,6 @@ struct EvalEWiseBase : public IEvaluator {
|
|||||||
* get evaluation result from one row
|
* get evaluation result from one row
|
||||||
* \param label label of current instance
|
* \param label label of current instance
|
||||||
* \param pred prediction value of current instance
|
* \param pred prediction value of current instance
|
||||||
* \param weight weight of current instance
|
|
||||||
*/
|
*/
|
||||||
inline static float EvalRow(float label, float pred);
|
inline static float EvalRow(float label, float pred);
|
||||||
/*!
|
/*!
|
||||||
@ -98,15 +98,84 @@ struct EvalError : public EvalEWiseBase<EvalError> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief base class of multi-class evaluation
|
||||||
|
* \tparam Derived the name of subclass
|
||||||
|
*/
|
||||||
|
template<typename Derived>
|
||||||
|
struct EvalMClassBase : public IEvaluator {
|
||||||
|
virtual float Eval(const std::vector<float> &preds,
|
||||||
|
const MetaInfo &info,
|
||||||
|
bool distributed) const {
|
||||||
|
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||||
|
utils::Check(preds.size() % info.labels.size() == 0,
|
||||||
|
"label and prediction size not match");
|
||||||
|
const size_t nclass = preds.size() / info.labels.size();
|
||||||
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||||
|
|
||||||
|
float sum = 0.0, wsum = 0.0;
|
||||||
|
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||||
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
|
const float wt = info.GetWeight(i);
|
||||||
|
sum += Derived::EvalRow(info.labels[i],
|
||||||
|
BeginPtr(preds) + i * nclass,
|
||||||
|
nclass) * wt;
|
||||||
|
wsum += wt;
|
||||||
|
}
|
||||||
|
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||||
|
if (distributed) {
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
|
}
|
||||||
|
return Derived::GetFinal(dat[0], dat[1]);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief to be implemented by subclass,
|
||||||
|
* get evaluation result from one row
|
||||||
|
* \param label label of current instance
|
||||||
|
* \param pred prediction value of current instance
|
||||||
|
* \param nclass number of class in the prediction
|
||||||
|
*/
|
||||||
|
inline static float EvalRow(float label,
|
||||||
|
const float *pred,
|
||||||
|
size_t nclass);
|
||||||
|
/*!
|
||||||
|
* \brief to be overide by subclas, final trasnformation
|
||||||
|
* \param esum the sum statistics returned by EvalRow
|
||||||
|
* \param wsum sum of weight
|
||||||
|
*/
|
||||||
|
inline static float GetFinal(float esum, float wsum) {
|
||||||
|
return esum / wsum;
|
||||||
|
}
|
||||||
|
};
|
||||||
/*! \brief match error */
|
/*! \brief match error */
|
||||||
struct EvalMatchError : public EvalEWiseBase<EvalMatchError> {
|
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||||
virtual const char *Name(void) const {
|
virtual const char *Name(void) const {
|
||||||
return "merror";
|
return "merror";
|
||||||
}
|
}
|
||||||
inline static float EvalRow(float label, float pred) {
|
inline static float EvalRow(float label,
|
||||||
return static_cast<int>(pred) != static_cast<int>(label);
|
const float *pred,
|
||||||
|
size_t nclass) {
|
||||||
|
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
/*! \brief match error */
|
||||||
|
struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||||
|
virtual const char *Name(void) const {
|
||||||
|
return "mlogloss";
|
||||||
|
}
|
||||||
|
inline static float EvalRow(float label,
|
||||||
|
const float *pred,
|
||||||
|
size_t nclass) {
|
||||||
|
size_t k = static_cast<size_t>(label);
|
||||||
|
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
|
||||||
|
if (pred[k] > eps) {
|
||||||
|
return -std::log(pred[k]);
|
||||||
|
} else {
|
||||||
|
return -std::log(eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const static float eps = 1e-16;
|
||||||
|
};
|
||||||
|
|
||||||
/*! \brief ctest */
|
/*! \brief ctest */
|
||||||
struct EvalCTest: public IEvaluator {
|
struct EvalCTest: public IEvaluator {
|
||||||
|
|||||||
@ -45,6 +45,7 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
|||||||
if (!strcmp(name, "error")) return new EvalError();
|
if (!strcmp(name, "error")) return new EvalError();
|
||||||
if (!strcmp(name, "merror")) return new EvalMatchError();
|
if (!strcmp(name, "merror")) return new EvalMatchError();
|
||||||
if (!strcmp(name, "logloss")) return new EvalLogLoss();
|
if (!strcmp(name, "logloss")) return new EvalLogLoss();
|
||||||
|
if (!strcmp(name, "mlogloss")) return new EvalMultiLogLoss();
|
||||||
if (!strcmp(name, "auc")) return new EvalAuc();
|
if (!strcmp(name, "auc")) return new EvalAuc();
|
||||||
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
||||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||||
|
|||||||
@ -27,21 +27,28 @@ inline static void Softmax(std::vector<float>* p_rec) {
|
|||||||
rec[i] /= static_cast<float>(wsum);
|
rec[i] /= static_cast<float>(wsum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// simple helper function to do softmax
|
|
||||||
inline static int FindMaxIndex(const std::vector<float>& rec) {
|
inline static int FindMaxIndex(const float *rec, size_t size) {
|
||||||
size_t mxid = 0;
|
size_t mxid = 0;
|
||||||
for (size_t i = 1; i < rec.size(); ++i) {
|
for (size_t i = 1; i < size; ++i) {
|
||||||
if (rec[i] > rec[mxid] + 1e-6f) {
|
if (rec[i] > rec[mxid]) {
|
||||||
mxid = i;
|
mxid = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return static_cast<int>(mxid);
|
return static_cast<int>(mxid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// simple helper function to do softmax
|
||||||
|
inline static int FindMaxIndex(const std::vector<float>& rec) {
|
||||||
|
return FindMaxIndex(BeginPtr(rec), rec.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
inline static bool CmpFirst(const std::pair<float, unsigned> &a,
|
inline static bool CmpFirst(const std::pair<float, unsigned> &a,
|
||||||
const std::pair<float, unsigned> &b) {
|
const std::pair<float, unsigned> &b) {
|
||||||
return a.first > b.first;
|
return a.first > b.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
||||||
const std::pair<float, unsigned> &b) {
|
const std::pair<float, unsigned> &b) {
|
||||||
return a.second > b.second;
|
return a.second > b.second;
|
||||||
|
|||||||
@ -225,7 +225,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
|||||||
this->Transform(io_preds, output_prob);
|
this->Transform(io_preds, output_prob);
|
||||||
}
|
}
|
||||||
virtual void EvalTransform(std::vector<float> *io_preds) {
|
virtual void EvalTransform(std::vector<float> *io_preds) {
|
||||||
this->Transform(io_preds, 0);
|
this->Transform(io_preds, 1);
|
||||||
}
|
}
|
||||||
virtual const char* DefaultEvalMetric(void) const {
|
virtual const char* DefaultEvalMetric(void) const {
|
||||||
return "merror";
|
return "merror";
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user