Generalize prediction cache. (#8783)
* Extract most of the functionality into `DMatrixCache`. * Move API entry to independent file to reduce dependency on `predictor.h` file. * Add test. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -12,11 +12,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h"
|
||||
#include "../common/io.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
|
||||
35
src/common/api_entry.h
Normal file
35
src/common/api_entry.h
Normal file
@@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2016-2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_API_ENTRY_H_
|
||||
#define XGBOOST_COMMON_API_ENTRY_H_
|
||||
#include <string> // std::string
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "xgboost/base.h" // GradientPair,bst_ulong
|
||||
#include "xgboost/predictor.h" // PredictionCacheEntry
|
||||
|
||||
namespace xgboost {
|
||||
/**
|
||||
* \brief entry to to easily hold returning information
|
||||
*/
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
std::string ret_str;
|
||||
/*! \brief result holder for returning raw buffer */
|
||||
std::vector<char> ret_char_vec;
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returning prediction result. */
|
||||
PredictionCacheEntry prediction_entry;
|
||||
/*! \brief Temp variable for returning prediction shape. */
|
||||
std::vector<bst_ulong> prediction_shape;
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_API_ENTRY_H_
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <cstring>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "common/charconv.h"
|
||||
#include "common/common.h"
|
||||
#include "common/io.h"
|
||||
@@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
|
||||
monitor_.Init("Learner");
|
||||
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
|
||||
for (std::shared_ptr<DMatrix> const& d : cache) {
|
||||
local_cache.Cache(d, Context::kCpuId);
|
||||
if (d) {
|
||||
local_cache.Cache(d, Context::kCpuId);
|
||||
}
|
||||
}
|
||||
}
|
||||
~LearnerConfiguration() override {
|
||||
@@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
local_cache->Cache(train, ctx_.gpu_id);
|
||||
|
||||
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
|
||||
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
|
||||
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
|
||||
|
||||
@@ -1,57 +1,28 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by Contributors
|
||||
*/
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <mutex>
|
||||
#include <string> // std::string
|
||||
|
||||
#include "../gbm/gbtree.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../gbm/gbtree.h" // GBTreeModel
|
||||
#include "xgboost/base.h" // bst_row_t,bst_group_t
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/learner.h" // LearnerModelParam
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
void PredictionContainer::ClearExpiredEntries() {
|
||||
std::vector<DMatrix*> expired;
|
||||
for (auto& kv : container_) {
|
||||
if (kv.second.ref.expired()) {
|
||||
expired.emplace_back(kv.first);
|
||||
}
|
||||
}
|
||||
for (auto const& ptr : expired) {
|
||||
container_.erase(ptr);
|
||||
}
|
||||
}
|
||||
void Predictor::Configure(Args const&) {}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->ClearExpiredEntries();
|
||||
container_[m.get()].ref = m;
|
||||
if (device != Context::kCpuId) {
|
||||
container_[m.get()].predictions.SetDevice(device);
|
||||
}
|
||||
return container_[m.get()];
|
||||
}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
|
||||
CHECK(container_.find(m) != container_.cend());
|
||||
CHECK(container_.at(m).ref.lock())
|
||||
<< "[Internal error]: DMatrix: " << m << " has expired.";
|
||||
return container_.at(m);
|
||||
}
|
||||
|
||||
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
|
||||
this->ClearExpiredEntries();
|
||||
return container_;
|
||||
}
|
||||
|
||||
void Predictor::Configure(
|
||||
const std::vector<std::pair<std::string, std::string>>&) {
|
||||
}
|
||||
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
||||
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user