Calculate base_score based on input labels for mae. (#8107)

Fit an intercept as base score for abs loss.
This commit is contained in:
Jiaming Yuan
2022-09-20 20:53:54 +08:00
committed by GitHub
parent 4f42aa5f12
commit fffb1fca52
42 changed files with 999 additions and 343 deletions

View File

@@ -429,11 +429,12 @@ class CPUPredictor : public Predictor {
}
out_preds->resize(model.learner_model_param->num_output_group *
(model.param.size_leaf_vector + 1));
auto base_score = model.learner_model_param->BaseScore(ctx_)(0);
// loop over output groups
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid,
&feat_vecs[0], 0, ntree_limit) +
model.learner_model_param->base_score;
(*out_preds)[gid] =
PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], 0, ntree_limit) +
base_score;
}
}
@@ -504,7 +505,8 @@ class CPUPredictor : public Predictor {
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
});
auto base_margin = info.base_margin_.View(GenericParameter::kCpuId);
auto base_margin = info.base_margin_.View(Context::kCpuId);
auto base_score = model.learner_model_param->BaseScore(Context::kCpuId)(0);
// start collecting the contributions
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView();
@@ -548,7 +550,7 @@ class CPUPredictor : public Predictor {
CHECK_EQ(base_margin.Shape(1), ngroup);
p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
} else {
p_contribs[ncolumns - 1] += model.learner_model_param->base_score;
p_contribs[ncolumns - 1] += base_score;
}
}
});