add object bound checking
This commit is contained in:
parent
df3eafc5ba
commit
e90b25a381
@ -41,6 +41,25 @@ struct LossType {
|
|||||||
default: utils::Error("unknown loss_type"); return 0.0f;
|
default: utils::Error("unknown loss_type"); return 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief check if label range is valid
|
||||||
|
*/
|
||||||
|
inline bool CheckLabel(float x) const {
|
||||||
|
if (loss_type != kLinearSquare) {
|
||||||
|
return x >= 0.0f && x <= 1.0f;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief error message displayed when check label fail
|
||||||
|
*/
|
||||||
|
inline const char * CheckLabelErrorMsg(void) const {
|
||||||
|
if (loss_type != kLinearSquare) {
|
||||||
|
return "label must be in [0,1] for logistic regression";
|
||||||
|
} else {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief calculate first order gradient of loss, given transformed prediction
|
* \brief calculate first order gradient of loss, given transformed prediction
|
||||||
* \param predt transformed prediction
|
* \param predt transformed prediction
|
||||||
@ -115,6 +134,8 @@ class RegLossObj : public IObjFunction{
|
|||||||
"labels are not correctly provided");
|
"labels are not correctly provided");
|
||||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||||
gpair.resize(preds.size());
|
gpair.resize(preds.size());
|
||||||
|
// check if label in range
|
||||||
|
bool label_correct = true;
|
||||||
// start calculating gradient
|
// start calculating gradient
|
||||||
const unsigned nstep = static_cast<unsigned>(info.labels.size());
|
const unsigned nstep = static_cast<unsigned>(info.labels.size());
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||||
@ -124,9 +145,11 @@ class RegLossObj : public IObjFunction{
|
|||||||
float p = loss.PredTransform(preds[i]);
|
float p = loss.PredTransform(preds[i]);
|
||||||
float w = info.GetWeight(j);
|
float w = info.GetWeight(j);
|
||||||
if (info.labels[j] == 1.0f) w *= scale_pos_weight;
|
if (info.labels[j] == 1.0f) w *= scale_pos_weight;
|
||||||
|
if (!loss.CheckLabel(info.labels[j])) label_correct = false;
|
||||||
gpair[i] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
|
gpair[i] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w,
|
||||||
loss.SecondOrderGradient(p, info.labels[j]) * w);
|
loss.SecondOrderGradient(p, info.labels[j]) * w);
|
||||||
}
|
}
|
||||||
|
utils::Check(label_correct, loss.CheckLabelErrorMsg());
|
||||||
}
|
}
|
||||||
virtual const char* DefaultEvalMetric(void) const {
|
virtual const char* DefaultEvalMetric(void) const {
|
||||||
return loss.DefaultEvalMetric();
|
return loss.DefaultEvalMetric();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user