add label error
This commit is contained in:
parent
30e61084eb
commit
529a732737
@ -197,6 +197,7 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
|||||||
gpair.resize(preds.size());
|
gpair.resize(preds.size());
|
||||||
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
|
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size() / nclass);
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size() / nclass);
|
||||||
|
int label_error = 0;
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
std::vector<float> rec(nclass);
|
std::vector<float> rec(nclass);
|
||||||
@ -208,8 +209,9 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
|||||||
Softmax(&rec);
|
Softmax(&rec);
|
||||||
const unsigned j = i % nstep;
|
const unsigned j = i % nstep;
|
||||||
int label = static_cast<int>(info.labels[j]);
|
int label = static_cast<int>(info.labels[j]);
|
||||||
utils::Check(label >= 0 && label < nclass,
|
if (label < 0 || label >= nclass) {
|
||||||
"SoftmaxMultiClassObj: label must be in [0, num_class)");
|
label_error = label; label = 0;
|
||||||
|
}
|
||||||
const float wt = info.GetWeight(j);
|
const float wt = info.GetWeight(j);
|
||||||
for (int k = 0; k < nclass; ++k) {
|
for (int k = 0; k < nclass; ++k) {
|
||||||
float p = rec[k];
|
float p = rec[k];
|
||||||
@ -222,6 +224,8 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
utils::Check(label_error >= 0 && label_error < nclass,
|
||||||
|
"SoftmaxMultiClassObj: label must be in [0, num_class), found %d in label", label_error);
|
||||||
}
|
}
|
||||||
virtual void PredTransform(std::vector<float> *io_preds) {
|
virtual void PredTransform(std::vector<float> *io_preds) {
|
||||||
this->Transform(io_preds, output_prob);
|
this->Transform(io_preds, output_prob);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user