Use a thread pool for external memory. (#10288)
This commit is contained in:
parent
ee2afb3256
commit
835e59e538
96
src/common/threadpool.h
Normal file
96
src/common/threadpool.h
Normal file
@ -0,0 +1,96 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable> // for condition_variable
|
||||
#include <cstdint> // for int32_t
|
||||
#include <functional> // for function
|
||||
#include <future> // for promise
|
||||
#include <memory> // for make_shared
|
||||
#include <mutex> // for mutex, unique_lock
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for invoke_result_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost::common {
|
||||
/**
|
||||
* @brief Simple implementation of a thread pool.
|
||||
*/
|
||||
class ThreadPool {
|
||||
std::mutex mu_;
|
||||
std::queue<std::function<void()>> tasks_;
|
||||
std::condition_variable cv_;
|
||||
std::vector<std::thread> pool_;
|
||||
bool stop_{false};
|
||||
|
||||
public:
|
||||
explicit ThreadPool(std::int32_t n_threads) {
|
||||
for (std::int32_t i = 0; i < n_threads; ++i) {
|
||||
pool_.emplace_back([&] {
|
||||
while (true) {
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
|
||||
|
||||
if (this->stop_) {
|
||||
if (!tasks_.empty()) {
|
||||
while (!tasks_.empty()) {
|
||||
auto fn = tasks_.front();
|
||||
tasks_.pop();
|
||||
fn();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto fn = tasks_.front();
|
||||
tasks_.pop();
|
||||
lock.unlock();
|
||||
fn();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
~ThreadPool() {
|
||||
std::unique_lock lock{mu_};
|
||||
stop_ = true;
|
||||
lock.unlock();
|
||||
|
||||
for (auto& t : pool_) {
|
||||
if (t.joinable()) {
|
||||
std::unique_lock lock{mu_};
|
||||
this->cv_.notify_one();
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& t : pool_) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Submit a function that doesn't take any argument.
|
||||
*/
|
||||
template <typename Fn, typename R = std::invoke_result_t<Fn>>
|
||||
auto Submit(Fn&& fn) {
|
||||
// Use shared ptr to make the task copy constructible.
|
||||
auto p{std::make_shared<std::promise<R>>()};
|
||||
auto fut = p->get_future();
|
||||
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
|
||||
task->set_value(fn());
|
||||
}};
|
||||
|
||||
std::unique_lock lock{mu_};
|
||||
this->tasks_.push(std::move(ffn));
|
||||
lock.unlock();
|
||||
|
||||
cv_.notify_one();
|
||||
return fut;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::common
|
||||
@ -20,6 +20,7 @@
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
#include "../common/io.h" // for PrivateMmapConstStream
|
||||
#include "../common/threadpool.h" // for ThreadPool
|
||||
#include "../common/timer.h" // for Monitor, Timer
|
||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||
@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
std::mutex single_threaded_;
|
||||
// The current page.
|
||||
std::shared_ptr<S> page_;
|
||||
// Workers for fetching data from external memory.
|
||||
common::ThreadPool workers_;
|
||||
|
||||
bool at_end_ {false};
|
||||
float missing_;
|
||||
@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
std::shared_ptr<Cache> cache_info_;
|
||||
|
||||
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
|
||||
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
|
||||
// can pre-fetch data in a ring.
|
||||
// A ring storing futures to data. Since the DMatrix iterator is forward only, we can
|
||||
// pre-fetch data in a ring.
|
||||
std::unique_ptr<Ring> ring_{new Ring};
|
||||
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
|
||||
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
|
||||
@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
}
|
||||
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
||||
// to let user adjust number of pre-fetched batches when needed.
|
||||
std::int32_t n_prefetches = std::max(nthreads_, 3);
|
||||
std::int32_t kPrefetches = 3;
|
||||
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
||||
n_prefetches = std::max(n_prefetches, 1);
|
||||
std::int32_t n_prefetch_batches =
|
||||
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||
CHECK_LE(n_prefetch_batches, kPrefetches);
|
||||
std::size_t fetch_it = count_;
|
||||
|
||||
exce_.Rethrow();
|
||||
@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
}
|
||||
auto const* self = this; // make sure it's const
|
||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
|
||||
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
|
||||
*GlobalConfigThreadLocalStore::Get() = config;
|
||||
auto page = std::make_shared<S>();
|
||||
this->exce_.Run([&] {
|
||||
@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
public:
|
||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||
std::shared_ptr<Cache> cache)
|
||||
: missing_{missing},
|
||||
: workers_{nthreads},
|
||||
missing_{missing},
|
||||
nthreads_{nthreads},
|
||||
n_features_{n_features},
|
||||
n_batches_{n_batches},
|
||||
|
||||
49
tests/cpp/common/test_threadpool.cc
Normal file
49
tests/cpp/common/test_threadpool.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
#include <thread> // for sleep_for, thread
|
||||
|
||||
#include "../../../src/common/threadpool.h"
|
||||
|
||||
namespace xgboost::common {
|
||||
TEST(ThreadPool, Basic) {
|
||||
std::int32_t n_threads = std::thread::hardware_concurrency();
|
||||
ThreadPool pool{n_threads};
|
||||
{
|
||||
auto fut = pool.Submit([] { return 3; });
|
||||
ASSERT_EQ(fut.get(), 3);
|
||||
}
|
||||
{
|
||||
auto fut = pool.Submit([] { return std::string{"ok"}; });
|
||||
ASSERT_EQ(fut.get(), "ok");
|
||||
}
|
||||
{
|
||||
std::vector<std::future<std::size_t>> futures;
|
||||
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
|
||||
futures.emplace_back(pool.Submit([=] {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds{10});
|
||||
return i;
|
||||
}));
|
||||
}
|
||||
for (std::size_t i = 0; i < futures.size(); ++i) {
|
||||
ASSERT_EQ(futures[i].get(), i);
|
||||
}
|
||||
}
|
||||
{
|
||||
std::vector<std::future<std::size_t>> futures;
|
||||
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
|
||||
futures.emplace_back(pool.Submit([=] {
|
||||
return i;
|
||||
}));
|
||||
}
|
||||
for (std::size_t i = 0; i < futures.size(); ++i) {
|
||||
ASSERT_EQ(futures[i].get(), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
Loading…
x
Reference in New Issue
Block a user