Fix cache with gc (#8851)

- Make DMatrixCache thread-safe.
- Remove the use of thread-local memory.
This commit is contained in:
Jiaming Yuan 2023-03-01 00:39:06 +08:00 committed by GitHub
parent d9688f93c7
commit d54ef56f6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 143 additions and 65 deletions

View File

@ -6,16 +6,18 @@
#include <xgboost/logging.h> // CHECK_EQ #include <xgboost/logging.h> // CHECK_EQ
#include <cstddef> // std::size_t #include <cstddef> // for size_t
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared #include <memory> // for weak_ptr, shared_ptr, make_shared
#include <queue> // std:queue #include <mutex> // for mutex, lock_guard
#include <unordered_map> // std::unordered_map #include <queue> // for queue
#include <vector> // std::vector #include <thread> // for thread
#include <unordered_map> // for unordered_map
#include <vector> // for vector
namespace xgboost { namespace xgboost {
class DMatrix; class DMatrix;
/** /**
* \brief FIFO cache for DMatrix related data. * \brief Thread-aware FIFO cache for DMatrix related data.
* *
* \tparam CacheT The type that needs to be cached. * \tparam CacheT The type that needs to be cached.
*/ */
@ -34,9 +36,31 @@ class DMatrixCache {
static constexpr std::size_t DefaultSize() { return 32; } static constexpr std::size_t DefaultSize() { return 32; }
private:
mutable std::mutex lock_;
protected: protected:
std::unordered_map<DMatrix const*, Item> container_; struct Key {
std::queue<DMatrix const*> queue_; DMatrix const* ptr;
std::thread::id const thread_id;
bool operator==(Key const& that) const {
return ptr == that.ptr && thread_id == that.thread_id;
}
};
struct Hash {
std::size_t operator()(Key const& key) const noexcept {
std::size_t f = std::hash<DMatrix const*>()(key.ptr);
std::size_t s = std::hash<std::thread::id>()(key.thread_id);
if (f == s) {
return f;
}
return f ^ s;
}
};
std::unordered_map<Key, Item, Hash> container_;
std::queue<Key> queue_;
std::size_t max_size_; std::size_t max_size_;
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); } void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
@ -44,8 +68,8 @@ class DMatrixCache {
void ClearExpired() { void ClearExpired() {
// Clear expired entries // Clear expired entries
this->CheckConsistent(); this->CheckConsistent();
std::vector<DMatrix const*> expired; std::vector<Key> expired;
std::queue<DMatrix const*> remained; std::queue<Key> remained;
while (!queue_.empty()) { while (!queue_.empty()) {
auto p_fmat = queue_.front(); auto p_fmat = queue_.front();
@ -61,8 +85,8 @@ class DMatrixCache {
CHECK(queue_.empty()); CHECK(queue_.empty());
CHECK_EQ(remained.size() + expired.size(), container_.size()); CHECK_EQ(remained.size() + expired.size(), container_.size());
for (auto const* p_fmat : expired) { for (auto const& key : expired) {
container_.erase(p_fmat); container_.erase(key);
} }
while (!remained.empty()) { while (!remained.empty()) {
auto p_fmat = remained.front(); auto p_fmat = remained.front();
@ -74,7 +98,9 @@ class DMatrixCache {
void ClearExcess() { void ClearExcess() {
this->CheckConsistent(); this->CheckConsistent();
while (queue_.size() >= max_size_) { // clear half of the entries to prevent repeatingly clearing cache.
std::size_t half_size = max_size_ / 2;
while (queue_.size() >= half_size && !queue_.empty()) {
auto p_fmat = queue_.front(); auto p_fmat = queue_.front();
queue_.pop(); queue_.pop();
container_.erase(p_fmat); container_.erase(p_fmat);
@ -88,7 +114,7 @@ class DMatrixCache {
*/ */
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {} explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
/** /**
* \brief Cache a new DMatrix if it's no in the cache already. * \brief Cache a new DMatrix if it's not in the cache already.
* *
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the * Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this * entry this shared pointer is necessary. More importantly, the life time of this
@ -101,35 +127,42 @@ class DMatrixCache {
* created. * created.
*/ */
template <typename... Args> template <typename... Args>
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) { std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
CHECK(m); CHECK(m);
std::lock_guard<std::mutex> guard{lock_};
this->ClearExpired(); this->ClearExpired();
if (container_.size() >= max_size_) { if (container_.size() >= max_size_) {
this->ClearExcess(); this->ClearExcess();
} }
// after clear, cache size < max_size // after clear, cache size < max_size
CHECK_LT(container_.size(), max_size_); CHECK_LT(container_.size(), max_size_);
auto it = container_.find(m.get()); auto key = Key{m.get(), std::this_thread::get_id()};
auto it = container_.find(key);
if (it == container_.cend()) { if (it == container_.cend()) {
// after the new DMatrix, cache size is at most max_size // after the new DMatrix, cache size is at most max_size
container_[m.get()] = {m, std::make_shared<CacheT>(args...)}; container_[key] = {m, std::make_shared<CacheT>(args...)};
queue_.push(m.get()); queue_.emplace(key);
} }
return container_.at(m.get()).value; return container_.at(key).value;
} }
/** /**
* \brief Get a const reference to the underlying hash map. Clear expired caches before * \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning. * returning.
*/ */
decltype(container_) const& Container() { decltype(container_) const& Container() {
std::lock_guard<std::mutex> guard{lock_};
this->ClearExpired(); this->ClearExpired();
return container_; return container_;
} }
std::shared_ptr<CacheT> Entry(DMatrix const* m) const { std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
CHECK(container_.find(m) != container_.cend()); std::lock_guard<std::mutex> guard{lock_};
CHECK(!container_.at(m).ref.expired()); auto key = Key{m, std::this_thread::get_id()};
return container_.at(m).value; CHECK(container_.find(key) != container_.cend());
CHECK(!container_.at(key).ref.expired());
return container_.at(key).value;
} }
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -14,6 +14,8 @@
#include <functional> // std::function #include <functional> // std::function
#include <memory> #include <memory>
#include <string> #include <string>
#include <thread> // for get_id
#include <utility> // for make_pair
#include <vector> #include <vector>
// Forward declarations // Forward declarations
@ -48,18 +50,17 @@ struct PredictionCacheEntry {
* \brief A container for managed prediction caches. * \brief A container for managed prediction caches.
*/ */
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> { class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
// we cache up to 32 DMatrix // We cache up to 64 DMatrix for all threads
std::size_t static constexpr DefaultSize() { return 32; } std::size_t static constexpr DefaultSize() { return 64; }
public: public:
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {} PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) { PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, std::int32_t device) {
this->CacheItem(m); auto p_cache = this->CacheItem(m);
auto p_cache = this->container_.find(m.get());
if (device != Context::kCpuId) { if (device != Context::kCpuId) {
p_cache->second.Value().predictions.SetDevice(device); p_cache->predictions.SetDevice(device);
} }
return p_cache->second.Value(); return *p_cache;
} }
}; };

View File

@ -328,9 +328,6 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
using LearnerAPIThreadLocalStore = using LearnerAPIThreadLocalStore =
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>; dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
using ThreadLocalPredictionCache =
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
namespace { namespace {
StringView ModelMsg() { StringView ModelMsg() {
return StringView{ return StringView{
@ -368,6 +365,8 @@ class LearnerConfiguration : public Learner {
LearnerModelParam learner_model_param_; LearnerModelParam learner_model_param_;
LearnerTrainParam tparam_; LearnerTrainParam tparam_;
// Initial prediction. // Initial prediction.
PredictionContainer prediction_container_;
std::vector<std::string> metric_names_; std::vector<std::string> metric_names_;
void ConfigureModelParamWithoutBaseScore() { void ConfigureModelParamWithoutBaseScore() {
@ -426,22 +425,15 @@ class LearnerConfiguration : public Learner {
} }
public: public:
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache) explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix>> cache)
: need_configuration_{true} { : need_configuration_{true} {
monitor_.Init("Learner"); monitor_.Init("Learner");
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
for (std::shared_ptr<DMatrix> const& d : cache) { for (std::shared_ptr<DMatrix> const& d : cache) {
if (d) { if (d) {
local_cache.Cache(d, Context::kCpuId); prediction_container_.Cache(d, Context::kCpuId);
} }
} }
} }
~LearnerConfiguration() override {
auto local_cache = ThreadLocalPredictionCache::Get();
if (local_cache->find(this) != local_cache->cend()) {
local_cache->erase(this);
}
}
// Configuration before data is known. // Configuration before data is known.
void Configure() override { void Configure() override {
@ -499,10 +491,6 @@ class LearnerConfiguration : public Learner {
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted(); CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted();
} }
virtual PredictionContainer* GetPredictionCache() const {
return &((*ThreadLocalPredictionCache::Get())[this]);
}
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
// If configuration is loaded, ensure that the model came from the same version // If configuration is loaded, ensure that the model came from the same version
CHECK(IsA<Object>(in)); CHECK(IsA<Object>(in));
@ -741,11 +729,10 @@ class LearnerConfiguration : public Learner {
if (mparam_.num_feature == 0) { if (mparam_.num_feature == 0) {
// TODO(hcho3): Change num_feature to 64-bit integer // TODO(hcho3): Change num_feature to 64-bit integer
unsigned num_feature = 0; unsigned num_feature = 0;
auto local_cache = this->GetPredictionCache(); for (auto const& matrix : prediction_container_.Container()) {
for (auto& matrix : local_cache->Container()) { CHECK(matrix.first.ptr);
CHECK(matrix.first);
CHECK(!matrix.second.ref.expired()); CHECK(!matrix.second.ref.expired());
const uint64_t num_col = matrix.first->Info().num_col_; const uint64_t num_col = matrix.first.ptr->Info().num_col_;
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max())) CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with " << "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater"; << std::numeric_limits<unsigned>::max() << " features or greater";
@ -817,13 +804,13 @@ class LearnerConfiguration : public Learner {
*/ */
void ConfigureTargets() { void ConfigureTargets() {
CHECK(this->obj_); CHECK(this->obj_);
auto const& cache = this->GetPredictionCache()->Container(); auto const& cache = prediction_container_.Container();
size_t n_targets = 1; size_t n_targets = 1;
for (auto const& d : cache) { for (auto const& d : cache) {
if (n_targets == 1) { if (n_targets == 1) {
n_targets = this->obj_->Targets(d.first->Info()); n_targets = this->obj_->Targets(d.first.ptr->Info());
} else { } else {
auto t = this->obj_->Targets(d.first->Info()); auto t = this->obj_->Targets(d.first.ptr->Info());
CHECK(n_targets == t || 1 == t) << "Inconsistent labels."; CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
} }
} }
@ -1275,8 +1262,7 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true, 0, 0); this->PredictRaw(train.get(), &predt, true, 0, 0);
@ -1303,8 +1289,7 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache(); auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get()); gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
monitor_.Stop("BoostOneIter"); monitor_.Stop("BoostOneIter");
} }
@ -1326,10 +1311,9 @@ class LearnerImpl : public LearnerIO {
metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
} }
auto local_cache = this->GetPredictionCache();
for (size_t i = 0; i < data_sets.size(); ++i) { for (size_t i = 0; i < data_sets.size(); ++i) {
std::shared_ptr<DMatrix> m = data_sets[i]; std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = local_cache->Cache(m, ctx_.gpu_id); auto &predt = prediction_container_.Cache(m, ctx_.gpu_id);
this->ValidateDMatrix(m.get(), false); this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false, 0, 0); this->PredictRaw(m.get(), &predt, false, 0, 0);
@ -1370,8 +1354,7 @@ class LearnerImpl : public LearnerIO {
} else if (pred_leaf) { } else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end); gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
} else { } else {
auto local_cache = this->GetPredictionCache(); auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id);
auto& prediction = local_cache->Cache(data, ctx_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end); this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API // Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(ctx_.gpu_id); out_preds->SetDevice(ctx_.gpu_id);

View File

@ -3,16 +3,18 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/cache.h> #include <xgboost/cache.h>
#include <xgboost/data.h> // DMatrix #include <xgboost/data.h> // for DMatrix
#include <cstddef> // std::size_t #include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <thread> // for thread
#include "helpers.h" // RandomDataGenerator #include "helpers.h" // for RandomDataGenerator
namespace xgboost { namespace xgboost {
namespace { namespace {
struct CacheForTest { struct CacheForTest {
std::size_t i; std::size_t const i;
explicit CacheForTest(std::size_t k) : i{k} {} explicit CacheForTest(std::size_t k) : i{k} {}
}; };
@ -20,7 +22,7 @@ struct CacheForTest {
TEST(DMatrixCache, Basic) { TEST(DMatrixCache, Basic) {
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4; std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
DMatrixCache<CacheForTest> cache(kCacheSize); DMatrixCache<CacheForTest> cache{kCacheSize};
auto add_cache = [&]() { auto add_cache = [&]() {
// Create a lambda function here, so that p_fmat gets deleted upon the // Create a lambda function here, so that p_fmat gets deleted upon the
@ -52,4 +54,63 @@ TEST(DMatrixCache, Basic) {
} }
} }
} }
TEST(DMatrixCache, MultiThread) {
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3;
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto n = std::thread::hardware_concurrency() * 128u;
CHECK_NE(n, 0);
std::vector<std::shared_ptr<CacheForTest>> results(n);
{
DMatrixCache<CacheForTest> cache{kCacheSize};
std::vector<std::thread> tasks;
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
tasks.emplace_back([&, i = tidx]() {
cache.CacheItem(p_fmat, i);
auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
results[i] = cache.CacheItem(p_fmat_local, i);
});
}
for (auto& t : tasks) {
t.join();
}
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
ASSERT_EQ(results[tidx]->i, tidx);
}
tasks.clear();
for (std::int32_t tidx = static_cast<std::int32_t>(n - 1); tidx >= 0; --tidx) {
tasks.emplace_back([&, i = tidx]() {
cache.CacheItem(p_fmat, i);
auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
results[i] = cache.CacheItem(p_fmat_local, i);
});
}
for (auto& t : tasks) {
t.join();
}
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
ASSERT_EQ(results[tidx]->i, tidx);
}
}
{
DMatrixCache<CacheForTest> cache{n};
std::vector<std::thread> tasks;
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
tasks.emplace_back([&, tidx]() { results[tidx] = cache.CacheItem(p_fmat, tidx); });
}
for (auto& t : tasks) {
t.join();
}
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
ASSERT_EQ(results[tidx]->i, tidx);
}
}
}
} // namespace xgboost } // namespace xgboost