Fix approximated predict contribution. (#6811)

This commit is contained in:
Jiaming Yuan
2021-04-03 02:15:03 +08:00
committed by GitHub
parent 0cced530ea
commit 7e06c81894
6 changed files with 47 additions and 17 deletions

View File

@@ -651,13 +651,17 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
auto type = PredictionType(get<Integer const>(config["type"]));
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
auto iteration_end = get<Integer const>(config["iteration_end"]);
learner->Predict(
*static_cast<std::shared_ptr<DMatrix> *>(dmat),
type == PredictionType::kMargin, &entry.predictions, iteration_begin,
iteration_end, get<Boolean const>(config["training"]),
type == PredictionType::kLeaf, type == PredictionType::kContribution,
type == PredictionType::kApproxContribution,
type == PredictionType::kInteraction);
bool approximate = type == PredictionType::kApproxContribution ||
type == PredictionType::kApproxInteraction;
bool contribs = type == PredictionType::kContribution ||
type == PredictionType::kApproxContribution;
bool interactions = type == PredictionType::kInteraction ||
type == PredictionType::kApproxInteraction;
bool training = get<Boolean const>(config["training"]);
learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions,
iteration_begin, iteration_end, training,
type == PredictionType::kLeaf, contribs, approximate,
interactions);
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;