[R] Use new predict function. (#6819)

* Call new C prediction API.
* Add `strict_shape`.
* Add `iterationrange`.
* Update document.
This commit is contained in:
Jiaming Yuan
2021-06-11 13:03:29 +08:00
committed by GitHub
parent 25514e104a
commit b56614e9b8
18 changed files with 293 additions and 160 deletions

View File

@@ -662,9 +662,21 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
auto *learner = static_cast<Learner*>(handle);
auto& entry = learner->GetThreadLocal().prediction_entry;
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
auto type = PredictionType(get<Integer const>(config["type"]));
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
auto iteration_end = get<Integer const>(config["iteration_end"]);
auto const& j_config = get<Object const>(config);
auto type = PredictionType(get<Integer const>(j_config.at("type")));
auto iteration_begin = get<Integer const>(j_config.at("iteration_begin"));
auto iteration_end = get<Integer const>(j_config.at("iteration_end"));
auto ntree_limit_it = j_config.find("ntree_limit");
if (ntree_limit_it != j_config.cend() && !IsA<Null>(ntree_limit_it->second) &&
get<Integer const>(ntree_limit_it->second) != 0) {
CHECK(iteration_end == 0) <<
"Only one of the `ntree_limit` or `iteration_range` can be specified.";
LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_range` instead.";
iteration_end = GetIterationFromTreeLimit(get<Integer const>(ntree_limit_it->second), learner);
}
bool approximate = type == PredictionType::kApproxContribution ||
type == PredictionType::kApproxInteraction;
bool contribs = type == PredictionType::kContribution ||

View File

@@ -48,7 +48,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
*out_dim = 2;
shape.resize(*out_dim);
shape.front() = rows;
shape.back() = groups;
shape.back() = std::min(groups, chunksize);
}
break;
}