[R] Use new predict function. (#6819)
* Call new C prediction API. * Add `strict_shape`. * Add `iterationrange`. * Update document.
This commit is contained in:
@@ -662,9 +662,21 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
auto& entry = learner->GetThreadLocal().prediction_entry;
|
||||
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
|
||||
auto iteration_end = get<Integer const>(config["iteration_end"]);
|
||||
|
||||
auto const& j_config = get<Object const>(config);
|
||||
auto type = PredictionType(get<Integer const>(j_config.at("type")));
|
||||
auto iteration_begin = get<Integer const>(j_config.at("iteration_begin"));
|
||||
auto iteration_end = get<Integer const>(j_config.at("iteration_end"));
|
||||
|
||||
auto ntree_limit_it = j_config.find("ntree_limit");
|
||||
if (ntree_limit_it != j_config.cend() && !IsA<Null>(ntree_limit_it->second) &&
|
||||
get<Integer const>(ntree_limit_it->second) != 0) {
|
||||
CHECK(iteration_end == 0) <<
|
||||
"Only one of the `ntree_limit` or `iteration_range` can be specified.";
|
||||
LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_range` instead.";
|
||||
iteration_end = GetIterationFromTreeLimit(get<Integer const>(ntree_limit_it->second), learner);
|
||||
}
|
||||
|
||||
bool approximate = type == PredictionType::kApproxContribution ||
|
||||
type == PredictionType::kApproxInteraction;
|
||||
bool contribs = type == PredictionType::kContribution ||
|
||||
|
||||
@@ -48,7 +48,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = groups;
|
||||
shape.back() = std::min(groups, chunksize);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user