The base margin will need to have length `[num_class] * [number of data points]`. Otherwise, the array holding prediction results will be only partially initialized, causing undefined behavior. Fix: check the length of the base margin. If the length is not correct, use the global bias (`base_score`) instead. Warn the user about the substitution.
This commit is contained in:
parent
4a429a7c4f
commit
8c633d1ca3
@ -131,10 +131,25 @@ class CPUPredictor : public Predictor {
|
|||||||
const std::vector<bst_float>& base_margin = info.base_margin_;
|
const std::vector<bst_float>& base_margin = info.base_margin_;
|
||||||
out_preds->Resize(n);
|
out_preds->Resize(n);
|
||||||
std::vector<bst_float>& out_preds_h = out_preds->HostVector();
|
std::vector<bst_float>& out_preds_h = out_preds->HostVector();
|
||||||
if (base_margin.size() != 0) {
|
if (base_margin.size() == n) {
|
||||||
CHECK_EQ(out_preds->Size(), n);
|
CHECK_EQ(out_preds->Size(), n);
|
||||||
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
||||||
} else {
|
} else {
|
||||||
|
if (!base_margin.empty()) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "Warning: Ignoring the base margin, since it has incorrect length. "
|
||||||
|
<< "The base margin must be an array of length ";
|
||||||
|
if (model.param.num_output_group > 1) {
|
||||||
|
oss << "[num_class] * [number of data points], i.e. "
|
||||||
|
<< model.param.num_output_group << " * " << info.num_row_
|
||||||
|
<< " = " << n << ". ";
|
||||||
|
} else {
|
||||||
|
oss << "[number of data points], i.e. " << info.num_row_ << ". ";
|
||||||
|
}
|
||||||
|
oss << "Instead, all data points will use "
|
||||||
|
<< "base_score = " << model.base_margin;
|
||||||
|
LOG(INFO) << oss.str();
|
||||||
|
}
|
||||||
std::fill(out_preds_h.begin(), out_preds_h.end(), model.base_margin);
|
std::fill(out_preds_h.begin(), out_preds_h.end(), model.base_margin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user