Generalize prediction cache. (#8783)

* Extract most of the functionality into `DMatrixCache`.
* Move API entry to independent file to reduce dependency on `predictor.h` file.
* Add test.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2023-02-13 12:36:43 +08:00 committed by GitHub
parent ed91e775ec
commit d11a0044cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 278 additions and 126 deletions

134
include/xgboost/cache.h Normal file
View File

@ -0,0 +1,134 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#ifndef XGBOOST_CACHE_H_
#define XGBOOST_CACHE_H_
#include <xgboost/logging.h> // CHECK_EQ
#include <cstddef> // std::size_t
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
#include <queue> // std:queue
#include <unordered_map> // std::unordered_map
#include <vector> // std::vector
namespace xgboost {
class DMatrix;
/**
* \brief FIFO cache for DMatrix related data.
*
* \tparam CacheT The type that needs to be cached.
*/
template <typename CacheT>
class DMatrixCache {
public:
struct Item {
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr<DMatrix> ref;
// The cached item
std::shared_ptr<CacheT> value;
CacheT const& Value() const { return *value; }
CacheT& Value() { return *value; }
};
protected:
std::unordered_map<DMatrix const*, Item> container_;
std::queue<DMatrix const*> queue_;
std::size_t max_size_;
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
void ClearExpired() {
// Clear expired entries
this->CheckConsistent();
std::vector<DMatrix const*> expired;
std::queue<DMatrix const*> remained;
while (!queue_.empty()) {
auto p_fmat = queue_.front();
auto it = container_.find(p_fmat);
CHECK(it != container_.cend());
if (it->second.ref.expired()) {
expired.push_back(it->first);
} else {
remained.push(it->first);
}
queue_.pop();
}
CHECK(queue_.empty());
CHECK_EQ(remained.size() + expired.size(), container_.size());
for (auto const* p_fmat : expired) {
container_.erase(p_fmat);
}
while (!remained.empty()) {
auto p_fmat = remained.front();
queue_.push(p_fmat);
remained.pop();
}
this->CheckConsistent();
}
void ClearExcess() {
this->CheckConsistent();
while (queue_.size() >= max_size_) {
auto p_fmat = queue_.front();
queue_.pop();
container_.erase(p_fmat);
}
this->CheckConsistent();
}
public:
/**
* \param cache_size Maximum size of the cache.
*/
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
/**
* \brief Cache a new DMatrix if it's no in the cache already.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param args The arguments for constructing a new cache item, if needed.
*
* \return The cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
template <typename... Args>
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
CHECK(m);
this->ClearExpired();
if (container_.size() >= max_size_) {
this->ClearExcess();
}
// after clear, cache size < max_size
CHECK_LT(container_.size(), max_size_);
auto it = container_.find(m.get());
if (it == container_.cend()) {
// after the new DMatrix, cache size is at most max_size
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
queue_.push(m.get());
}
return container_.at(m.get()).value;
}
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container() {
this->ClearExpired();
return container_;
}
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
CHECK(container_.find(m) != container_.cend());
CHECK(!container_.at(m).ref.expired());
return container_.at(m).value;
}
};
} // namespace xgboost
#endif // XGBOOST_CACHE_H_

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2014-2022 by XGBoost Contributors * Copyright 2014-2023 by XGBoost Contributors
* \file gbm.h * \file gbm.h
* \brief Interface of gradient booster, * \brief Interface of gradient booster,
* that learns through gradient statistics. * that learns through gradient statistics.
@ -31,7 +31,6 @@ class ObjFunction;
struct Context; struct Context;
struct LearnerModelParam; struct LearnerModelParam;
struct PredictionCacheEntry; struct PredictionCacheEntry;
class PredictionContainer;
/*! /*!
* \brief interface of gradient boosting model. * \brief interface of gradient boosting model.

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2015-2022 by XGBoost Contributors * Copyright 2015-2023 by XGBoost Contributors
* \file learner.h * \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together. * \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module. * This is the user facing XGBoost training module.
@ -8,12 +8,13 @@
#ifndef XGBOOST_LEARNER_H_ #ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_ #define XGBOOST_LEARNER_H_
#include <dmlc/io.h> // Serializable
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/context.h> // Context #include <xgboost/context.h> // Context
#include <xgboost/feature_map.h> #include <xgboost/feature_map.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h> // Tensor
#include <xgboost/model.h> #include <xgboost/model.h>
#include <xgboost/predictor.h>
#include <xgboost/task.h> #include <xgboost/task.h>
#include <map> #include <map>
@ -29,6 +30,7 @@ class GradientBooster;
class ObjFunction; class ObjFunction;
class DMatrix; class DMatrix;
class Json; class Json;
struct XGBAPIThreadLocalEntry;
enum class PredictionType : std::uint8_t { // NOLINT enum class PredictionType : std::uint8_t { // NOLINT
kValue = 0, kValue = 0,
@ -40,26 +42,6 @@ enum class PredictionType : std::uint8_t { // NOLINT
kLeaf = 6 kLeaf = 6
}; };
/*! \brief entry to to easily hold returning information */
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning raw buffer */
std::vector<char> ret_char_vec;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returning prediction shape. */
std::vector<bst_ulong> prediction_shape;
};
/*! /*!
* \brief Learner class that does training and prediction. * \brief Learner class that does training and prediction.
* This is the user facing module of xgboost training. * This is the user facing module of xgboost training.

View File

@ -1,95 +1,66 @@
/*! /**
* Copyright 2017-2022 by Contributors * Copyright 2017-2023 by Contributors
* \file predictor.h * \file predictor.h
* \brief Interface of predictor, * \brief Interface of predictor,
* performs predictions for a gradient booster. * performs predictions for a gradient booster.
*/ */
#pragma once #pragma once
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/cache.h> // DMatrixCache
#include <xgboost/context.h> #include <xgboost/context.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <functional> #include <functional> // std::function
#include <memory> #include <memory>
#include <mutex>
#include <string> #include <string>
#include <unordered_map>
#include <utility>
#include <vector> #include <vector>
// Forward declarations // Forward declarations
namespace xgboost { namespace xgboost {
class TreeUpdater;
namespace gbm { namespace gbm {
struct GBTreeModel; struct GBTreeModel;
} // namespace gbm } // namespace gbm
} } // namespace xgboost
namespace xgboost { namespace xgboost {
/** /**
* \struct PredictionCacheEntry
*
* \brief Contains pointer to input matrix and associated cached predictions. * \brief Contains pointer to input matrix and associated cached predictions.
*/ */
struct PredictionCacheEntry { struct PredictionCacheEntry {
// A storage for caching prediction values // A storage for caching prediction values
HostDeviceVector<bst_float> predictions; HostDeviceVector<bst_float> predictions;
// The version of current cache, corresponding number of layers of trees // The version of current cache, corresponding number of layers of trees
uint32_t version { 0 }; std::uint32_t version{0};
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr< DMatrix > ref;
PredictionCacheEntry() = default; PredictionCacheEntry() = default;
/* \brief Update the cache entry by number of versions. /**
* \brief Update the cache entry by number of versions.
* *
* \param v Added versions. * \param v Added versions.
*/ */
void Update(uint32_t v) { void Update(std::uint32_t v) {
version += v; version += v;
} }
}; };
/* \brief A container for managed prediction caches. /**
* \brief A container for managed prediction caches.
*/ */
class PredictionContainer { class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
std::unordered_map<DMatrix *, PredictionCacheEntry> container_; // we cache up to 32 DMatrix
void ClearExpiredEntries(); std::size_t static constexpr DefaultSize() { return 32; }
public: public:
PredictionContainer() = default; PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
* all expired caches by checking the `std::weak_ptr`. Caching an existing this->CacheItem(m);
* DMatrix won't renew it. auto p_cache = this->container_.find(m.get());
* if (device != Context::kCpuId) {
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the p_cache->second.Value().predictions.SetDevice(device);
* entry this shared pointer is necessary. More importantly, the life time of this }
* cache is tied to the shared pointer. return p_cache->second.Value();
* }
* Another way to make a safe cache is create a proxy to this entry, with anther shared
* pointer defined inside, and pass this proxy around instead of the real entry. But
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param device Which device should the cache be allocated on. Pass
* Context::kCpuId for CPU or positive integer for GPU id.
*
* \return the cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
* method. Otherwise a dmlc::Error is thrown.
*
* \param m pointer to the DMatrix.
* \return The prediction cache for passed in DMatrix.
*/
PredictionCacheEntry& Entry(DMatrix* m);
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container();
}; };
/** /**
@ -114,7 +85,7 @@ class Predictor {
* *
* \param cfg The configuration. * \param cfg The configuration.
*/ */
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&); virtual void Configure(Args const&);
/** /**
* \brief Initialize output prediction * \brief Initialize output prediction

View File

@ -12,11 +12,11 @@
#include <vector> #include <vector>
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/charconv.h" #include "../common/charconv.h"
#include "../common/io.h" #include "../common/io.h"
#include "../data/adapter.h" #include "../data/adapter.h"
#include "../data/simple_dmatrix.h" #include "../data/simple_dmatrix.h"
#include "c_api_error.h"
#include "c_api_utils.h" #include "c_api_utils.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"

View File

@ -1,6 +1,7 @@
/** /**
* Copyright 2019-2023 by XGBoost Contributors * Copyright 2019-2023 by XGBoost Contributors
*/ */
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../data/device_adapter.cuh" #include "../data/device_adapter.cuh"
#include "../data/proxy_dmatrix.h" #include "../data/proxy_dmatrix.h"

35
src/common/api_entry.h Normal file
View File

@ -0,0 +1,35 @@
/**
* Copyright 2016-2023 by XGBoost contributors
*/
#ifndef XGBOOST_COMMON_API_ENTRY_H_
#define XGBOOST_COMMON_API_ENTRY_H_
#include <string> // std::string
#include <vector> // std::vector
#include "xgboost/base.h" // GradientPair,bst_ulong
#include "xgboost/predictor.h" // PredictionCacheEntry
namespace xgboost {
/**
* \brief entry to to easily hold returning information
*/
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning raw buffer */
std::vector<char> ret_char_vec;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returning prediction shape. */
std::vector<bst_ulong> prediction_shape;
};
} // namespace xgboost
#endif // XGBOOST_COMMON_API_ENTRY_H_

View File

@ -10,6 +10,7 @@
#include <cstring> #include <cstring>
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/group_data.h" #include "../common/group_data.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/linalg_op.h" #include "../common/linalg_op.h"

View File

@ -25,6 +25,7 @@
#include <vector> #include <vector>
#include "collective/communicator-inl.h" #include "collective/communicator-inl.h"
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
#include "common/charconv.h" #include "common/charconv.h"
#include "common/common.h" #include "common/common.h"
#include "common/io.h" #include "common/io.h"
@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
monitor_.Init("Learner"); monitor_.Init("Learner");
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this]; auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
for (std::shared_ptr<DMatrix> const& d : cache) { for (std::shared_ptr<DMatrix> const& d : cache) {
local_cache.Cache(d, Context::kCpuId); if (d) {
local_cache.Cache(d, Context::kCpuId);
}
} }
} }
~LearnerConfiguration() override { ~LearnerConfiguration() override {
@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
local_cache->Cache(train, ctx_.gpu_id); auto& predt = local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
monitor_.Stop("BoostOneIter"); monitor_.Stop("BoostOneIter");
} }

View File

@ -1,57 +1,28 @@
/*! /**
* Copyright 2017-2021 by Contributors * Copyright 2017-2023 by Contributors
*/ */
#include "xgboost/predictor.h" #include "xgboost/predictor.h"
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <mutex> #include <string> // std::string
#include "../gbm/gbtree.h" #include "../gbm/gbtree.h" // GBTreeModel
#include "xgboost/context.h" #include "xgboost/base.h" // bst_row_t,bst_group_t
#include "xgboost/data.h" #include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/learner.h" // LearnerModelParam
#include "xgboost/linalg.h" // Tensor
#include "xgboost/logging.h"
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
} // namespace dmlc } // namespace dmlc
namespace xgboost { namespace xgboost {
void PredictionContainer::ClearExpiredEntries() { void Predictor::Configure(Args const&) {}
std::vector<DMatrix*> expired;
for (auto& kv : container_) {
if (kv.second.ref.expired()) {
expired.emplace_back(kv.first);
}
}
for (auto const& ptr : expired) {
container_.erase(ptr);
}
}
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->ClearExpiredEntries();
container_[m.get()].ref = m;
if (device != Context::kCpuId) {
container_[m.get()].predictions.SetDevice(device);
}
return container_[m.get()];
}
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
CHECK(container_.find(m) != container_.cend());
CHECK(container_.at(m).ref.lock())
<< "[Internal error]: DMatrix: " << m << " has expired.";
return container_.at(m);
}
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
this->ClearExpiredEntries();
return container_;
}
void Predictor::Configure(
const std::vector<std::pair<std::string, std::string>>&) {
}
Predictor* Predictor::Create(std::string const& name, Context const* ctx) { Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name); auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
if (e == nullptr) { if (e == nullptr) {

55
tests/cpp/test_cache.cc Normal file
View File

@ -0,0 +1,55 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/cache.h>
#include <xgboost/data.h> // DMatrix
#include <cstddef> // std::size_t
#include "helpers.h" // RandomDataGenerator
namespace xgboost {
namespace {
struct CacheForTest {
std::size_t i;
explicit CacheForTest(std::size_t k) : i{k} {}
};
} // namespace
TEST(DMatrixCache, Basic) {
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
DMatrixCache<CacheForTest> cache(kCacheSize);
auto add_cache = [&]() {
// Create a lambda function here, so that p_fmat gets deleted upon the
// end of the lambda. This is to test how the cache handle expired
// cache entries.
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
cache.CacheItem(p_fmat, 3);
DMatrix* m = p_fmat.get();
return m;
};
auto m = add_cache();
ASSERT_EQ(cache.Container().size(), 0);
ASSERT_THROW(cache.Entry(m), dmlc::Error);
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto item = cache.CacheItem(p_fmat, 1);
ASSERT_EQ(cache.Entry(p_fmat.get())->i, 1);
std::vector<std::shared_ptr<DMatrix>> items;
for (std::size_t i = 0; i < kCacheSize * 2; ++i) {
items.emplace_back(RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix());
cache.CacheItem(items.back(), i);
ASSERT_EQ(cache.Entry(items.back().get())->i, i);
ASSERT_LE(cache.Container().size(), kCacheSize);
if (i > kCacheSize) {
auto k = i - kCacheSize - 1;
ASSERT_THROW(cache.Entry(items[k].get()), dmlc::Error);
}
}
}
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2022 by XGBoost contributors * Copyright 2017-2023 by XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/learner.h> #include <xgboost/learner.h>
@ -10,6 +10,7 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "../../src/common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../../src/common/io.h" #include "../../src/common/io.h"
#include "../../src/common/linalg_op.h" #include "../../src/common/linalg_op.h"
#include "../../src/common/random.h" #include "../../src/common/random.h"