Generalize prediction cache. (#8783)

* Extract most of the functionality into `DMatrixCache`.
* Move API entry to independent file to reduce dependency on `predictor.h` file.
* Add test.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2023-02-13 12:36:43 +08:00
committed by GitHub
parent ed91e775ec
commit d11a0044cf
12 changed files with 278 additions and 126 deletions

View File

@@ -25,6 +25,7 @@
#include <vector>
#include "collective/communicator-inl.h"
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
#include "common/charconv.h"
#include "common/common.h"
#include "common/io.h"
@@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
monitor_.Init("Learner");
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
for (std::shared_ptr<DMatrix> const& d : cache) {
local_cache.Cache(d, Context::kCpuId);
if (d) {
local_cache.Cache(d, Context::kCpuId);
}
}
}
~LearnerConfiguration() override {
@@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache();
local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
monitor_.Stop("BoostOneIter");
}