fix crash in error

This commit is contained in:
tqchen 2015-04-06 08:58:33 -07:00
parent b6d85b9d9b
commit 3cc48d6707

View File

@ -120,16 +120,25 @@ struct EvalMClassBase : public IEvaluator {
"label and prediction size not match");
const size_t nclass = preds.size() / info.labels.size();
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
float sum = 0.0, wsum = 0.0;
int label_error = 0;
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
for (bst_omp_uint i = 0; i < ndata; ++i) {
const float wt = info.GetWeight(i);
sum += Derived::EvalRow(info.labels[i],
BeginPtr(preds) + i * nclass,
nclass) * wt;
wsum += wt;
int label = static_cast<int>(info.labels[i]);
if (label >= 0 && label < static_cast<int>(nclass)) {
sum += Derived::EvalRow(info.labels[i],
BeginPtr(preds) + i * nclass,
nclass) * wt;
wsum += wt;
} else {
label_error = label;
}
}
utils::Check(label_error >= 0 && label_error < static_cast<int>(nclass),
"MultiClassEvaluation: label must be in [0, num_class)," \
" num_class=%d but found %d in label",
static_cast<int>(nclass), label_error);
float dat[2]; dat[0] = sum, dat[1] = wsum;
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
@ -143,7 +152,7 @@ struct EvalMClassBase : public IEvaluator {
* \param pred prediction value of current instance
* \param nclass number of class in the prediction
*/
inline static float EvalRow(float label,
inline static float EvalRow(int label,
const float *pred,
size_t nclass);
/*!
@ -154,13 +163,15 @@ struct EvalMClassBase : public IEvaluator {
inline static float GetFinal(float esum, float wsum) {
return esum / wsum;
}
// used to store error message
const char *error_msg_;
};
/*! \brief match error */
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
virtual const char *Name(void) const {
return "merror";
}
inline static float EvalRow(float label,
inline static float EvalRow(int label,
const float *pred,
size_t nclass) {
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
@ -171,12 +182,11 @@ struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
virtual const char *Name(void) const {
return "mlogloss";
}
inline static float EvalRow(float label,
inline static float EvalRow(int label,
const float *pred,
size_t nclass) {
const float eps = 1e-16f;
size_t k = static_cast<size_t>(label);
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
if (pred[k] > eps) {
return -std::log(pred[k]);
} else {