Generalize prediction cache. (#8783)
* Extract most of the functionality into `DMatrixCache`. * Move API entry to independent file to reduce dependency on `predictor.h` file. * Add test. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "common/charconv.h"
|
||||
#include "common/common.h"
|
||||
#include "common/io.h"
|
||||
@@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
|
||||
monitor_.Init("Learner");
|
||||
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
|
||||
for (std::shared_ptr<DMatrix> const& d : cache) {
|
||||
local_cache.Cache(d, Context::kCpuId);
|
||||
if (d) {
|
||||
local_cache.Cache(d, Context::kCpuId);
|
||||
}
|
||||
}
|
||||
}
|
||||
~LearnerConfiguration() override {
|
||||
@@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
local_cache->Cache(train, ctx_.gpu_id);
|
||||
|
||||
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
|
||||
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
|
||||
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user