[breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668)
* Add a new API function for predicting on `DMatrix`. This function aligns with rest of the `XGBoosterPredictFrom*` functions on semantic of function arguments. * Purge `ntree_limit` from libxgboost, use iteration instead. * [dask] Use `inplace_predict` by default for dask sklearn models. * [dask] Run prediction shape inference on worker instead of client. The breaking change is in the Python sklearn `apply` function, I made it to be consistent with other prediction functions where `best_iteration` is used by default.
This commit is contained in:
@@ -619,20 +619,58 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
auto& entry = learner->GetThreadLocal().prediction_entry;
|
||||
learner->Predict(
|
||||
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
|
||||
(option_mask & 1) != 0,
|
||||
&entry.predictions, ntree_limit,
|
||||
static_cast<bool>(training),
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0,
|
||||
(option_mask & 8) != 0,
|
||||
(option_mask & 16) != 0);
|
||||
auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner);
|
||||
learner->Predict(*static_cast<std::shared_ptr<DMatrix> *>(dmat),
|
||||
(option_mask & 1) != 0, &entry.predictions, 0, iteration_end,
|
||||
static_cast<bool>(training), (option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0, (option_mask & 8) != 0,
|
||||
(option_mask & 16) != 0);
|
||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||
*len = static_cast<xgboost::bst_ulong>(entry.predictions.Size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
char const* c_json_config,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
bst_float const **out_result) {
|
||||
API_BEGIN();
|
||||
if (handle == nullptr) {
|
||||
LOG(FATAL) << "Booster has not been intialized or has already been disposed.";
|
||||
}
|
||||
if (dmat == nullptr) {
|
||||
LOG(FATAL) << "DMatrix has not been intialized or has already been disposed.";
|
||||
}
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
auto& entry = learner->GetThreadLocal().prediction_entry;
|
||||
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
|
||||
auto iteration_end = get<Integer const>(config["iteration_end"]);
|
||||
learner->Predict(
|
||||
*static_cast<std::shared_ptr<DMatrix> *>(dmat),
|
||||
type == PredictionType::kMargin, &entry.predictions, iteration_begin,
|
||||
iteration_end, get<Boolean const>(config["training"]),
|
||||
type == PredictionType::kLeaf, type == PredictionType::kContribution,
|
||||
type == PredictionType::kApproxContribution,
|
||||
type == PredictionType::kInteraction);
|
||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
|
||||
auto rounds = iteration_end - iteration_begin;
|
||||
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
|
||||
// Determine shape
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
|
||||
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
|
||||
&shape, out_dim);
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
API_END();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
@@ -705,7 +743,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
XGB_DLL int XGBoosterPredictFromCUDAArray(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
@@ -715,7 +753,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
XGB_DLL int XGBoosterPredictFromCUDAColumnar(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
|
||||
@@ -66,8 +66,7 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
XGB_DLL int XGBoosterPredictFromCudaColumnar(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
@@ -79,8 +78,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
XGB_DLL int XGBoosterPredictFromCudaArray(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -30,8 +31,8 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
std::vector<bst_ulong> *out_shape,
|
||||
xgboost::bst_ulong *out_dim) {
|
||||
auto &shape = *out_shape;
|
||||
if ((type == PredictionType::kMargin || type == PredictionType::kValue) &&
|
||||
rows != 0) {
|
||||
if (type == PredictionType::kMargin && rows != 0) {
|
||||
// When kValue is used, softmax can change the chunksize.
|
||||
CHECK_EQ(chunksize, groups);
|
||||
}
|
||||
|
||||
@@ -110,5 +111,35 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
|
||||
chunksize * rows);
|
||||
}
|
||||
|
||||
// Reverse the ntree_limit in old prediction API.
|
||||
inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner) {
|
||||
// On Python and R, `best_ntree_limit` is set to `best_iteration * num_parallel_tree`.
|
||||
// To reverse it we just divide it by `num_parallel_tree`.
|
||||
if (ntree_limit != 0) {
|
||||
learner->Configure();
|
||||
uint32_t num_parallel_tree = 0;
|
||||
|
||||
Json config{Object()};
|
||||
learner->SaveConfig(&config);
|
||||
auto const &booster =
|
||||
get<String const>(config["learner"]["gradient_booster"]["name"]);
|
||||
if (booster == "gblinear") {
|
||||
num_parallel_tree = 0;
|
||||
} else if (booster == "dart") {
|
||||
num_parallel_tree = std::stoi(
|
||||
get<String const>(config["learner"]["gradient_booster"]["gbtree"]
|
||||
["gbtree_train_param"]["num_parallel_tree"]));
|
||||
} else if (booster == "gbtree") {
|
||||
num_parallel_tree = std::stoi(get<String const>(
|
||||
(config["learner"]["gradient_booster"]["gbtree_train_param"]
|
||||
["num_parallel_tree"])));
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown booster:" << booster;
|
||||
}
|
||||
ntree_limit /= std::max(num_parallel_tree, 1u);
|
||||
}
|
||||
return ntree_limit;
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user