Generalize prediction cache. (#8783)
* Extract most of the functionality into `DMatrixCache`. * Move API entry to independent file to reduce dependency on `predictor.h` file. * Add test. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,57 +1,28 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by Contributors
|
||||
*/
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <mutex>
|
||||
#include <string> // std::string
|
||||
|
||||
#include "../gbm/gbtree.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../gbm/gbtree.h" // GBTreeModel
|
||||
#include "xgboost/base.h" // bst_row_t,bst_group_t
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/learner.h" // LearnerModelParam
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
void PredictionContainer::ClearExpiredEntries() {
|
||||
std::vector<DMatrix*> expired;
|
||||
for (auto& kv : container_) {
|
||||
if (kv.second.ref.expired()) {
|
||||
expired.emplace_back(kv.first);
|
||||
}
|
||||
}
|
||||
for (auto const& ptr : expired) {
|
||||
container_.erase(ptr);
|
||||
}
|
||||
}
|
||||
void Predictor::Configure(Args const&) {}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->ClearExpiredEntries();
|
||||
container_[m.get()].ref = m;
|
||||
if (device != Context::kCpuId) {
|
||||
container_[m.get()].predictions.SetDevice(device);
|
||||
}
|
||||
return container_[m.get()];
|
||||
}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
|
||||
CHECK(container_.find(m) != container_.cend());
|
||||
CHECK(container_.at(m).ref.lock())
|
||||
<< "[Internal error]: DMatrix: " << m << " has expired.";
|
||||
return container_.at(m);
|
||||
}
|
||||
|
||||
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
|
||||
this->ClearExpiredEntries();
|
||||
return container_;
|
||||
}
|
||||
|
||||
void Predictor::Configure(
|
||||
const std::vector<std::pair<std::string, std::string>>&) {
|
||||
}
|
||||
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
||||
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user