diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 142c33a57..423274c50 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -6,16 +6,18 @@ #include // CHECK_EQ -#include // std::size_t -#include // std::weak_ptr,std::shared_ptr,std::make_shared -#include // std:queue -#include // std::unordered_map -#include // std::vector +#include // for size_t +#include // for weak_ptr, shared_ptr, make_shared +#include // for mutex, lock_guard +#include // for queue +#include // for thread +#include // for unordered_map +#include // for vector namespace xgboost { class DMatrix; /** - * \brief FIFO cache for DMatrix related data. + * \brief Thread-aware FIFO cache for DMatrix related data. * * \tparam CacheT The type that needs to be cached. */ @@ -34,9 +36,31 @@ class DMatrixCache { static constexpr std::size_t DefaultSize() { return 32; } + private: + mutable std::mutex lock_; + protected: - std::unordered_map container_; - std::queue queue_; + struct Key { + DMatrix const* ptr; + std::thread::id const thread_id; + + bool operator==(Key const& that) const { + return ptr == that.ptr && thread_id == that.thread_id; + } + }; + struct Hash { + std::size_t operator()(Key const& key) const noexcept { + std::size_t f = std::hash()(key.ptr); + std::size_t s = std::hash()(key.thread_id); + if (f == s) { + return f; + } + return f ^ s; + } + }; + + std::unordered_map container_; + std::queue queue_; std::size_t max_size_; void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); } @@ -44,8 +68,8 @@ class DMatrixCache { void ClearExpired() { // Clear expired entries this->CheckConsistent(); - std::vector expired; - std::queue remained; + std::vector expired; + std::queue remained; while (!queue_.empty()) { auto p_fmat = queue_.front(); @@ -61,8 +85,8 @@ class DMatrixCache { CHECK(queue_.empty()); CHECK_EQ(remained.size() + expired.size(), container_.size()); - for (auto const* p_fmat : expired) { - container_.erase(p_fmat); + for (auto const& key : expired) { + container_.erase(key); } while (!remained.empty()) { auto p_fmat = remained.front(); @@ -74,7 +98,9 @@ class DMatrixCache { void ClearExcess() { this->CheckConsistent(); - while (queue_.size() >= max_size_) { + // clear half of the entries to prevent repeatingly clearing cache. + std::size_t half_size = max_size_ / 2; + while (queue_.size() >= half_size && !queue_.empty()) { auto p_fmat = queue_.front(); queue_.pop(); container_.erase(p_fmat); @@ -88,7 +114,7 @@ class DMatrixCache { */ explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {} /** - * \brief Cache a new DMatrix if it's no in the cache already. + * \brief Cache a new DMatrix if it's not in the cache already. * * Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the * entry this shared pointer is necessary. More importantly, the life time of this @@ -101,35 +127,42 @@ class DMatrixCache { * created. */ template - std::shared_ptr& CacheItem(std::shared_ptr m, Args const&... args) { + std::shared_ptr CacheItem(std::shared_ptr m, Args const&... args) { CHECK(m); + std::lock_guard guard{lock_}; + this->ClearExpired(); if (container_.size() >= max_size_) { this->ClearExcess(); } // after clear, cache size < max_size CHECK_LT(container_.size(), max_size_); - auto it = container_.find(m.get()); + auto key = Key{m.get(), std::this_thread::get_id()}; + auto it = container_.find(key); if (it == container_.cend()) { // after the new DMatrix, cache size is at most max_size - container_[m.get()] = {m, std::make_shared(args...)}; - queue_.push(m.get()); + container_[key] = {m, std::make_shared(args...)}; + queue_.emplace(key); } - return container_.at(m.get()).value; + return container_.at(key).value; } /** * \brief Get a const reference to the underlying hash map. Clear expired caches before * returning. */ decltype(container_) const& Container() { + std::lock_guard guard{lock_}; + this->ClearExpired(); return container_; } std::shared_ptr Entry(DMatrix const* m) const { - CHECK(container_.find(m) != container_.cend()); - CHECK(!container_.at(m).ref.expired()); - return container_.at(m).value; + std::lock_guard guard{lock_}; + auto key = Key{m, std::this_thread::get_id()}; + CHECK(container_.find(key) != container_.cend()); + CHECK(!container_.at(key).ref.expired()); + return container_.at(key).value; } }; } // namespace xgboost diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 438c23465..50665341a 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -14,6 +14,8 @@ #include // std::function #include #include +#include // for get_id +#include // for make_pair #include // Forward declarations @@ -48,18 +50,17 @@ struct PredictionCacheEntry { * \brief A container for managed prediction caches. */ class PredictionContainer : public DMatrixCache { - // we cache up to 32 DMatrix - std::size_t static constexpr DefaultSize() { return 32; } + // We cache up to 64 DMatrix for all threads + std::size_t static constexpr DefaultSize() { return 64; } public: PredictionContainer() : DMatrixCache{DefaultSize()} {} - PredictionCacheEntry& Cache(std::shared_ptr m, int32_t device) { - this->CacheItem(m); - auto p_cache = this->container_.find(m.get()); + PredictionCacheEntry& Cache(std::shared_ptr m, std::int32_t device) { + auto p_cache = this->CacheItem(m); if (device != Context::kCpuId) { - p_cache->second.Value().predictions.SetDevice(device); + p_cache->predictions.SetDevice(device); } - return p_cache->second.Value(); + return *p_cache; } }; diff --git a/src/learner.cc b/src/learner.cc index 390889e9c..dfcab281d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -328,9 +328,6 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam); using LearnerAPIThreadLocalStore = dmlc::ThreadLocalStore>; -using ThreadLocalPredictionCache = - dmlc::ThreadLocalStore>; - namespace { StringView ModelMsg() { return StringView{ @@ -368,6 +365,8 @@ class LearnerConfiguration : public Learner { LearnerModelParam learner_model_param_; LearnerTrainParam tparam_; // Initial prediction. + PredictionContainer prediction_container_; + std::vector metric_names_; void ConfigureModelParamWithoutBaseScore() { @@ -426,22 +425,15 @@ class LearnerConfiguration : public Learner { } public: - explicit LearnerConfiguration(std::vector > cache) + explicit LearnerConfiguration(std::vector> cache) : need_configuration_{true} { monitor_.Init("Learner"); - auto& local_cache = (*ThreadLocalPredictionCache::Get())[this]; for (std::shared_ptr const& d : cache) { if (d) { - local_cache.Cache(d, Context::kCpuId); + prediction_container_.Cache(d, Context::kCpuId); } } } - ~LearnerConfiguration() override { - auto local_cache = ThreadLocalPredictionCache::Get(); - if (local_cache->find(this) != local_cache->cend()) { - local_cache->erase(this); - } - } // Configuration before data is known. void Configure() override { @@ -499,10 +491,6 @@ class LearnerConfiguration : public Learner { CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted(); } - virtual PredictionContainer* GetPredictionCache() const { - return &((*ThreadLocalPredictionCache::Get())[this]); - } - void LoadConfig(Json const& in) override { // If configuration is loaded, ensure that the model came from the same version CHECK(IsA(in)); @@ -741,11 +729,10 @@ class LearnerConfiguration : public Learner { if (mparam_.num_feature == 0) { // TODO(hcho3): Change num_feature to 64-bit integer unsigned num_feature = 0; - auto local_cache = this->GetPredictionCache(); - for (auto& matrix : local_cache->Container()) { - CHECK(matrix.first); + for (auto const& matrix : prediction_container_.Container()) { + CHECK(matrix.first.ptr); CHECK(!matrix.second.ref.expired()); - const uint64_t num_col = matrix.first->Info().num_col_; + const uint64_t num_col = matrix.first.ptr->Info().num_col_; CHECK_LE(num_col, static_cast(std::numeric_limits::max())) << "Unfortunately, XGBoost does not support data matrices with " << std::numeric_limits::max() << " features or greater"; @@ -817,13 +804,13 @@ class LearnerConfiguration : public Learner { */ void ConfigureTargets() { CHECK(this->obj_); - auto const& cache = this->GetPredictionCache()->Container(); + auto const& cache = prediction_container_.Container(); size_t n_targets = 1; for (auto const& d : cache) { if (n_targets == 1) { - n_targets = this->obj_->Targets(d.first->Info()); + n_targets = this->obj_->Targets(d.first.ptr->Info()); } else { - auto t = this->obj_->Targets(d.first->Info()); + auto t = this->obj_->Targets(d.first.ptr->Info()); CHECK(n_targets == t || 1 == t) << "Inconsistent labels."; } } @@ -1275,8 +1262,7 @@ class LearnerImpl : public LearnerIO { this->ValidateDMatrix(train.get(), true); - auto local_cache = this->GetPredictionCache(); - auto& predt = local_cache->Cache(train, ctx_.gpu_id); + auto& predt = prediction_container_.Cache(train, ctx_.gpu_id); monitor_.Start("PredictRaw"); this->PredictRaw(train.get(), &predt, true, 0, 0); @@ -1303,8 +1289,7 @@ class LearnerImpl : public LearnerIO { this->ValidateDMatrix(train.get(), true); - auto local_cache = this->GetPredictionCache(); - auto& predt = local_cache->Cache(train, ctx_.gpu_id); + auto& predt = prediction_container_.Cache(train, ctx_.gpu_id); gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get()); monitor_.Stop("BoostOneIter"); } @@ -1326,10 +1311,9 @@ class LearnerImpl : public LearnerIO { metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); } - auto local_cache = this->GetPredictionCache(); for (size_t i = 0; i < data_sets.size(); ++i) { std::shared_ptr m = data_sets[i]; - auto &predt = local_cache->Cache(m, ctx_.gpu_id); + auto &predt = prediction_container_.Cache(m, ctx_.gpu_id); this->ValidateDMatrix(m.get(), false); this->PredictRaw(m.get(), &predt, false, 0, 0); @@ -1370,8 +1354,7 @@ class LearnerImpl : public LearnerIO { } else if (pred_leaf) { gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end); } else { - auto local_cache = this->GetPredictionCache(); - auto& prediction = local_cache->Cache(data, ctx_.gpu_id); + auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id); this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end); // Copy the prediction cache to output prediction. out_preds comes from C API out_preds->SetDevice(ctx_.gpu_id); diff --git a/tests/cpp/test_cache.cc b/tests/cpp/test_cache.cc index 4099fa2de..351730181 100644 --- a/tests/cpp/test_cache.cc +++ b/tests/cpp/test_cache.cc @@ -3,16 +3,18 @@ */ #include #include -#include // DMatrix +#include // for DMatrix -#include // std::size_t +#include // for size_t +#include // for uint32_t +#include // for thread -#include "helpers.h" // RandomDataGenerator +#include "helpers.h" // for RandomDataGenerator namespace xgboost { namespace { struct CacheForTest { - std::size_t i; + std::size_t const i; explicit CacheForTest(std::size_t k) : i{k} {} }; @@ -20,7 +22,7 @@ struct CacheForTest { TEST(DMatrixCache, Basic) { std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4; - DMatrixCache cache(kCacheSize); + DMatrixCache cache{kCacheSize}; auto add_cache = [&]() { // Create a lambda function here, so that p_fmat gets deleted upon the @@ -52,4 +54,63 @@ TEST(DMatrixCache, Basic) { } } } + +TEST(DMatrixCache, MultiThread) { + std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3; + auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + + auto n = std::thread::hardware_concurrency() * 128u; + CHECK_NE(n, 0); + std::vector> results(n); + + { + DMatrixCache cache{kCacheSize}; + std::vector tasks; + for (std::uint32_t tidx = 0; tidx < n; ++tidx) { + tasks.emplace_back([&, i = tidx]() { + cache.CacheItem(p_fmat, i); + + auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + results[i] = cache.CacheItem(p_fmat_local, i); + }); + } + for (auto& t : tasks) { + t.join(); + } + for (std::uint32_t tidx = 0; tidx < n; ++tidx) { + ASSERT_EQ(results[tidx]->i, tidx); + } + + tasks.clear(); + + for (std::int32_t tidx = static_cast(n - 1); tidx >= 0; --tidx) { + tasks.emplace_back([&, i = tidx]() { + cache.CacheItem(p_fmat, i); + + auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + results[i] = cache.CacheItem(p_fmat_local, i); + }); + } + for (auto& t : tasks) { + t.join(); + } + for (std::uint32_t tidx = 0; tidx < n; ++tidx) { + ASSERT_EQ(results[tidx]->i, tidx); + } + } + + { + DMatrixCache cache{n}; + std::vector tasks; + for (std::uint32_t tidx = 0; tidx < n; ++tidx) { + tasks.emplace_back([&, tidx]() { results[tidx] = cache.CacheItem(p_fmat, tidx); }); + } + for (auto& t : tasks) { + t.join(); + } + for (std::uint32_t tidx = 0; tidx < n; ++tidx) { + ASSERT_EQ(results[tidx]->i, tidx); + } + } +} } // namespace xgboost