Fix global config for external memory. (#10173)

Pass the thread-local configuration between threads.
This commit is contained in:
Jiaming Yuan
2024-04-11 01:29:28 +08:00
committed by GitHub
parent f0a138f33a
commit 1022909bbe
3 changed files with 28 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file sparse_page_source.h
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
@@ -7,23 +7,26 @@
#include <algorithm> // for min
#include <atomic> // for atomic
#include <cstdio> // for remove
#include <future> // for async
#include <map>
#include <memory>
#include <mutex> // for mutex
#include <string>
#include <thread>
#include <utility> // for pair, move
#include <vector>
#include <memory> // for unique_ptr
#include <mutex> // for mutex
#include <string> // for string
#include <utility> // for pair, move
#include <vector> // for vector
#include "../common/common.h"
#include "../common/io.h" // for PrivateMmapConstStream
#include "../common/timer.h" // for Monitor, Timer
#include "adapter.h"
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h"
#include "xgboost/data.h"
#if !defined(XGBOOST_USE_CUDA)
#include "../common/common.h" // for AssertGPUSupport
#endif // !defined(XGBOOST_USE_CUDA)
#include "../common/io.h" // for PrivateMmapConstStream
#include "../common/timer.h" // for Monitor, Timer
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/data.h" // for SparsePage, CSCPage
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
#include "xgboost/logging.h" // for CHECK_EQ
namespace xgboost::data {
inline void TryDeleteCacheFile(const std::string& file) {
@@ -185,6 +188,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
exce_.Rethrow();
auto const config = *GlobalConfigThreadLocalStore::Get();
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) {
@@ -192,7 +196,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
}
auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
*GlobalConfigThreadLocalStore::Get() = config;
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};