Use a thread pool for external memory. (#10288)
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
#include "../common/io.h" // for PrivateMmapConstStream
|
||||
#include "../common/threadpool.h" // for ThreadPool
|
||||
#include "../common/timer.h" // for Monitor, Timer
|
||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||
@@ -148,6 +149,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
std::mutex single_threaded_;
|
||||
// The current page.
|
||||
std::shared_ptr<S> page_;
|
||||
// Workers for fetching data from external memory.
|
||||
common::ThreadPool workers_;
|
||||
|
||||
bool at_end_ {false};
|
||||
float missing_;
|
||||
@@ -161,8 +164,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
std::shared_ptr<Cache> cache_info_;
|
||||
|
||||
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
|
||||
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
|
||||
// can pre-fetch data in a ring.
|
||||
// A ring storing futures to data. Since the DMatrix iterator is forward only, we can
|
||||
// pre-fetch data in a ring.
|
||||
std::unique_ptr<Ring> ring_{new Ring};
|
||||
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
|
||||
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
|
||||
@@ -180,10 +183,13 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
}
|
||||
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
||||
// to let user adjust number of pre-fetched batches when needed.
|
||||
std::int32_t n_prefetches = std::max(nthreads_, 3);
|
||||
std::int32_t kPrefetches = 3;
|
||||
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
||||
n_prefetches = std::max(n_prefetches, 1);
|
||||
std::int32_t n_prefetch_batches =
|
||||
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||
CHECK_LE(n_prefetch_batches, kPrefetches);
|
||||
std::size_t fetch_it = count_;
|
||||
|
||||
exce_.Rethrow();
|
||||
@@ -196,7 +202,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
}
|
||||
auto const* self = this; // make sure it's const
|
||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
|
||||
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] {
|
||||
*GlobalConfigThreadLocalStore::Get() = config;
|
||||
auto page = std::make_shared<S>();
|
||||
this->exce_.Run([&] {
|
||||
@@ -252,7 +258,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
public:
|
||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||
std::shared_ptr<Cache> cache)
|
||||
: missing_{missing},
|
||||
: workers_{nthreads},
|
||||
missing_{missing},
|
||||
nthreads_{nthreads},
|
||||
n_features_{n_features},
|
||||
n_batches_{n_batches},
|
||||
|
||||
Reference in New Issue
Block a user