From 835e59e5388f60d382f0d183632eec9d9688d5db Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 16 May 2024 19:32:12 +0800 Subject: [PATCH] Use a thread pool for external memory. (#10288) --- src/common/threadpool.h | 96 +++++++++++++++++++++++++++++ src/data/sparse_page_source.h | 17 +++-- tests/cpp/common/test_threadpool.cc | 49 +++++++++++++++ 3 files changed, 157 insertions(+), 5 deletions(-) create mode 100644 src/common/threadpool.h create mode 100644 tests/cpp/common/test_threadpool.cc diff --git a/src/common/threadpool.h b/src/common/threadpool.h new file mode 100644 index 000000000..95d1deaaa --- /dev/null +++ b/src/common/threadpool.h @@ -0,0 +1,96 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#pragma once +#include // for condition_variable +#include // for int32_t +#include // for function +#include // for promise +#include // for make_shared +#include // for mutex, unique_lock +#include // for queue +#include // for thread +#include // for invoke_result_t +#include // for move +#include // for vector + +namespace xgboost::common { +/** + * @brief Simple implementation of a thread pool. + */ +class ThreadPool { + std::mutex mu_; + std::queue> tasks_; + std::condition_variable cv_; + std::vector pool_; + bool stop_{false}; + + public: + explicit ThreadPool(std::int32_t n_threads) { + for (std::int32_t i = 0; i < n_threads; ++i) { + pool_.emplace_back([&] { + while (true) { + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; }); + + if (this->stop_) { + if (!tasks_.empty()) { + while (!tasks_.empty()) { + auto fn = tasks_.front(); + tasks_.pop(); + fn(); + } + } + return; + } + + auto fn = tasks_.front(); + tasks_.pop(); + lock.unlock(); + fn(); + } + }); + } + } + + ~ThreadPool() { + std::unique_lock lock{mu_}; + stop_ = true; + lock.unlock(); + + for (auto& t : pool_) { + if (t.joinable()) { + std::unique_lock lock{mu_}; + this->cv_.notify_one(); + lock.unlock(); + } + } + + for (auto& t : pool_) { + if (t.joinable()) { + t.join(); + } + } + } + + /** + * @brief Submit a function that doesn't take any argument. + */ + template > + auto Submit(Fn&& fn) { + // Use shared ptr to make the task copy constructible. + auto p{std::make_shared>()}; + auto fut = p->get_future(); + auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable { + task->set_value(fn()); + }}; + + std::unique_lock lock{mu_}; + this->tasks_.push(std::move(ffn)); + lock.unlock(); + + cv_.notify_one(); + return fut; + } +}; +} // namespace xgboost::common diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 60129741b..ebb5fdf24 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -20,6 +20,7 @@ #endif // !defined(XGBOOST_USE_CUDA) #include "../common/io.h" // for PrivateMmapConstStream +#include "../common/threadpool.h" // for ThreadPool #include "../common/timer.h" // for Monitor, Timer #include "proxy_dmatrix.h" // for DMatrixProxy #include "sparse_page_writer.h" // for SparsePageFormat @@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { std::mutex single_threaded_; // The current page. std::shared_ptr page_; + // Workers for fetching data from external memory. + common::ThreadPool workers_; bool at_end_ {false}; float missing_; @@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { std::shared_ptr cache_info_; using Ring = std::vector>>; - // A ring storing futures to data. Since the DMatrix iterator is forward only, so we - // can pre-fetch data in a ring. + // A ring storing futures to data. Since the DMatrix iterator is forward only, we can + // pre-fetch data in a ring. std::unique_ptr ring_{new Ring}; // Catching exception in pre-fetch threads to prevent segfault. Not always work though, // OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then @@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl { } // An heuristic for number of pre-fetched batches. We can make it part of BatchParam // to let user adjust number of pre-fetched batches when needed. - std::int32_t n_prefetches = std::max(nthreads_, 3); + std::int32_t kPrefetches = 3; + std::int32_t n_prefetches = std::min(nthreads_, kPrefetches); + n_prefetches = std::max(n_prefetches, 1); std::int32_t n_prefetch_batches = std::min(static_cast(n_prefetches), n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; + CHECK_LE(n_prefetch_batches, kPrefetches); std::size_t fetch_it = count_; exce_.Rethrow(); @@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { } auto const* self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); - ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() { + ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] { *GlobalConfigThreadLocalStore::Get() = config; auto page = std::make_shared(); this->exce_.Run([&] { @@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { public: SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache) - : missing_{missing}, + : workers_{nthreads}, + missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, n_batches_{n_batches}, diff --git a/tests/cpp/common/test_threadpool.cc b/tests/cpp/common/test_threadpool.cc new file mode 100644 index 000000000..bd54a9ded --- /dev/null +++ b/tests/cpp/common/test_threadpool.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include + +#include // for size_t +#include // for int32_t +#include // for future +#include // for sleep_for, thread + +#include "../../../src/common/threadpool.h" + +namespace xgboost::common { +TEST(ThreadPool, Basic) { + std::int32_t n_threads = std::thread::hardware_concurrency(); + ThreadPool pool{n_threads}; + { + auto fut = pool.Submit([] { return 3; }); + ASSERT_EQ(fut.get(), 3); + } + { + auto fut = pool.Submit([] { return std::string{"ok"}; }); + ASSERT_EQ(fut.get(), "ok"); + } + { + std::vector> futures; + for (std::size_t i = 0; i < static_cast(n_threads) * 16; ++i) { + futures.emplace_back(pool.Submit([=] { + std::this_thread::sleep_for(std::chrono::milliseconds{10}); + return i; + })); + } + for (std::size_t i = 0; i < futures.size(); ++i) { + ASSERT_EQ(futures[i].get(), i); + } + } + { + std::vector> futures; + for (std::size_t i = 0; i < static_cast(n_threads) * 16; ++i) { + futures.emplace_back(pool.Submit([=] { + return i; + })); + } + for (std::size_t i = 0; i < futures.size(); ++i) { + ASSERT_EQ(futures[i].get(), i); + } + } +} +} // namespace xgboost::common