Generalize prediction cache. (#8783)
* Extract most of the functionality into `DMatrixCache`. * Move API entry to independent file to reduce dependency on `predictor.h` file. * Add test. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
134
include/xgboost/cache.h
Normal file
134
include/xgboost/cache.h
Normal file
@@ -0,0 +1,134 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_CACHE_H_
|
||||
#define XGBOOST_CACHE_H_
|
||||
|
||||
#include <xgboost/logging.h> // CHECK_EQ
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
|
||||
#include <queue> // std:queue
|
||||
#include <unordered_map> // std::unordered_map
|
||||
#include <vector> // std::vector
|
||||
|
||||
namespace xgboost {
|
||||
class DMatrix;
|
||||
/**
|
||||
* \brief FIFO cache for DMatrix related data.
|
||||
*
|
||||
* \tparam CacheT The type that needs to be cached.
|
||||
*/
|
||||
template <typename CacheT>
|
||||
class DMatrixCache {
|
||||
public:
|
||||
struct Item {
|
||||
// A weak pointer for checking whether the DMatrix object has expired.
|
||||
std::weak_ptr<DMatrix> ref;
|
||||
// The cached item
|
||||
std::shared_ptr<CacheT> value;
|
||||
|
||||
CacheT const& Value() const { return *value; }
|
||||
CacheT& Value() { return *value; }
|
||||
};
|
||||
|
||||
protected:
|
||||
std::unordered_map<DMatrix const*, Item> container_;
|
||||
std::queue<DMatrix const*> queue_;
|
||||
std::size_t max_size_;
|
||||
|
||||
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
|
||||
|
||||
void ClearExpired() {
|
||||
// Clear expired entries
|
||||
this->CheckConsistent();
|
||||
std::vector<DMatrix const*> expired;
|
||||
std::queue<DMatrix const*> remained;
|
||||
|
||||
while (!queue_.empty()) {
|
||||
auto p_fmat = queue_.front();
|
||||
auto it = container_.find(p_fmat);
|
||||
CHECK(it != container_.cend());
|
||||
if (it->second.ref.expired()) {
|
||||
expired.push_back(it->first);
|
||||
} else {
|
||||
remained.push(it->first);
|
||||
}
|
||||
queue_.pop();
|
||||
}
|
||||
CHECK(queue_.empty());
|
||||
CHECK_EQ(remained.size() + expired.size(), container_.size());
|
||||
|
||||
for (auto const* p_fmat : expired) {
|
||||
container_.erase(p_fmat);
|
||||
}
|
||||
while (!remained.empty()) {
|
||||
auto p_fmat = remained.front();
|
||||
queue_.push(p_fmat);
|
||||
remained.pop();
|
||||
}
|
||||
this->CheckConsistent();
|
||||
}
|
||||
|
||||
void ClearExcess() {
|
||||
this->CheckConsistent();
|
||||
while (queue_.size() >= max_size_) {
|
||||
auto p_fmat = queue_.front();
|
||||
queue_.pop();
|
||||
container_.erase(p_fmat);
|
||||
}
|
||||
this->CheckConsistent();
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* \param cache_size Maximum size of the cache.
|
||||
*/
|
||||
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
||||
/**
|
||||
* \brief Cache a new DMatrix if it's no in the cache already.
|
||||
*
|
||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||
* cache is tied to the shared pointer.
|
||||
*
|
||||
* \param m shared pointer to the DMatrix that needs to be cached.
|
||||
* \param args The arguments for constructing a new cache item, if needed.
|
||||
*
|
||||
* \return The cache entry for passed in DMatrix, either an existing cache or newly
|
||||
* created.
|
||||
*/
|
||||
template <typename... Args>
|
||||
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||
CHECK(m);
|
||||
this->ClearExpired();
|
||||
if (container_.size() >= max_size_) {
|
||||
this->ClearExcess();
|
||||
}
|
||||
// after clear, cache size < max_size
|
||||
CHECK_LT(container_.size(), max_size_);
|
||||
auto it = container_.find(m.get());
|
||||
if (it == container_.cend()) {
|
||||
// after the new DMatrix, cache size is at most max_size
|
||||
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
|
||||
queue_.push(m.get());
|
||||
}
|
||||
return container_.at(m.get()).value;
|
||||
}
|
||||
/**
|
||||
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||
* returning.
|
||||
*/
|
||||
decltype(container_) const& Container() {
|
||||
this->ClearExpired();
|
||||
return container_;
|
||||
}
|
||||
|
||||
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
|
||||
CHECK(container_.find(m) != container_.cend());
|
||||
CHECK(!container_.at(m).ref.expired());
|
||||
return container_.at(m).value;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_CACHE_H_
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* \file gbm.h
|
||||
* \brief Interface of gradient booster,
|
||||
* that learns through gradient statistics.
|
||||
@@ -31,7 +31,6 @@ class ObjFunction;
|
||||
struct Context;
|
||||
struct LearnerModelParam;
|
||||
struct PredictionCacheEntry;
|
||||
class PredictionContainer;
|
||||
|
||||
/*!
|
||||
* \brief interface of gradient boosting model.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file learner.h
|
||||
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
||||
* This is the user facing XGBoost training module.
|
||||
@@ -8,12 +8,13 @@
|
||||
#ifndef XGBOOST_LEARNER_H_
|
||||
#define XGBOOST_LEARNER_H_
|
||||
|
||||
#include <dmlc/io.h> // Serializable
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h> // Tensor
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/task.h>
|
||||
|
||||
#include <map>
|
||||
@@ -29,6 +30,7 @@ class GradientBooster;
|
||||
class ObjFunction;
|
||||
class DMatrix;
|
||||
class Json;
|
||||
struct XGBAPIThreadLocalEntry;
|
||||
|
||||
enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kValue = 0,
|
||||
@@ -40,26 +42,6 @@ enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kLeaf = 6
|
||||
};
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
std::string ret_str;
|
||||
/*! \brief result holder for returning raw buffer */
|
||||
std::vector<char> ret_char_vec;
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<bst_float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returning prediction result. */
|
||||
PredictionCacheEntry prediction_entry;
|
||||
/*! \brief Temp variable for returning prediction shape. */
|
||||
std::vector<bst_ulong> prediction_shape;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Learner class that does training and prediction.
|
||||
* This is the user facing module of xgboost training.
|
||||
|
||||
@@ -1,95 +1,66 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by Contributors
|
||||
* \file predictor.h
|
||||
* \brief Interface of predictor,
|
||||
* performs predictions for a gradient booster.
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/cache.h> // DMatrixCache
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
|
||||
#include <functional>
|
||||
#include <functional> // std::function
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations
|
||||
namespace xgboost {
|
||||
class TreeUpdater;
|
||||
namespace gbm {
|
||||
struct GBTreeModel;
|
||||
} // namespace gbm
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
/**
|
||||
* \struct PredictionCacheEntry
|
||||
*
|
||||
* \brief Contains pointer to input matrix and associated cached predictions.
|
||||
*/
|
||||
struct PredictionCacheEntry {
|
||||
// A storage for caching prediction values
|
||||
HostDeviceVector<bst_float> predictions;
|
||||
// The version of current cache, corresponding number of layers of trees
|
||||
uint32_t version { 0 };
|
||||
// A weak pointer for checking whether the DMatrix object has expired.
|
||||
std::weak_ptr< DMatrix > ref;
|
||||
std::uint32_t version{0};
|
||||
|
||||
PredictionCacheEntry() = default;
|
||||
/* \brief Update the cache entry by number of versions.
|
||||
/**
|
||||
* \brief Update the cache entry by number of versions.
|
||||
*
|
||||
* \param v Added versions.
|
||||
*/
|
||||
void Update(uint32_t v) {
|
||||
void Update(std::uint32_t v) {
|
||||
version += v;
|
||||
}
|
||||
};
|
||||
|
||||
/* \brief A container for managed prediction caches.
|
||||
/**
|
||||
* \brief A container for managed prediction caches.
|
||||
*/
|
||||
class PredictionContainer {
|
||||
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
|
||||
void ClearExpiredEntries();
|
||||
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
|
||||
// we cache up to 32 DMatrix
|
||||
std::size_t static constexpr DefaultSize() { return 32; }
|
||||
|
||||
public:
|
||||
PredictionContainer() = default;
|
||||
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
|
||||
* all expired caches by checking the `std::weak_ptr`. Caching an existing
|
||||
* DMatrix won't renew it.
|
||||
*
|
||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||
* cache is tied to the shared pointer.
|
||||
*
|
||||
* Another way to make a safe cache is create a proxy to this entry, with anther shared
|
||||
* pointer defined inside, and pass this proxy around instead of the real entry. But
|
||||
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
|
||||
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
|
||||
*
|
||||
* \param m shared pointer to the DMatrix that needs to be cached.
|
||||
* \param device Which device should the cache be allocated on. Pass
|
||||
* Context::kCpuId for CPU or positive integer for GPU id.
|
||||
*
|
||||
* \return the cache entry for passed in DMatrix, either an existing cache or newly
|
||||
* created.
|
||||
*/
|
||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
|
||||
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
|
||||
* method. Otherwise a dmlc::Error is thrown.
|
||||
*
|
||||
* \param m pointer to the DMatrix.
|
||||
* \return The prediction cache for passed in DMatrix.
|
||||
*/
|
||||
PredictionCacheEntry& Entry(DMatrix* m);
|
||||
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||
* returning.
|
||||
*/
|
||||
decltype(container_) const& Container();
|
||||
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
|
||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->CacheItem(m);
|
||||
auto p_cache = this->container_.find(m.get());
|
||||
if (device != Context::kCpuId) {
|
||||
p_cache->second.Value().predictions.SetDevice(device);
|
||||
}
|
||||
return p_cache->second.Value();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -114,7 +85,7 @@ class Predictor {
|
||||
*
|
||||
* \param cfg The configuration.
|
||||
*/
|
||||
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
|
||||
virtual void Configure(Args const&);
|
||||
|
||||
/**
|
||||
* \brief Initialize output prediction
|
||||
|
||||
Reference in New Issue
Block a user