add label error

This commit is contained in:
tqchen 2015-04-06 08:45:54 -07:00
parent 30e61084eb
commit 529a732737

View File

@ -197,6 +197,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
gpair.resize(preds.size());
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size() / nclass);
int label_error = 0;
#pragma omp parallel
{
std::vector<float> rec(nclass);
@ -208,8 +209,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
Softmax(&rec);
const unsigned j = i % nstep;
int label = static_cast<int>(info.labels[j]);
utils::Check(label >= 0 && label < nclass,
"SoftmaxMultiClassObj: label must be in [0, num_class)");
if (label < 0 || label >= nclass) {
label_error = label; label = 0;
}
const float wt = info.GetWeight(j);
for (int k = 0; k < nclass; ++k) {
float p = rec[k];
@ -222,6 +224,8 @@ class SoftmaxMultiClassObj : public IObjFunction {
}
}
}
utils::Check(label_error >= 0 && label_error < nclass,
"SoftmaxMultiClassObj: label must be in [0, num_class), found %d in label", label_error);
}
virtual void PredTransform(std::vector<float> *io_preds) {
this->Transform(io_preds, output_prob);