Use a thread pool for external memory. (#10288)

This commit is contained in:
Jiaming Yuan
2024-05-16 19:32:12 +08:00
committed by GitHub
parent ee2afb3256
commit 835e59e538
3 changed files with 157 additions and 5 deletions

View File

@@ -20,6 +20,7 @@
#endif // !defined(XGBOOST_USE_CUDA)
#include "../common/io.h" // for PrivateMmapConstStream
#include "../common/threadpool.h" // for ThreadPool
#include "../common/timer.h" // for Monitor, Timer
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
@@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
std::mutex single_threaded_;
// The current page.
std::shared_ptr<S> page_;
// Workers for fetching data from external memory.
common::ThreadPool workers_;
bool at_end_ {false};
float missing_;
@@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
std::shared_ptr<Cache> cache_info_;
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
// can pre-fetch data in a ring.
// A ring storing futures to data. Since the DMatrix iterator is forward only, we can
// pre-fetch data in a ring.
std::unique_ptr<Ring> ring_{new Ring};
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
@@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
}
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed.
std::int32_t n_prefetches = std::max(nthreads_, 3);
std::int32_t kPrefetches = 3;
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
n_prefetches = std::max(n_prefetches, 1);
std::int32_t n_prefetch_batches =
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
CHECK_LE(n_prefetch_batches, kPrefetches);
std::size_t fetch_it = count_;
exce_.Rethrow();
@@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
}
auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
*GlobalConfigThreadLocalStore::Get() = config;
auto page = std::make_shared<S>();
this->exce_.Run([&] {
@@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache)
: missing_{missing},
: workers_{nthreads},
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},
n_batches_{n_batches},