Generalize prediction cache. (#8783)

* Extract most of the functionality into `DMatrixCache`.
* Move API entry to independent file to reduce dependency on `predictor.h` file.
* Add test.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2023-02-13 12:36:43 +08:00
committed by GitHub
parent ed91e775ec
commit d11a0044cf
12 changed files with 278 additions and 126 deletions

View File

@@ -12,11 +12,11 @@
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/charconv.h"
#include "../common/io.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "c_api_error.h"
#include "c_api_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"

View File

@@ -1,6 +1,7 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
*/
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/threading_utils.h"
#include "../data/device_adapter.cuh"
#include "../data/proxy_dmatrix.h"

35
src/common/api_entry.h Normal file
View File

@@ -0,0 +1,35 @@
/**
* Copyright 2016-2023 by XGBoost contributors
*/
#ifndef XGBOOST_COMMON_API_ENTRY_H_
#define XGBOOST_COMMON_API_ENTRY_H_
#include <string> // std::string
#include <vector> // std::vector
#include "xgboost/base.h" // GradientPair,bst_ulong
#include "xgboost/predictor.h" // PredictionCacheEntry
namespace xgboost {
/**
* \brief entry to to easily hold returning information
*/
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning raw buffer */
std::vector<char> ret_char_vec;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returning prediction shape. */
std::vector<bst_ulong> prediction_shape;
};
} // namespace xgboost
#endif // XGBOOST_COMMON_API_ENTRY_H_

View File

@@ -10,6 +10,7 @@
#include <cstring>
#include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"

View File

@@ -25,6 +25,7 @@
#include <vector>
#include "collective/communicator-inl.h"
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
#include "common/charconv.h"
#include "common/common.h"
#include "common/io.h"
@@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
monitor_.Init("Learner");
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
for (std::shared_ptr<DMatrix> const& d : cache) {
local_cache.Cache(d, Context::kCpuId);
if (d) {
local_cache.Cache(d, Context::kCpuId);
}
}
}
~LearnerConfiguration() override {
@@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache();
local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
monitor_.Stop("BoostOneIter");
}

View File

@@ -1,57 +1,28 @@
/*!
* Copyright 2017-2021 by Contributors
/**
* Copyright 2017-2023 by Contributors
*/
#include "xgboost/predictor.h"
#include <dmlc/registry.h>
#include <mutex>
#include <string> // std::string
#include "../gbm/gbtree.h"
#include "xgboost/context.h"
#include "xgboost/data.h"
#include "../gbm/gbtree.h" // GBTreeModel
#include "xgboost/base.h" // bst_row_t,bst_group_t
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/learner.h" // LearnerModelParam
#include "xgboost/linalg.h" // Tensor
#include "xgboost/logging.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
} // namespace dmlc
namespace xgboost {
void PredictionContainer::ClearExpiredEntries() {
std::vector<DMatrix*> expired;
for (auto& kv : container_) {
if (kv.second.ref.expired()) {
expired.emplace_back(kv.first);
}
}
for (auto const& ptr : expired) {
container_.erase(ptr);
}
}
void Predictor::Configure(Args const&) {}
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->ClearExpiredEntries();
container_[m.get()].ref = m;
if (device != Context::kCpuId) {
container_[m.get()].predictions.SetDevice(device);
}
return container_[m.get()];
}
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
CHECK(container_.find(m) != container_.cend());
CHECK(container_.at(m).ref.lock())
<< "[Internal error]: DMatrix: " << m << " has expired.";
return container_.at(m);
}
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
this->ClearExpiredEntries();
return container_;
}
void Predictor::Configure(
const std::vector<std::pair<std::string, std::string>>&) {
}
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
if (e == nullptr) {