Fix cache with gc (#8851)
- Make DMatrixCache thread-safe. - Remove the use of thread-local memory.
This commit is contained in:
@@ -6,16 +6,18 @@
|
||||
|
||||
#include <xgboost/logging.h> // CHECK_EQ
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
|
||||
#include <queue> // std:queue
|
||||
#include <unordered_map> // std::unordered_map
|
||||
#include <vector> // std::vector
|
||||
#include <cstddef> // for size_t
|
||||
#include <memory> // for weak_ptr, shared_ptr, make_shared
|
||||
#include <mutex> // for mutex, lock_guard
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <unordered_map> // for unordered_map
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost {
|
||||
class DMatrix;
|
||||
/**
|
||||
* \brief FIFO cache for DMatrix related data.
|
||||
* \brief Thread-aware FIFO cache for DMatrix related data.
|
||||
*
|
||||
* \tparam CacheT The type that needs to be cached.
|
||||
*/
|
||||
@@ -34,9 +36,31 @@ class DMatrixCache {
|
||||
|
||||
static constexpr std::size_t DefaultSize() { return 32; }
|
||||
|
||||
private:
|
||||
mutable std::mutex lock_;
|
||||
|
||||
protected:
|
||||
std::unordered_map<DMatrix const*, Item> container_;
|
||||
std::queue<DMatrix const*> queue_;
|
||||
struct Key {
|
||||
DMatrix const* ptr;
|
||||
std::thread::id const thread_id;
|
||||
|
||||
bool operator==(Key const& that) const {
|
||||
return ptr == that.ptr && thread_id == that.thread_id;
|
||||
}
|
||||
};
|
||||
struct Hash {
|
||||
std::size_t operator()(Key const& key) const noexcept {
|
||||
std::size_t f = std::hash<DMatrix const*>()(key.ptr);
|
||||
std::size_t s = std::hash<std::thread::id>()(key.thread_id);
|
||||
if (f == s) {
|
||||
return f;
|
||||
}
|
||||
return f ^ s;
|
||||
}
|
||||
};
|
||||
|
||||
std::unordered_map<Key, Item, Hash> container_;
|
||||
std::queue<Key> queue_;
|
||||
std::size_t max_size_;
|
||||
|
||||
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
|
||||
@@ -44,8 +68,8 @@ class DMatrixCache {
|
||||
void ClearExpired() {
|
||||
// Clear expired entries
|
||||
this->CheckConsistent();
|
||||
std::vector<DMatrix const*> expired;
|
||||
std::queue<DMatrix const*> remained;
|
||||
std::vector<Key> expired;
|
||||
std::queue<Key> remained;
|
||||
|
||||
while (!queue_.empty()) {
|
||||
auto p_fmat = queue_.front();
|
||||
@@ -61,8 +85,8 @@ class DMatrixCache {
|
||||
CHECK(queue_.empty());
|
||||
CHECK_EQ(remained.size() + expired.size(), container_.size());
|
||||
|
||||
for (auto const* p_fmat : expired) {
|
||||
container_.erase(p_fmat);
|
||||
for (auto const& key : expired) {
|
||||
container_.erase(key);
|
||||
}
|
||||
while (!remained.empty()) {
|
||||
auto p_fmat = remained.front();
|
||||
@@ -74,7 +98,9 @@ class DMatrixCache {
|
||||
|
||||
void ClearExcess() {
|
||||
this->CheckConsistent();
|
||||
while (queue_.size() >= max_size_) {
|
||||
// clear half of the entries to prevent repeatingly clearing cache.
|
||||
std::size_t half_size = max_size_ / 2;
|
||||
while (queue_.size() >= half_size && !queue_.empty()) {
|
||||
auto p_fmat = queue_.front();
|
||||
queue_.pop();
|
||||
container_.erase(p_fmat);
|
||||
@@ -88,7 +114,7 @@ class DMatrixCache {
|
||||
*/
|
||||
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
||||
/**
|
||||
* \brief Cache a new DMatrix if it's no in the cache already.
|
||||
* \brief Cache a new DMatrix if it's not in the cache already.
|
||||
*
|
||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||
@@ -101,35 +127,42 @@ class DMatrixCache {
|
||||
* created.
|
||||
*/
|
||||
template <typename... Args>
|
||||
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||
std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||
CHECK(m);
|
||||
std::lock_guard<std::mutex> guard{lock_};
|
||||
|
||||
this->ClearExpired();
|
||||
if (container_.size() >= max_size_) {
|
||||
this->ClearExcess();
|
||||
}
|
||||
// after clear, cache size < max_size
|
||||
CHECK_LT(container_.size(), max_size_);
|
||||
auto it = container_.find(m.get());
|
||||
auto key = Key{m.get(), std::this_thread::get_id()};
|
||||
auto it = container_.find(key);
|
||||
if (it == container_.cend()) {
|
||||
// after the new DMatrix, cache size is at most max_size
|
||||
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
|
||||
queue_.push(m.get());
|
||||
container_[key] = {m, std::make_shared<CacheT>(args...)};
|
||||
queue_.emplace(key);
|
||||
}
|
||||
return container_.at(m.get()).value;
|
||||
return container_.at(key).value;
|
||||
}
|
||||
/**
|
||||
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||
* returning.
|
||||
*/
|
||||
decltype(container_) const& Container() {
|
||||
std::lock_guard<std::mutex> guard{lock_};
|
||||
|
||||
this->ClearExpired();
|
||||
return container_;
|
||||
}
|
||||
|
||||
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
|
||||
CHECK(container_.find(m) != container_.cend());
|
||||
CHECK(!container_.at(m).ref.expired());
|
||||
return container_.at(m).value;
|
||||
std::lock_guard<std::mutex> guard{lock_};
|
||||
auto key = Key{m, std::this_thread::get_id()};
|
||||
CHECK(container_.find(key) != container_.cend());
|
||||
CHECK(!container_.at(key).ref.expired());
|
||||
return container_.at(key).value;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include <functional> // std::function
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread> // for get_id
|
||||
#include <utility> // for make_pair
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations
|
||||
@@ -48,18 +50,17 @@ struct PredictionCacheEntry {
|
||||
* \brief A container for managed prediction caches.
|
||||
*/
|
||||
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
|
||||
// we cache up to 32 DMatrix
|
||||
std::size_t static constexpr DefaultSize() { return 32; }
|
||||
// We cache up to 64 DMatrix for all threads
|
||||
std::size_t static constexpr DefaultSize() { return 64; }
|
||||
|
||||
public:
|
||||
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
|
||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->CacheItem(m);
|
||||
auto p_cache = this->container_.find(m.get());
|
||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, std::int32_t device) {
|
||||
auto p_cache = this->CacheItem(m);
|
||||
if (device != Context::kCpuId) {
|
||||
p_cache->second.Value().predictions.SetDevice(device);
|
||||
p_cache->predictions.SetDevice(device);
|
||||
}
|
||||
return p_cache->second.Value();
|
||||
return *p_cache;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user