Use a thread pool for external memory. (#10288)

This commit is contained in:
Jiaming Yuan 2024-05-16 19:32:12 +08:00 committed by GitHub
parent ee2afb3256
commit 835e59e538
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 157 additions and 5 deletions

96
src/common/threadpool.h Normal file
View File

@ -0,0 +1,96 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#pragma once
#include <condition_variable> // for condition_variable
#include <cstdint> // for int32_t
#include <functional> // for function
#include <future> // for promise
#include <memory> // for make_shared
#include <mutex> // for mutex, unique_lock
#include <queue> // for queue
#include <thread> // for thread
#include <type_traits> // for invoke_result_t
#include <utility> // for move
#include <vector> // for vector
namespace xgboost::common {
/**
* @brief Simple implementation of a thread pool.
*/
class ThreadPool {
std::mutex mu_;
std::queue<std::function<void()>> tasks_;
std::condition_variable cv_;
std::vector<std::thread> pool_;
bool stop_{false};
public:
explicit ThreadPool(std::int32_t n_threads) {
for (std::int32_t i = 0; i < n_threads; ++i) {
pool_.emplace_back([&] {
while (true) {
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
if (this->stop_) {
if (!tasks_.empty()) {
while (!tasks_.empty()) {
auto fn = tasks_.front();
tasks_.pop();
fn();
}
}
return;
}
auto fn = tasks_.front();
tasks_.pop();
lock.unlock();
fn();
}
});
}
}
~ThreadPool() {
std::unique_lock lock{mu_};
stop_ = true;
lock.unlock();
for (auto& t : pool_) {
if (t.joinable()) {
std::unique_lock lock{mu_};
this->cv_.notify_one();
lock.unlock();
}
}
for (auto& t : pool_) {
if (t.joinable()) {
t.join();
}
}
}
/**
* @brief Submit a function that doesn't take any argument.
*/
template <typename Fn, typename R = std::invoke_result_t<Fn>>
auto Submit(Fn&& fn) {
// Use shared ptr to make the task copy constructible.
auto p{std::make_shared<std::promise<R>>()};
auto fut = p->get_future();
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
task->set_value(fn());
}};
std::unique_lock lock{mu_};
this->tasks_.push(std::move(ffn));
lock.unlock();
cv_.notify_one();
return fut;
}
};
} // namespace xgboost::common

View File

@ -20,6 +20,7 @@
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
#include "../common/io.h" // for PrivateMmapConstStream #include "../common/io.h" // for PrivateMmapConstStream
#include "../common/threadpool.h" // for ThreadPool
#include "../common/timer.h" // for Monitor, Timer #include "../common/timer.h" // for Monitor, Timer
#include "proxy_dmatrix.h" // for DMatrixProxy #include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat #include "sparse_page_writer.h" // for SparsePageFormat
@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
std::mutex single_threaded_; std::mutex single_threaded_;
// The current page. // The current page.
std::shared_ptr<S> page_; std::shared_ptr<S> page_;
// Workers for fetching data from external memory.
common::ThreadPool workers_;
bool at_end_ {false}; bool at_end_ {false};
float missing_; float missing_;
@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
std::shared_ptr<Cache> cache_info_; std::shared_ptr<Cache> cache_info_;
using Ring = std::vector<std::future<std::shared_ptr<S>>>; using Ring = std::vector<std::future<std::shared_ptr<S>>>;
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we // A ring storing futures to data. Since the DMatrix iterator is forward only, we can
// can pre-fetch data in a ring. // pre-fetch data in a ring.
std::unique_ptr<Ring> ring_{new Ring}; std::unique_ptr<Ring> ring_{new Ring};
// Catching exception in pre-fetch threads to prevent segfault. Not always work though, // Catching exception in pre-fetch threads to prevent segfault. Not always work though,
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then // OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
} }
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam // An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed. // to let user adjust number of pre-fetched batches when needed.
std::int32_t n_prefetches = std::max(nthreads_, 3); std::int32_t kPrefetches = 3;
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
n_prefetches = std::max(n_prefetches, 1);
std::int32_t n_prefetch_batches = std::int32_t n_prefetch_batches =
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_); std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
CHECK_LE(n_prefetch_batches, kPrefetches);
std::size_t fetch_it = count_; std::size_t fetch_it = count_;
exce_.Rethrow(); exce_.Rethrow();
@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
} }
auto const* self = this; // make sure it's const auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size()); CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() { ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
*GlobalConfigThreadLocalStore::Get() = config; *GlobalConfigThreadLocalStore::Get() = config;
auto page = std::make_shared<S>(); auto page = std::make_shared<S>();
this->exce_.Run([&] { this->exce_.Run([&] {
@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
public: public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache) std::shared_ptr<Cache> cache)
: missing_{missing}, : workers_{nthreads},
missing_{missing},
nthreads_{nthreads}, nthreads_{nthreads},
n_features_{n_features}, n_features_{n_features},
n_batches_{n_batches}, n_batches_{n_batches},

View File

@ -0,0 +1,49 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <future> // for future
#include <thread> // for sleep_for, thread
#include "../../../src/common/threadpool.h"
namespace xgboost::common {
TEST(ThreadPool, Basic) {
std::int32_t n_threads = std::thread::hardware_concurrency();
ThreadPool pool{n_threads};
{
auto fut = pool.Submit([] { return 3; });
ASSERT_EQ(fut.get(), 3);
}
{
auto fut = pool.Submit([] { return std::string{"ok"}; });
ASSERT_EQ(fut.get(), "ok");
}
{
std::vector<std::future<std::size_t>> futures;
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
futures.emplace_back(pool.Submit([=] {
std::this_thread::sleep_for(std::chrono::milliseconds{10});
return i;
}));
}
for (std::size_t i = 0; i < futures.size(); ++i) {
ASSERT_EQ(futures[i].get(), i);
}
}
{
std::vector<std::future<std::size_t>> futures;
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
futures.emplace_back(pool.Submit([=] {
return i;
}));
}
for (std::size_t i = 0; i < futures.size(); ++i) {
ASSERT_EQ(futures[i].get(), i);
}
}
}
} // namespace xgboost::common