fix crash in error
This commit is contained in:
parent
b6d85b9d9b
commit
3cc48d6707
@ -120,16 +120,25 @@ struct EvalMClassBase : public IEvaluator {
|
||||
"label and prediction size not match");
|
||||
const size_t nclass = preds.size() / info.labels.size();
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
||||
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
int label_error = 0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += Derived::EvalRow(info.labels[i],
|
||||
BeginPtr(preds) + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
int label = static_cast<int>(info.labels[i]);
|
||||
if (label >= 0 && label < static_cast<int>(nclass)) {
|
||||
sum += Derived::EvalRow(info.labels[i],
|
||||
BeginPtr(preds) + i * nclass,
|
||||
nclass) * wt;
|
||||
wsum += wt;
|
||||
} else {
|
||||
label_error = label;
|
||||
}
|
||||
}
|
||||
utils::Check(label_error >= 0 && label_error < static_cast<int>(nclass),
|
||||
"MultiClassEvaluation: label must be in [0, num_class)," \
|
||||
" num_class=%d but found %d in label",
|
||||
static_cast<int>(nclass), label_error);
|
||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
@ -143,7 +152,7 @@ struct EvalMClassBase : public IEvaluator {
|
||||
* \param pred prediction value of current instance
|
||||
* \param nclass number of class in the prediction
|
||||
*/
|
||||
inline static float EvalRow(float label,
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass);
|
||||
/*!
|
||||
@ -154,13 +163,15 @@ struct EvalMClassBase : public IEvaluator {
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
// used to store error message
|
||||
const char *error_msg_;
|
||||
};
|
||||
/*! \brief match error */
|
||||
struct EvalMatchError : public EvalMClassBase<EvalMatchError> {
|
||||
virtual const char *Name(void) const {
|
||||
return "merror";
|
||||
}
|
||||
inline static float EvalRow(float label,
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
return FindMaxIndex(pred, nclass) != static_cast<int>(label);
|
||||
@ -171,12 +182,11 @@ struct EvalMultiLogLoss : public EvalMClassBase<EvalMultiLogLoss> {
|
||||
virtual const char *Name(void) const {
|
||||
return "mlogloss";
|
||||
}
|
||||
inline static float EvalRow(float label,
|
||||
inline static float EvalRow(int label,
|
||||
const float *pred,
|
||||
size_t nclass) {
|
||||
const float eps = 1e-16f;
|
||||
size_t k = static_cast<size_t>(label);
|
||||
utils::Check(k < nclass, "mlogloss: label must be in [0, num_class)");
|
||||
if (pred[k] > eps) {
|
||||
return -std::log(pred[k]);
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user