Use a thread pool for external memory. (#10288)
This commit is contained in:
parent
ee2afb3256
commit
835e59e538
96
src/common/threadpool.h
Normal file
96
src/common/threadpool.h
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <condition_variable> // for condition_variable
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <functional> // for function
|
||||||
|
#include <future> // for promise
|
||||||
|
#include <memory> // for make_shared
|
||||||
|
#include <mutex> // for mutex, unique_lock
|
||||||
|
#include <queue> // for queue
|
||||||
|
#include <thread> // for thread
|
||||||
|
#include <type_traits> // for invoke_result_t
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
|
/**
|
||||||
|
* @brief Simple implementation of a thread pool.
|
||||||
|
*/
|
||||||
|
class ThreadPool {
|
||||||
|
std::mutex mu_;
|
||||||
|
std::queue<std::function<void()>> tasks_;
|
||||||
|
std::condition_variable cv_;
|
||||||
|
std::vector<std::thread> pool_;
|
||||||
|
bool stop_{false};
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit ThreadPool(std::int32_t n_threads) {
|
||||||
|
for (std::int32_t i = 0; i < n_threads; ++i) {
|
||||||
|
pool_.emplace_back([&] {
|
||||||
|
while (true) {
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
|
||||||
|
|
||||||
|
if (this->stop_) {
|
||||||
|
if (!tasks_.empty()) {
|
||||||
|
while (!tasks_.empty()) {
|
||||||
|
auto fn = tasks_.front();
|
||||||
|
tasks_.pop();
|
||||||
|
fn();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fn = tasks_.front();
|
||||||
|
tasks_.pop();
|
||||||
|
lock.unlock();
|
||||||
|
fn();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~ThreadPool() {
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
stop_ = true;
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
for (auto& t : pool_) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
this->cv_.notify_one();
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& t : pool_) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Submit a function that doesn't take any argument.
|
||||||
|
*/
|
||||||
|
template <typename Fn, typename R = std::invoke_result_t<Fn>>
|
||||||
|
auto Submit(Fn&& fn) {
|
||||||
|
// Use shared ptr to make the task copy constructible.
|
||||||
|
auto p{std::make_shared<std::promise<R>>()};
|
||||||
|
auto fut = p->get_future();
|
||||||
|
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
|
||||||
|
task->set_value(fn());
|
||||||
|
}};
|
||||||
|
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
this->tasks_.push(std::move(ffn));
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
cv_.notify_one();
|
||||||
|
return fut;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost::common
|
||||||
@ -20,6 +20,7 @@
|
|||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
#include "../common/io.h" // for PrivateMmapConstStream
|
#include "../common/io.h" // for PrivateMmapConstStream
|
||||||
|
#include "../common/threadpool.h" // for ThreadPool
|
||||||
#include "../common/timer.h" // for Monitor, Timer
|
#include "../common/timer.h" // for Monitor, Timer
|
||||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||||
@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
std::mutex single_threaded_;
|
std::mutex single_threaded_;
|
||||||
// The current page.
|
// The current page.
|
||||||
std::shared_ptr<S> page_;
|
std::shared_ptr<S> page_;
|
||||||
|
// Workers for fetching data from external memory.
|
||||||
|
common::ThreadPool workers_;
|
||||||
|
|
||||||
bool at_end_ {false};
|
bool at_end_ {false};
|
||||||
float missing_;
|
float missing_;
|
||||||
@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
std::shared_ptr<Cache> cache_info_;
|
std::shared_ptr<Cache> cache_info_;
|
||||||
|
|
||||||
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
|
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
|
||||||
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
|
// A ring storing futures to data. Since the DMatrix iterator is forward only, we can
|
||||||
// can pre-fetch data in a ring.
|
// pre-fetch data in a ring.
|
||||||
std::unique_ptr<Ring> ring_{new Ring};
|
std::unique_ptr<Ring> ring_{new Ring};
|
||||||
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
|
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
|
||||||
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
|
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
|
||||||
@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
}
|
}
|
||||||
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
||||||
// to let user adjust number of pre-fetched batches when needed.
|
// to let user adjust number of pre-fetched batches when needed.
|
||||||
std::int32_t n_prefetches = std::max(nthreads_, 3);
|
std::int32_t kPrefetches = 3;
|
||||||
|
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
||||||
|
n_prefetches = std::max(n_prefetches, 1);
|
||||||
std::int32_t n_prefetch_batches =
|
std::int32_t n_prefetch_batches =
|
||||||
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
||||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||||
|
CHECK_LE(n_prefetch_batches, kPrefetches);
|
||||||
std::size_t fetch_it = count_;
|
std::size_t fetch_it = count_;
|
||||||
|
|
||||||
exce_.Rethrow();
|
exce_.Rethrow();
|
||||||
@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
}
|
}
|
||||||
auto const* self = this; // make sure it's const
|
auto const* self = this; // make sure it's const
|
||||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
|
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
|
||||||
*GlobalConfigThreadLocalStore::Get() = config;
|
*GlobalConfigThreadLocalStore::Get() = config;
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
this->exce_.Run([&] {
|
this->exce_.Run([&] {
|
||||||
@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
public:
|
public:
|
||||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||||
std::shared_ptr<Cache> cache)
|
std::shared_ptr<Cache> cache)
|
||||||
: missing_{missing},
|
: workers_{nthreads},
|
||||||
|
missing_{missing},
|
||||||
nthreads_{nthreads},
|
nthreads_{nthreads},
|
||||||
n_features_{n_features},
|
n_features_{n_features},
|
||||||
n_batches_{n_batches},
|
n_batches_{n_batches},
|
||||||
|
|||||||
49
tests/cpp/common/test_threadpool.cc
Normal file
49
tests/cpp/common/test_threadpool.cc
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <future> // for future
|
||||||
|
#include <thread> // for sleep_for, thread
|
||||||
|
|
||||||
|
#include "../../../src/common/threadpool.h"
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
|
TEST(ThreadPool, Basic) {
|
||||||
|
std::int32_t n_threads = std::thread::hardware_concurrency();
|
||||||
|
ThreadPool pool{n_threads};
|
||||||
|
{
|
||||||
|
auto fut = pool.Submit([] { return 3; });
|
||||||
|
ASSERT_EQ(fut.get(), 3);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto fut = pool.Submit([] { return std::string{"ok"}; });
|
||||||
|
ASSERT_EQ(fut.get(), "ok");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::vector<std::future<std::size_t>> futures;
|
||||||
|
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
|
||||||
|
futures.emplace_back(pool.Submit([=] {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds{10});
|
||||||
|
return i;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for (std::size_t i = 0; i < futures.size(); ++i) {
|
||||||
|
ASSERT_EQ(futures[i].get(), i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::vector<std::future<std::size_t>> futures;
|
||||||
|
for (std::size_t i = 0; i < static_cast<std::size_t>(n_threads) * 16; ++i) {
|
||||||
|
futures.emplace_back(pool.Submit([=] {
|
||||||
|
return i;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for (std::size_t i = 0; i < futures.size(); ++i) {
|
||||||
|
ASSERT_EQ(futures[i].get(), i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
Loading…
x
Reference in New Issue
Block a user