Fix cache with gc (#8851)

- Make DMatrixCache thread-safe.
- Remove the use of thread-local memory.
This commit is contained in:
Jiaming Yuan
2023-03-01 00:39:06 +08:00
committed by GitHub
parent d9688f93c7
commit d54ef56f6f
4 changed files with 143 additions and 65 deletions

View File

@@ -6,16 +6,18 @@
#include <xgboost/logging.h> // CHECK_EQ
#include <cstddef> // std::size_t
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
#include <queue> // std:queue
#include <unordered_map> // std::unordered_map
#include <vector> // std::vector
#include <cstddef> // for size_t
#include <memory> // for weak_ptr, shared_ptr, make_shared
#include <mutex> // for mutex, lock_guard
#include <queue> // for queue
#include <thread> // for thread
#include <unordered_map> // for unordered_map
#include <vector> // for vector
namespace xgboost {
class DMatrix;
/**
* \brief FIFO cache for DMatrix related data.
* \brief Thread-aware FIFO cache for DMatrix related data.
*
* \tparam CacheT The type that needs to be cached.
*/
@@ -34,9 +36,31 @@ class DMatrixCache {
static constexpr std::size_t DefaultSize() { return 32; }
private:
mutable std::mutex lock_;
protected:
std::unordered_map<DMatrix const*, Item> container_;
std::queue<DMatrix const*> queue_;
struct Key {
DMatrix const* ptr;
std::thread::id const thread_id;
bool operator==(Key const& that) const {
return ptr == that.ptr && thread_id == that.thread_id;
}
};
struct Hash {
std::size_t operator()(Key const& key) const noexcept {
std::size_t f = std::hash<DMatrix const*>()(key.ptr);
std::size_t s = std::hash<std::thread::id>()(key.thread_id);
if (f == s) {
return f;
}
return f ^ s;
}
};
std::unordered_map<Key, Item, Hash> container_;
std::queue<Key> queue_;
std::size_t max_size_;
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
@@ -44,8 +68,8 @@ class DMatrixCache {
void ClearExpired() {
// Clear expired entries
this->CheckConsistent();
std::vector<DMatrix const*> expired;
std::queue<DMatrix const*> remained;
std::vector<Key> expired;
std::queue<Key> remained;
while (!queue_.empty()) {
auto p_fmat = queue_.front();
@@ -61,8 +85,8 @@ class DMatrixCache {
CHECK(queue_.empty());
CHECK_EQ(remained.size() + expired.size(), container_.size());
for (auto const* p_fmat : expired) {
container_.erase(p_fmat);
for (auto const& key : expired) {
container_.erase(key);
}
while (!remained.empty()) {
auto p_fmat = remained.front();
@@ -74,7 +98,9 @@ class DMatrixCache {
void ClearExcess() {
this->CheckConsistent();
while (queue_.size() >= max_size_) {
// clear half of the entries to prevent repeatingly clearing cache.
std::size_t half_size = max_size_ / 2;
while (queue_.size() >= half_size && !queue_.empty()) {
auto p_fmat = queue_.front();
queue_.pop();
container_.erase(p_fmat);
@@ -88,7 +114,7 @@ class DMatrixCache {
*/
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
/**
* \brief Cache a new DMatrix if it's no in the cache already.
* \brief Cache a new DMatrix if it's not in the cache already.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
@@ -101,35 +127,42 @@ class DMatrixCache {
* created.
*/
template <typename... Args>
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
CHECK(m);
std::lock_guard<std::mutex> guard{lock_};
this->ClearExpired();
if (container_.size() >= max_size_) {
this->ClearExcess();
}
// after clear, cache size < max_size
CHECK_LT(container_.size(), max_size_);
auto it = container_.find(m.get());
auto key = Key{m.get(), std::this_thread::get_id()};
auto it = container_.find(key);
if (it == container_.cend()) {
// after the new DMatrix, cache size is at most max_size
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
queue_.push(m.get());
container_[key] = {m, std::make_shared<CacheT>(args...)};
queue_.emplace(key);
}
return container_.at(m.get()).value;
return container_.at(key).value;
}
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container() {
std::lock_guard<std::mutex> guard{lock_};
this->ClearExpired();
return container_;
}
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
CHECK(container_.find(m) != container_.cend());
CHECK(!container_.at(m).ref.expired());
return container_.at(m).value;
std::lock_guard<std::mutex> guard{lock_};
auto key = Key{m, std::this_thread::get_id()};
CHECK(container_.find(key) != container_.cend());
CHECK(!container_.at(key).ref.expired());
return container_.at(key).value;
}
};
} // namespace xgboost

View File

@@ -14,6 +14,8 @@
#include <functional> // std::function
#include <memory>
#include <string>
#include <thread> // for get_id
#include <utility> // for make_pair
#include <vector>
// Forward declarations
@@ -48,18 +50,17 @@ struct PredictionCacheEntry {
* \brief A container for managed prediction caches.
*/
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
// we cache up to 32 DMatrix
std::size_t static constexpr DefaultSize() { return 32; }
// We cache up to 64 DMatrix for all threads
std::size_t static constexpr DefaultSize() { return 64; }
public:
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->CacheItem(m);
auto p_cache = this->container_.find(m.get());
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, std::int32_t device) {
auto p_cache = this->CacheItem(m);
if (device != Context::kCpuId) {
p_cache->second.Value().predictions.SetDevice(device);
p_cache->predictions.SetDevice(device);
}
return p_cache->second.Value();
return *p_cache;
}
};