Generalize prediction cache. (#8783)

* Extract most of the functionality into `DMatrixCache`.
* Move API entry to independent file to reduce dependency on `predictor.h` file.
* Add test.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2023-02-13 12:36:43 +08:00
committed by GitHub
parent ed91e775ec
commit d11a0044cf
12 changed files with 278 additions and 126 deletions

View File

@@ -1,57 +1,28 @@
/*!
* Copyright 2017-2021 by Contributors
/**
* Copyright 2017-2023 by Contributors
*/
#include "xgboost/predictor.h"
#include <dmlc/registry.h>
#include <mutex>
#include <string> // std::string
#include "../gbm/gbtree.h"
#include "xgboost/context.h"
#include "xgboost/data.h"
#include "../gbm/gbtree.h" // GBTreeModel
#include "xgboost/base.h" // bst_row_t,bst_group_t
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/learner.h" // LearnerModelParam
#include "xgboost/linalg.h" // Tensor
#include "xgboost/logging.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
} // namespace dmlc
namespace xgboost {
void PredictionContainer::ClearExpiredEntries() {
std::vector<DMatrix*> expired;
for (auto& kv : container_) {
if (kv.second.ref.expired()) {
expired.emplace_back(kv.first);
}
}
for (auto const& ptr : expired) {
container_.erase(ptr);
}
}
void Predictor::Configure(Args const&) {}
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->ClearExpiredEntries();
container_[m.get()].ref = m;
if (device != Context::kCpuId) {
container_[m.get()].predictions.SetDevice(device);
}
return container_[m.get()];
}
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
CHECK(container_.find(m) != container_.cend());
CHECK(container_.at(m).ref.lock())
<< "[Internal error]: DMatrix: " << m << " has expired.";
return container_.at(m);
}
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
this->ClearExpiredEntries();
return container_;
}
void Predictor::Configure(
const std::vector<std::pair<std::string, std::string>>&) {
}
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
if (e == nullptr) {