Enhance the threadpool implementation. (#10531)

- Accept an initialization function.
- Support void return tasks.
This commit is contained in:
Jiaming Yuan
2024-07-03 12:13:27 +08:00
committed by GitHub
parent 9cb4c938da
commit 628411a654
3 changed files with 50 additions and 15 deletions

View File

@@ -236,7 +236,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
exce_.Rethrow();
auto const config = *GlobalConfigThreadLocalStore::Get();
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) {
@@ -244,8 +243,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
}
auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
*GlobalConfigThreadLocalStore::Get() = config;
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
@@ -297,7 +295,10 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
std::shared_ptr<Cache> cache)
: workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads.
: workers_{std::max(2, std::min(nthreads, 16)),
[config = *GlobalConfigThreadLocalStore::Get()] {
*GlobalConfigThreadLocalStore::Get() = config;
}},
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},