Jiaming Yuan 72e8331eab
Reimplement the NDCG metric. (#8906)
- Add support for non-exp gain.
- Cache the DMatrix object to avoid re-calculating the IDCG.
- Make GPU implementation deterministic. (no atomic add)
2023-03-15 03:26:17 +08:00

205 lines
6.1 KiB
C++

/**
* Copyright 2023 by XGBoost contributors
*/
#ifndef XGBOOST_CACHE_H_
#define XGBOOST_CACHE_H_
#include <xgboost/logging.h> // for CHECK_EQ, CHECK
#include <cstddef> // for size_t
#include <memory> // for weak_ptr, shared_ptr, make_shared
#include <mutex> // for mutex, lock_guard
#include <queue> // for queue
#include <thread> // for thread
#include <unordered_map> // for unordered_map
#include <utility> // for move
#include <vector> // for vector
namespace xgboost {
class DMatrix;
/**
* \brief Thread-aware FIFO cache for DMatrix related data.
*
* \tparam CacheT The type that needs to be cached.
*/
template <typename CacheT>
class DMatrixCache {
public:
struct Item {
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr<DMatrix> ref;
// The cached item
std::shared_ptr<CacheT> value;
CacheT const& Value() const { return *value; }
CacheT& Value() { return *value; }
Item(std::shared_ptr<DMatrix> m, std::shared_ptr<CacheT> v) : ref{m}, value{std::move(v)} {}
};
static constexpr std::size_t DefaultSize() { return 32; }
private:
mutable std::mutex lock_;
protected:
struct Key {
DMatrix const* ptr;
std::thread::id const thread_id;
bool operator==(Key const& that) const {
return ptr == that.ptr && thread_id == that.thread_id;
}
};
struct Hash {
std::size_t operator()(Key const& key) const noexcept {
std::size_t f = std::hash<DMatrix const*>()(key.ptr);
std::size_t s = std::hash<std::thread::id>()(key.thread_id);
if (f == s) {
return f;
}
return f ^ s;
}
};
std::unordered_map<Key, Item, Hash> container_;
std::queue<Key> queue_;
std::size_t max_size_;
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
void ClearExpired() {
// Clear expired entries
this->CheckConsistent();
std::vector<Key> expired;
std::queue<Key> remained;
while (!queue_.empty()) {
auto p_fmat = queue_.front();
auto it = container_.find(p_fmat);
CHECK(it != container_.cend());
if (it->second.ref.expired()) {
expired.push_back(it->first);
} else {
remained.push(it->first);
}
queue_.pop();
}
CHECK(queue_.empty());
CHECK_EQ(remained.size() + expired.size(), container_.size());
for (auto const& key : expired) {
container_.erase(key);
}
while (!remained.empty()) {
auto p_fmat = remained.front();
queue_.push(p_fmat);
remained.pop();
}
this->CheckConsistent();
}
void ClearExcess() {
this->CheckConsistent();
// clear half of the entries to prevent repeatingly clearing cache.
std::size_t half_size = max_size_ / 2;
while (queue_.size() >= half_size && !queue_.empty()) {
auto p_fmat = queue_.front();
queue_.pop();
container_.erase(p_fmat);
}
this->CheckConsistent();
}
public:
/**
* \param cache_size Maximum size of the cache.
*/
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
DMatrixCache& operator=(DMatrixCache&& that) {
CHECK(lock_.try_lock());
lock_.unlock();
CHECK(that.lock_.try_lock());
that.lock_.unlock();
std::swap(this->container_, that.container_);
std::swap(this->queue_, that.queue_);
std::swap(this->max_size_, that.max_size_);
return *this;
}
/**
* \brief Cache a new DMatrix if it's not in the cache already.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param args The arguments for constructing a new cache item, if needed.
*
* \return The cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
template <typename... Args>
std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
CHECK(m);
std::lock_guard<std::mutex> guard{lock_};
this->ClearExpired();
if (container_.size() >= max_size_) {
this->ClearExcess();
}
// after clear, cache size < max_size
CHECK_LT(container_.size(), max_size_);
auto key = Key{m.get(), std::this_thread::get_id()};
auto it = container_.find(key);
if (it == container_.cend()) {
// after the new DMatrix, cache size is at most max_size
container_.emplace(key, Item{m, std::make_shared<CacheT>(args...)});
queue_.emplace(key);
}
return container_.at(key).value;
}
/**
* \brief Re-initialize the item in cache.
*
* Since the shared_ptr is used to hold the item, any reference that lives outside of
* the cache can no-longer be reached from the cache.
*
* We use reset instead of erase to avoid walking through the whole cache for renewing
* a single item. (the cache is FIFO, needs to maintain the order).
*/
template <typename... Args>
std::shared_ptr<CacheT> ResetItem(std::shared_ptr<DMatrix> m, Args const&... args) {
std::lock_guard<std::mutex> guard{lock_};
CheckConsistent();
auto key = Key{m.get(), std::this_thread::get_id()};
auto it = container_.find(key);
CHECK(it != container_.cend());
it->second = {m, std::make_shared<CacheT>(args...)};
CheckConsistent();
return it->second.value;
}
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container() {
std::lock_guard<std::mutex> guard{lock_};
this->ClearExpired();
return container_;
}
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
std::lock_guard<std::mutex> guard{lock_};
auto key = Key{m, std::this_thread::get_id()};
CHECK(container_.find(key) != container_.cend());
CHECK(!container_.at(key).ref.expired());
return container_.at(key).value;
}
};
} // namespace xgboost
#endif // XGBOOST_CACHE_H_