fix crash in error
This commit is contained in:
parent
b6d85b9d9b
commit
3cc48d6707
@ -120,16 +120,25 @@ struct EvalMClassBase : public IEvaluator {
|
|||||||
"label and prediction size not match");
|
"label and prediction size not match");
|
||||||
const size_t nclass = preds.size() / info.labels.size();
|
const size_t nclass = preds.size() / info.labels.size();
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||||
|
|
||||||
float sum = 0.0, wsum = 0.0;
|
float sum = 0.0, wsum = 0.0;
|
||||||
|
int label_error = 0;
|
||||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
const float wt = info.GetWeight(i);
|
const float wt = info.GetWeight(i);
|
||||||
|
int label = static_cast<int>(info.labels[i]);
|
||||||
|
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||||
sum += Derived::EvalRow(info.labels[i],
|
sum += Derived::EvalRow(info.labels[i],
|
||||||
BeginPtr(preds) + i * nclass,
|
BeginPtr(preds) + i * nclass,
|
||||||
nclass) * wt;
|
nclass) * wt;
|
||||||
wsum += wt;
|
wsum += wt;
|
||||||
|
} else {
|
||||||
|
label_error = label;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
utils::Check(label_error >= 0 && label_error < static_cast<int>(nclass),
|
||||||
|
"MultiClassEvaluation: label must be in [0, num_class)," \
|
||||||
|
" num_class=%d but found %d in label",
|
||||||
|
static_cast<int>(nclass), label_error);
|
||||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||||
if (distributed) {
|
if (distributed) {
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
@ -143,7 +152,7 @@ struct EvalMClassBase : public IEvaluator {
|
|||||||
* \param pred prediction value of current instance
|
* \param pred prediction value of current instance
|
||||||
* \param nclass number of class in the prediction
|
* \param nclass number of class in the prediction
|
||||||
*/
|
*/
|
||||||
inline static float EvalRow(float label,
|
inline static float EvalRow(int label,
|
||||||
const float *pred,
|
const float *pred,
|
||||||
size_t nclass);
|
size_t nclass);
|
||||||
/*!
|
/*!
|
||||||
@ -154,13 +163,15 @@ struct EvalMClassBase : public IEvaluator {
|
|||||||
inline static float GetFinal(float esum, float wsum) {
|
inline static float GetFinal(float esum, float wsum) {
|
||||||
return esum / wsum;
|
return esum / wsum;
|
||||||
}
|
}
|
||||||
|
// used to store error message
|
||||||
|
const char *error_msg_;
|
||||||
};
|
};
|
||||||
/*! \brief match error */
|
/*! \brief match error */
|
||||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||||
virtual const char *Name(void) const {
|
virtual const char *Name(void) const {
|
||||||
return "merror";
|
return "merror";
|
||||||
}
|
}
|
||||||
inline static float EvalRow(float label,
|
inline static float EvalRow(int label,
|
||||||
const float *pred,
|
const float *pred,
|
||||||
size_t nclass) {
|
size_t nclass) {
|
||||||
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
|
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
|
||||||
@ -171,12 +182,11 @@ struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
|||||||
virtual const char *Name(void) const {
|
virtual const char *Name(void) const {
|
||||||
return "mlogloss";
|
return "mlogloss";
|
||||||
}
|
}
|
||||||
inline static float EvalRow(float label,
|
inline static float EvalRow(int label,
|
||||||
const float *pred,
|
const float *pred,
|
||||||
size_t nclass) {
|
size_t nclass) {
|
||||||
const float eps = 1e-16f;
|
const float eps = 1e-16f;
|
||||||
size_t k = static_cast<size_t>(label);
|
size_t k = static_cast<size_t>(label);
|
||||||
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
|
|
||||||
if (pred[k] > eps) {
|
if (pred[k] > eps) {
|
||||||
return -std::log(pred[k]);
|
return -std::log(pred[k]);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user