diff --git a/src/learner/evaluation-inl.hpp b/src/learner/evaluation-inl.hpp index 8e63e83ec..c7ef4ed30 100644 --- a/src/learner/evaluation-inl.hpp +++ b/src/learner/evaluation-inl.hpp @@ -120,16 +120,25 @@ struct EvalMClassBase : public IEvaluator { "label and prediction size not match"); const size_t nclass = preds.size() / info.labels.size(); const bst_omp_uint ndata = static_cast(info.labels.size()); - float sum = 0.0, wsum = 0.0; + int label_error = 0; #pragma omp parallel for reduction(+: sum, wsum) schedule(static) - for (bst_omp_uint i = 0; i < ndata; ++i) { + for (bst_omp_uint i = 0; i < ndata; ++i) { const float wt = info.GetWeight(i); - sum += Derived::EvalRow(info.labels[i], - BeginPtr(preds) + i * nclass, - nclass) * wt; - wsum += wt; + int label = static_cast(info.labels[i]); + if (label >= 0 && label < static_cast(nclass)) { + sum += Derived::EvalRow(info.labels[i], + BeginPtr(preds) + i * nclass, + nclass) * wt; + wsum += wt; + } else { + label_error = label; + } } + utils::Check(label_error >= 0 && label_error < static_cast(nclass), + "MultiClassEvaluation: label must be in [0, num_class)," \ + " num_class=%d but found %d in label", + static_cast(nclass), label_error); float dat[2]; dat[0] = sum, dat[1] = wsum; if (distributed) { rabit::Allreduce(dat, 2); @@ -143,7 +152,7 @@ struct EvalMClassBase : public IEvaluator { * \param pred prediction value of current instance * \param nclass number of class in the prediction */ - inline static float EvalRow(float label, + inline static float EvalRow(int label, const float *pred, size_t nclass); /*! @@ -154,13 +163,15 @@ struct EvalMClassBase : public IEvaluator { inline static float GetFinal(float esum, float wsum) { return esum / wsum; } + // used to store error message + const char *error_msg_; }; /*! \brief match error */ struct EvalMatchError : public EvalMClassBase { virtual const char *Name(void) const { return "merror"; } - inline static float EvalRow(float label, + inline static float EvalRow(int label, const float *pred, size_t nclass) { return FindMaxIndex(pred, nclass) != static_cast(label); @@ -171,12 +182,11 @@ struct EvalMultiLogLoss : public EvalMClassBase { virtual const char *Name(void) const { return "mlogloss"; } - inline static float EvalRow(float label, + inline static float EvalRow(int label, const float *pred, size_t nclass) { const float eps = 1e-16f; size_t k = static_cast(label); - utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)"); if (pred[k] > eps) { return -std::log(pred[k]); } else {