Remove update prediction cache from predictors. (#5312)
Move this function into gbtree, and uses only updater for doing so. As now the predictor knows exactly how many trees to predict, there's no need for it to update the prediction cache.
This commit is contained in:
@@ -7,7 +7,6 @@
|
||||
#include <thrust/fill.h>
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
@@ -316,8 +315,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
CHECK_EQ(tree_begin, 0);
|
||||
auto* out_preds = &predts->predictions;
|
||||
CHECK_GE(predts->version, tree_begin);
|
||||
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
||||
CHECK_EQ(predts->version, 0);
|
||||
}
|
||||
if (predts->version == 0) {
|
||||
CHECK_EQ(out_preds->Size(), 0);
|
||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||
}
|
||||
|
||||
@@ -370,32 +371,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(
|
||||
const gbm::GBTreeModel& model,
|
||||
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||
int num_new_trees,
|
||||
DMatrix* m,
|
||||
PredictionCacheEntry* predts) override {
|
||||
int device = generic_param_->gpu_id;
|
||||
ConfigureDevice(device);
|
||||
auto old_ntree = model.trees.size() - num_new_trees;
|
||||
// update cache entry
|
||||
auto* out = &predts->predictions;
|
||||
if (predts->predictions.Size() == 0) {
|
||||
InitOutPredictions(m->Info(), out, model);
|
||||
DevicePredictInternal(m, out, model, 0, model.trees.size());
|
||||
} else if (model.learner_model_param_->num_output_group == 1 &&
|
||||
updaters->size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
updaters->back()->UpdatePredictionCache(m, out)) {
|
||||
{}
|
||||
} else {
|
||||
DevicePredictInternal(m, out, model, old_ntree, model.trees.size());
|
||||
}
|
||||
auto delta = num_new_trees / model.learner_model_param_->num_output_group;
|
||||
predts->Update(delta);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
|
||||
Reference in New Issue
Block a user