Enhance the threadpool implementation. (#10531)
- Accept an initialization function. - Support void return tasks.
This commit is contained in:
@@ -26,20 +26,25 @@ class ThreadPool {
|
||||
bool stop_{false};
|
||||
|
||||
public:
|
||||
explicit ThreadPool(std::int32_t n_threads) {
|
||||
/**
|
||||
* @param n_threads The number of threads this pool should hold.
|
||||
* @param init_fn Function called once during thread creation.
|
||||
*/
|
||||
template <typename InitFn>
|
||||
explicit ThreadPool(std::int32_t n_threads, InitFn&& init_fn) {
|
||||
for (std::int32_t i = 0; i < n_threads; ++i) {
|
||||
pool_.emplace_back([&] {
|
||||
pool_.emplace_back([&, init_fn = std::forward<InitFn>(init_fn)] {
|
||||
init_fn();
|
||||
|
||||
while (true) {
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; });
|
||||
|
||||
if (this->stop_) {
|
||||
if (!tasks_.empty()) {
|
||||
while (!tasks_.empty()) {
|
||||
auto fn = tasks_.front();
|
||||
tasks_.pop();
|
||||
fn();
|
||||
}
|
||||
while (!tasks_.empty()) {
|
||||
auto fn = tasks_.front();
|
||||
tasks_.pop();
|
||||
fn();
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -81,8 +86,13 @@ class ThreadPool {
|
||||
// Use shared ptr to make the task copy constructible.
|
||||
auto p{std::make_shared<std::promise<R>>()};
|
||||
auto fut = p->get_future();
|
||||
auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable {
|
||||
task->set_value(fn());
|
||||
auto ffn = std::function{[task = std::move(p), fn = std::forward<Fn>(fn)]() mutable {
|
||||
if constexpr (std::is_void_v<R>) {
|
||||
fn();
|
||||
task->set_value();
|
||||
} else {
|
||||
task->set_value(fn());
|
||||
}
|
||||
}};
|
||||
|
||||
std::unique_lock lock{mu_};
|
||||
|
||||
@@ -236,7 +236,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
||||
|
||||
exce_.Rethrow();
|
||||
|
||||
auto const config = *GlobalConfigThreadLocalStore::Get();
|
||||
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||
fetch_it %= n_batches_; // ring
|
||||
if (ring_->at(fetch_it).valid()) {
|
||||
@@ -244,8 +243,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
||||
}
|
||||
auto const* self = this; // make sure it's const
|
||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
|
||||
*GlobalConfigThreadLocalStore::Get() = config;
|
||||
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
|
||||
auto page = std::make_shared<S>();
|
||||
this->exce_.Run([&] {
|
||||
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
|
||||
@@ -297,7 +295,10 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
||||
public:
|
||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
||||
std::shared_ptr<Cache> cache)
|
||||
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
|
||||
: workers_{std::max(2, std::min(nthreads, 16)),
|
||||
[config = *GlobalConfigThreadLocalStore::Get()] {
|
||||
*GlobalConfigThreadLocalStore::Get() = config;
|
||||
}},
|
||||
missing_{missing},
|
||||
nthreads_{nthreads},
|
||||
n_features_{n_features},
|
||||
|
||||
Reference in New Issue
Block a user