[breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668)
* Add a new API function for predicting on `DMatrix`. This function aligns with rest of the `XGBoosterPredictFrom*` functions on semantic of function arguments. * Purge `ntree_limit` from libxgboost, use iteration instead. * [dask] Use `inplace_predict` by default for dask sklearn models. * [dask] Run prediction shape inference on worker instead of client. The breaking change is in the Python sklearn `apply` function, I made it to be consistent with other prediction functions where `best_iteration` is used by default.
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
|
||||
#include "dmlc/any.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/model.h"
|
||||
#include "xgboost/predictor.h"
|
||||
@@ -996,7 +997,7 @@ class LearnerImpl : public LearnerIO {
|
||||
auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train.get(), &predt, true);
|
||||
this->PredictRaw(train.get(), &predt, true, 0, 0);
|
||||
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
|
||||
monitor_.Stop("PredictRaw");
|
||||
|
||||
@@ -1057,7 +1058,7 @@ class LearnerImpl : public LearnerIO {
|
||||
std::shared_ptr<DMatrix> m = data_sets[i];
|
||||
auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id);
|
||||
this->ValidateDMatrix(m.get(), false);
|
||||
this->PredictRaw(m.get(), &predt, false);
|
||||
this->PredictRaw(m.get(), &predt, false, 0, 0);
|
||||
|
||||
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
|
||||
out.Resize(predt.predictions.Size());
|
||||
@@ -1075,8 +1076,8 @@ class LearnerImpl : public LearnerIO {
|
||||
}
|
||||
|
||||
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
HostDeviceVector<bst_float> *out_preds, unsigned layer_begin,
|
||||
unsigned layer_end, bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) override {
|
||||
int multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
@@ -1085,16 +1086,16 @@ class LearnerImpl : public LearnerIO {
|
||||
this->Configure();
|
||||
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data.get(), out_preds, ntree_limit, approx_contribs);
|
||||
gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs);
|
||||
} else if (pred_interactions) {
|
||||
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
|
||||
gbm_->PredictInteractionContributions(data.get(), out_preds, layer_begin, layer_end,
|
||||
approx_contribs);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
|
||||
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
|
||||
} else {
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
|
||||
this->PredictRaw(data.get(), &prediction, training, ntree_limit);
|
||||
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
|
||||
// Copy the prediction cache to output prediction. out_preds comes from C API
|
||||
out_preds->SetDevice(generic_parameters_.gpu_id);
|
||||
out_preds->Resize(prediction.predictions.Size());
|
||||
@@ -1151,12 +1152,11 @@ class LearnerImpl : public LearnerIO {
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param training allow dropout when the DART booster is being used
|
||||
*/
|
||||
void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training,
|
||||
unsigned layer_begin, unsigned layer_end) const {
|
||||
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data, false);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
|
||||
}
|
||||
|
||||
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
|
||||
|
||||
Reference in New Issue
Block a user