Generalize prediction cache. (#8783)

* Extract most of the functionality into `DMatrixCache`.
* Move API entry to independent file to reduce dependency on `predictor.h` file.
* Add test.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2023-02-13 12:36:43 +08:00
committed by GitHub
parent ed91e775ec
commit d11a0044cf
12 changed files with 278 additions and 126 deletions

134
include/xgboost/cache.h Normal file
View File

@@ -0,0 +1,134 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#ifndef XGBOOST_CACHE_H_
#define XGBOOST_CACHE_H_
#include <xgboost/logging.h> // CHECK_EQ
#include <cstddef> // std::size_t
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
#include <queue> // std:queue
#include <unordered_map> // std::unordered_map
#include <vector> // std::vector
namespace xgboost {
class DMatrix;
/**
* \brief FIFO cache for DMatrix related data.
*
* \tparam CacheT The type that needs to be cached.
*/
template <typename CacheT>
class DMatrixCache {
public:
struct Item {
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr<DMatrix> ref;
// The cached item
std::shared_ptr<CacheT> value;
CacheT const& Value() const { return *value; }
CacheT& Value() { return *value; }
};
protected:
std::unordered_map<DMatrix const*, Item> container_;
std::queue<DMatrix const*> queue_;
std::size_t max_size_;
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
void ClearExpired() {
// Clear expired entries
this->CheckConsistent();
std::vector<DMatrix const*> expired;
std::queue<DMatrix const*> remained;
while (!queue_.empty()) {
auto p_fmat = queue_.front();
auto it = container_.find(p_fmat);
CHECK(it != container_.cend());
if (it->second.ref.expired()) {
expired.push_back(it->first);
} else {
remained.push(it->first);
}
queue_.pop();
}
CHECK(queue_.empty());
CHECK_EQ(remained.size() + expired.size(), container_.size());
for (auto const* p_fmat : expired) {
container_.erase(p_fmat);
}
while (!remained.empty()) {
auto p_fmat = remained.front();
queue_.push(p_fmat);
remained.pop();
}
this->CheckConsistent();
}
void ClearExcess() {
this->CheckConsistent();
while (queue_.size() >= max_size_) {
auto p_fmat = queue_.front();
queue_.pop();
container_.erase(p_fmat);
}
this->CheckConsistent();
}
public:
/**
* \param cache_size Maximum size of the cache.
*/
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
/**
* \brief Cache a new DMatrix if it's no in the cache already.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param args The arguments for constructing a new cache item, if needed.
*
* \return The cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
template <typename... Args>
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
CHECK(m);
this->ClearExpired();
if (container_.size() >= max_size_) {
this->ClearExcess();
}
// after clear, cache size < max_size
CHECK_LT(container_.size(), max_size_);
auto it = container_.find(m.get());
if (it == container_.cend()) {
// after the new DMatrix, cache size is at most max_size
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
queue_.push(m.get());
}
return container_.at(m.get()).value;
}
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container() {
this->ClearExpired();
return container_;
}
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
CHECK(container_.find(m) != container_.cend());
CHECK(!container_.at(m).ref.expired());
return container_.at(m).value;
}
};
} // namespace xgboost
#endif // XGBOOST_CACHE_H_

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023 by XGBoost Contributors
* \file gbm.h
* \brief Interface of gradient booster,
* that learns through gradient statistics.
@@ -31,7 +31,6 @@ class ObjFunction;
struct Context;
struct LearnerModelParam;
struct PredictionCacheEntry;
class PredictionContainer;
/*!
* \brief interface of gradient boosting model.

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module.
@@ -8,12 +8,13 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_
#include <dmlc/io.h> // Serializable
#include <xgboost/base.h>
#include <xgboost/context.h> // Context
#include <xgboost/feature_map.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h> // Tensor
#include <xgboost/model.h>
#include <xgboost/predictor.h>
#include <xgboost/task.h>
#include <map>
@@ -29,6 +30,7 @@ class GradientBooster;
class ObjFunction;
class DMatrix;
class Json;
struct XGBAPIThreadLocalEntry;
enum class PredictionType : std::uint8_t { // NOLINT
kValue = 0,
@@ -40,26 +42,6 @@ enum class PredictionType : std::uint8_t { // NOLINT
kLeaf = 6
};
/*! \brief entry to to easily hold returning information */
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning raw buffer */
std::vector<char> ret_char_vec;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returning prediction shape. */
std::vector<bst_ulong> prediction_shape;
};
/*!
* \brief Learner class that does training and prediction.
* This is the user facing module of xgboost training.

View File

@@ -1,95 +1,66 @@
/*!
* Copyright 2017-2022 by Contributors
/**
* Copyright 2017-2023 by Contributors
* \file predictor.h
* \brief Interface of predictor,
* performs predictions for a gradient booster.
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/cache.h> // DMatrixCache
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <functional>
#include <functional> // std::function
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
// Forward declarations
namespace xgboost {
class TreeUpdater;
namespace gbm {
struct GBTreeModel;
} // namespace gbm
}
} // namespace xgboost
namespace xgboost {
/**
* \struct PredictionCacheEntry
*
* \brief Contains pointer to input matrix and associated cached predictions.
*/
struct PredictionCacheEntry {
// A storage for caching prediction values
HostDeviceVector<bst_float> predictions;
// The version of current cache, corresponding number of layers of trees
uint32_t version { 0 };
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr< DMatrix > ref;
std::uint32_t version{0};
PredictionCacheEntry() = default;
/* \brief Update the cache entry by number of versions.
/**
* \brief Update the cache entry by number of versions.
*
* \param v Added versions.
*/
void Update(uint32_t v) {
void Update(std::uint32_t v) {
version += v;
}
};
/* \brief A container for managed prediction caches.
/**
* \brief A container for managed prediction caches.
*/
class PredictionContainer {
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
void ClearExpiredEntries();
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
// we cache up to 32 DMatrix
std::size_t static constexpr DefaultSize() { return 32; }
public:
PredictionContainer() = default;
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
* all expired caches by checking the `std::weak_ptr`. Caching an existing
* DMatrix won't renew it.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* Another way to make a safe cache is create a proxy to this entry, with anther shared
* pointer defined inside, and pass this proxy around instead of the real entry. But
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param device Which device should the cache be allocated on. Pass
* Context::kCpuId for CPU or positive integer for GPU id.
*
* \return the cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
* method. Otherwise a dmlc::Error is thrown.
*
* \param m pointer to the DMatrix.
* \return The prediction cache for passed in DMatrix.
*/
PredictionCacheEntry& Entry(DMatrix* m);
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container();
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->CacheItem(m);
auto p_cache = this->container_.find(m.get());
if (device != Context::kCpuId) {
p_cache->second.Value().predictions.SetDevice(device);
}
return p_cache->second.Value();
}
};
/**
@@ -114,7 +85,7 @@ class Predictor {
*
* \param cfg The configuration.
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
virtual void Configure(Args const&);
/**
* \brief Initialize output prediction