Use mmap for external memory. (#9282)

- Have basic infrastructure for mmap.
- Release file write handle.
This commit is contained in:
Jiaming Yuan
2023-06-19 18:52:55 +08:00
committed by GitHub
parent d8beb517ed
commit ee6809e642
16 changed files with 599 additions and 275 deletions

View File

@@ -1,35 +1,34 @@
/*!
* Copyright 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file sparse_page_source.h
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#include <algorithm> // std::min
#include <string>
#include <utility>
#include <vector>
#include <future>
#include <thread>
#include <algorithm> // for min
#include <future> // async
#include <map>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "../common/common.h"
#include "../common/io.h" // for PrivateMmapStream, PadPageForMMAP
#include "../common/timer.h" // for Monitor, Timer
#include "adapter.h"
#include "dmlc/common.h" // OMPException
#include "proxy_dmatrix.h"
#include "sparse_page_writer.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "adapter.h"
#include "sparse_page_writer.h"
#include "proxy_dmatrix.h"
#include "../common/common.h"
#include "../common/timer.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
inline void TryDeleteCacheFile(const std::string& file) {
if (std::remove(file.c_str()) != 0) {
LOG(WARNING) << "Couldn't remove external memory cache file " << file
<< "; you may want to remove it manually";
<< "; you may want to remove it manually";
}
}
@@ -54,6 +53,9 @@ struct Cache {
std::string ShardName() {
return ShardName(this->name, this->format);
}
void Push(std::size_t n_bytes) {
offset.push_back(n_bytes);
}
// The write is completed.
void Commit() {
@@ -95,56 +97,72 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
uint32_t n_batches_ {0};
std::shared_ptr<Cache> cache_info_;
std::unique_ptr<dmlc::Stream> fo_;
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
// can pre-fetch data in a ring.
std::unique_ptr<Ring> ring_{new Ring};
dmlc::OMPException exec_;
common::Monitor monitor_;
bool ReadCache() {
CHECK(!at_end_);
if (!cache_info_->written) {
return false;
}
if (fo_) {
fo_.reset(); // flush the data to disk.
if (ring_->empty()) {
ring_->resize(n_batches_);
}
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed.
uint32_t constexpr kPreFetch = 4;
uint32_t constexpr kPreFetch = 3;
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
size_t fetch_it = count_;
std::size_t fetch_it = count_;
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
exec_.Rethrow();
monitor_.Start("launch");
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) {
continue;
}
auto const *self = this; // make sure it's const
auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName();
size_t offset = self->cache_info_->offset.at(fetch_it);
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
fi->Seek(offset);
CHECK_EQ(fi->Tell(), offset);
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
auto page = std::make_shared<S>();
CHECK(fmt->Read(page.get(), fi.get()));
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds.";
this->exec_.Run([&] {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName();
std::uint64_t offset = self->cache_info_->offset.at(fetch_it);
std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset;
auto fi = std::make_unique<common::PrivateMmapConstStream>(n, offset, length);
CHECK(fmt->Read(page.get(), fi.get()));
timer.Stop();
LOG(INFO) << "Read a page `" << typeid(S).name() << "` in " << timer.ElapsedSeconds()
<< " seconds.";
});
return page;
});
}
monitor_.Stop("launch");
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration.";
monitor_.Start("Wait");
page_ = (*ring_)[count_].get();
monitor_.Stop("Wait");
CHECK(!(*ring_)[count_].valid());
exec_.Rethrow();
return true;
}
@@ -153,25 +171,35 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
if (!fo_) {
auto n = cache_info_->ShardName();
fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
}
auto bytes = fmt->Write(*page_, fo_.get());
timer.Stop();
auto name = cache_info_->ShardName();
std::unique_ptr<dmlc::Stream> fo;
if (this->Iter() == 0) {
fo.reset(dmlc::Stream::Create(name.c_str(), "wb"));
} else {
fo.reset(dmlc::Stream::Create(name.c_str(), "ab"));
}
auto bytes = fmt->Write(*page_, fo.get());
timer.Stop();
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
<< timer.ElapsedSeconds() << " seconds.";
cache_info_->offset.push_back(bytes);
cache_info_->Push(bytes);
}
virtual void Fetch() = 0;
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features,
uint32_t n_batches, std::shared_ptr<Cache> cache)
: missing_{missing}, nthreads_{nthreads}, n_features_{n_features},
n_batches_{n_batches}, cache_info_{std::move(cache)} {}
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
std::shared_ptr<Cache> cache)
: missing_{missing},
nthreads_{nthreads},
n_features_{n_features},
n_batches_{n_batches},
cache_info_{std::move(cache)} {
monitor_.Init(typeid(S).name()); // not pretty, but works for basic profiling
}
SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete;
@@ -244,7 +272,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
iter_{iter}, proxy_{proxy} {
if (!cache_info_->written) {
iter_.Reset();
CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch.";
CHECK(iter_.Next()) << "Must have at least 1 batch.";
}
this->Fetch();
}
@@ -259,6 +287,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
}
if (at_end_) {
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
cache_info_->Commit();
if (n_batches_ != 0) {
CHECK_EQ(count_, n_batches_);
@@ -371,6 +400,5 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
this->Fetch();
}
};
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_