add label error
This commit is contained in:
parent
30e61084eb
commit
529a732737
@ -197,6 +197,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
gpair.resize(preds.size());
|
||||
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size() / nclass);
|
||||
int label_error = 0;
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
@ -208,8 +209,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
Softmax(&rec);
|
||||
const unsigned j = i % nstep;
|
||||
int label = static_cast<int>(info.labels[j]);
|
||||
utils::Check(label >= 0 && label < nclass,
|
||||
"SoftmaxMultiClassObj: label must be in [0, num_class)");
|
||||
if (label < 0 || label >= nclass) {
|
||||
label_error = label; label = 0;
|
||||
}
|
||||
const float wt = info.GetWeight(j);
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
float p = rec[k];
|
||||
@ -222,6 +224,8 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
utils::Check(label_error >= 0 && label_error < nclass,
|
||||
"SoftmaxMultiClassObj: label must be in [0, num_class), found %d in label", label_error);
|
||||
}
|
||||
virtual void PredTransform(std::vector<float> *io_preds) {
|
||||
this->Transform(io_preds, output_prob);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user