Enhance the threadpool implementation. (#10531)
- Accept an initialization function. - Support void return tasks.
This commit is contained in:
parent
9cb4c938da
commit
628411a654
@ -26,21 +26,26 @@ class ThreadPool {
|
|||||||
bool stop_{false};
|
bool stop_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ThreadPool(std::int32_t n_threads) {
|
/**
|
||||||
|
* @param n_threads The number of threads this pool should hold.
|
||||||
|
* @param init_fn Function called once during thread creation.
|
||||||
|
*/
|
||||||
|
template <typename InitFn>
|
||||||
|
explicit ThreadPool(std::int32_t n_threads, InitFn&& init_fn) {
|
||||||
for (std::int32_t i = 0; i < n_threads; ++i) {
|
for (std::int32_t i = 0; i < n_threads; ++i) {
|
||||||
pool_.emplace_back([&] {
|
pool_.emplace_back([&, init_fn = std::forward<InitFn>(init_fn)] {
|
||||||
|
init_fn();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock lock{mu_};
|
std::unique_lock lock{mu_};
|
||||||
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
|
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
|
||||||
|
|
||||||
if (this->stop_) {
|
if (this->stop_) {
|
||||||
if (!tasks_.empty()) {
|
|
||||||
while (!tasks_.empty()) {
|
while (!tasks_.empty()) {
|
||||||
auto fn = tasks_.front();
|
auto fn = tasks_.front();
|
||||||
tasks_.pop();
|
tasks_.pop();
|
||||||
fn();
|
fn();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,8 +86,13 @@ class ThreadPool {
|
|||||||
// Use shared ptr to make the task copy constructible.
|
// Use shared ptr to make the task copy constructible.
|
||||||
auto p{std::make_shared<std::promise<R>>()};
|
auto p{std::make_shared<std::promise<R>>()};
|
||||||
auto fut = p->get_future();
|
auto fut = p->get_future();
|
||||||
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
|
auto ffn = std::function{[task = std::move(p), fn = std::forward<Fn>(fn)]() mutable {
|
||||||
|
if constexpr (std::is_void_v<R>) {
|
||||||
|
fn();
|
||||||
|
task->set_value();
|
||||||
|
} else {
|
||||||
task->set_value(fn());
|
task->set_value(fn());
|
||||||
|
}
|
||||||
}};
|
}};
|
||||||
|
|
||||||
std::unique_lock lock{mu_};
|
std::unique_lock lock{mu_};
|
||||||
|
|||||||
@ -236,7 +236,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
|
|
||||||
exce_.Rethrow();
|
exce_.Rethrow();
|
||||||
|
|
||||||
auto const config = *GlobalConfigThreadLocalStore::Get();
|
|
||||||
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||||
fetch_it %= n_batches_; // ring
|
fetch_it %= n_batches_; // ring
|
||||||
if (ring_->at(fetch_it).valid()) {
|
if (ring_->at(fetch_it).valid()) {
|
||||||
@ -244,8 +243,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
}
|
}
|
||||||
auto const* self = this; // make sure it's const
|
auto const* self = this; // make sure it's const
|
||||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||||
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
|
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
|
||||||
*GlobalConfigThreadLocalStore::Get() = config;
|
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
this->exce_.Run([&] {
|
this->exce_.Run([&] {
|
||||||
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
|
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
|
||||||
@ -297,7 +295,10 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
public:
|
public:
|
||||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
||||||
std::shared_ptr<Cache> cache)
|
std::shared_ptr<Cache> cache)
|
||||||
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
|
: workers_{std::max(2, std::min(nthreads, 16)),
|
||||||
|
[config = *GlobalConfigThreadLocalStore::Get()] {
|
||||||
|
*GlobalConfigThreadLocalStore::Get() = config;
|
||||||
|
}},
|
||||||
missing_{missing},
|
missing_{missing},
|
||||||
nthreads_{nthreads},
|
nthreads_{nthreads},
|
||||||
n_features_{n_features},
|
n_features_{n_features},
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
* Copyright 2024, XGBoost Contributors
|
* Copyright 2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/global_config.h> // for GlobalConfigThreadLocalStore
|
||||||
|
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
@ -13,7 +14,23 @@
|
|||||||
namespace xgboost::common {
|
namespace xgboost::common {
|
||||||
TEST(ThreadPool, Basic) {
|
TEST(ThreadPool, Basic) {
|
||||||
std::int32_t n_threads = std::thread::hardware_concurrency();
|
std::int32_t n_threads = std::thread::hardware_concurrency();
|
||||||
ThreadPool pool{n_threads};
|
|
||||||
|
// Set verbosity to 0 for thread-local variable.
|
||||||
|
auto orig = GlobalConfigThreadLocalStore::Get()->verbosity;
|
||||||
|
GlobalConfigThreadLocalStore::Get()->verbosity = 4;
|
||||||
|
// 4 is an invalid value, it's only possible to set it by bypassing the parameter
|
||||||
|
// validation.
|
||||||
|
ASSERT_NE(orig, GlobalConfigThreadLocalStore::Get()->verbosity);
|
||||||
|
ThreadPool pool{n_threads, [config = *GlobalConfigThreadLocalStore::Get()] {
|
||||||
|
*GlobalConfigThreadLocalStore::Get() = config;
|
||||||
|
}};
|
||||||
|
GlobalConfigThreadLocalStore::Get()->verbosity = orig; // restore
|
||||||
|
|
||||||
|
{
|
||||||
|
auto fut = pool.Submit([] { return GlobalConfigThreadLocalStore::Get()->verbosity; });
|
||||||
|
ASSERT_EQ(fut.get(), 4);
|
||||||
|
ASSERT_EQ(GlobalConfigThreadLocalStore::Get()->verbosity, orig);
|
||||||
|
}
|
||||||
{
|
{
|
||||||
auto fut = pool.Submit([] { return 3; });
|
auto fut = pool.Submit([] { return 3; });
|
||||||
ASSERT_EQ(fut.get(), 3);
|
ASSERT_EQ(fut.get(), 3);
|
||||||
@ -45,5 +62,12 @@ TEST(ThreadPool, Basic) {
|
|||||||
ASSERT_EQ(futures[i].get(), i);
|
ASSERT_EQ(futures[i].get(), i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
std::int32_t val{0};
|
||||||
|
auto fut = pool.Submit([&] { val = 3; });
|
||||||
|
static_assert(std::is_void_v<decltype(fut.get())>);
|
||||||
|
fut.get();
|
||||||
|
ASSERT_EQ(val, 3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace xgboost::common
|
} // namespace xgboost::common
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user