External memory support for hist (#7531)
* Generate column matrix from gHistIndex. * Avoid synchronization with the sparse page once the cache is written. * Cleanups: Remove member variables/functions, change the update routine to look like approx and gpu_hist. * Remove pruner.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 XGBoost contributors
|
||||
* Copyright 2019-2022 XGBoost contributors
|
||||
*/
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
@@ -12,6 +12,13 @@ namespace data {
|
||||
void EllpackPageSource::Fetch() {
|
||||
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
|
||||
if (!this->ReadCache()) {
|
||||
if (count_ != 0 && !sync_) {
|
||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||
// there's no need to increment the source.
|
||||
++(*source_);
|
||||
}
|
||||
// This is not read from cache so we still need it to be synced with sparse page source.
|
||||
CHECK_EQ(count_, source_->Iter());
|
||||
auto const &csr = source_->Page();
|
||||
this->page_.reset(new EllpackPage{});
|
||||
auto *impl = this->page_->Impl();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||
@@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
||||
std::unique_ptr<common::HistogramCuts> cuts_;
|
||||
|
||||
public:
|
||||
EllpackPageSource(
|
||||
float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
|
||||
std::shared_ptr<Cache> cache, BatchParam param,
|
||||
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense,
|
||||
size_t row_stride, common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
|
||||
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)},
|
||||
feature_types_{feature_types}, cuts_{std::move(cuts)} {
|
||||
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
|
||||
std::shared_ptr<Cache> cache, BatchParam param,
|
||||
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
|
||||
is_dense_{is_dense},
|
||||
row_stride_{row_stride},
|
||||
param_{std::move(param)},
|
||||
feature_types_{feature_types},
|
||||
cuts_{std::move(cuts)} {
|
||||
this->source_ = source;
|
||||
this->Fetch();
|
||||
}
|
||||
|
||||
@@ -144,7 +144,6 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||
|
||||
this->p_fmat = p_fmat;
|
||||
size_t new_size = 1;
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
new_size += batch.Size();
|
||||
@@ -164,6 +163,16 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
|
||||
prev_sum = row_ptr[rbegin + batch.Size()];
|
||||
rbegin += batch.Size();
|
||||
}
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||
|
||||
// hessian is empty when hist tree method is used or when dataset is empty
|
||||
if (hess.empty() && !std::isnan(sparse_thresh)) {
|
||||
// hist
|
||||
CHECK(!sorted_sketch);
|
||||
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
||||
this->columns_->Init(page, *this, sparse_thresh, n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||
@@ -187,6 +196,10 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
|
||||
size_t prev_sum = 0;
|
||||
|
||||
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||
if (!std::isnan(sparse_thresh)) {
|
||||
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
|
||||
@@ -205,4 +218,17 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
|
||||
index.Resize((sizeof(uint32_t)) * n_index);
|
||||
}
|
||||
}
|
||||
|
||||
common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
|
||||
CHECK(columns_);
|
||||
return *columns_;
|
||||
}
|
||||
|
||||
bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) {
|
||||
return this->columns_->Read(fi, this->cut.Ptrs().data());
|
||||
}
|
||||
|
||||
size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const {
|
||||
return this->columns_->Write(fo);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -40,7 +40,6 @@ class GHistIndexMatrix {
|
||||
std::vector<size_t> hit_count;
|
||||
/*! \brief The corresponding cuts */
|
||||
common::HistogramCuts cut;
|
||||
DMatrix* p_fmat;
|
||||
/*! \brief max_bin for each feature. */
|
||||
size_t max_num_bins;
|
||||
/*! \brief base row index for current page (used by external memory) */
|
||||
@@ -119,8 +118,12 @@ class GHistIndexMatrix {
|
||||
return row_ptr.empty() ? 0 : row_ptr.size() - 1;
|
||||
}
|
||||
|
||||
bool ReadColumnPage(dmlc::SeekStream* fi);
|
||||
size_t WriteColumnPage(dmlc::Stream* fo) const;
|
||||
|
||||
common::ColumnMatrix const& Transpose() const;
|
||||
|
||||
private:
|
||||
// unused at the moment: https://github.com/dmlc/xgboost/pull/7531
|
||||
std::unique_ptr<common::ColumnMatrix> columns_;
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
bool isDense_;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
* Copyright 2021-2022 XGBoost contributors
|
||||
*/
|
||||
#include "sparse_page_writer.h"
|
||||
#include "gradient_index.h"
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
||||
public:
|
||||
bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override {
|
||||
@@ -50,6 +49,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
||||
if (is_dense) {
|
||||
page->index.SetBinOffset(page->cut.Ptrs());
|
||||
}
|
||||
|
||||
page->ReadColumnPage(fi);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -81,6 +82,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
||||
bytes += sizeof(page.base_rowid);
|
||||
fo->Write(page.IsDense());
|
||||
bytes += sizeof(page.IsDense());
|
||||
|
||||
bytes += page.WriteColumnPage(fo);
|
||||
return bytes;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -7,11 +7,18 @@ namespace xgboost {
|
||||
namespace data {
|
||||
void GradientIndexPageSource::Fetch() {
|
||||
if (!this->ReadCache()) {
|
||||
if (count_ != 0 && !sync_) {
|
||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||
// there's no need to increment the source.
|
||||
++(*source_);
|
||||
}
|
||||
// This is not read from cache so we still need it to be synced with sparse page source.
|
||||
CHECK_EQ(count_, source_->Iter());
|
||||
auto const& csr = source_->Page();
|
||||
this->page_.reset(new GHistIndexMatrix());
|
||||
CHECK_NE(cuts_.Values().size(), 0);
|
||||
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_,
|
||||
sparse_thresh_, nthreads_);
|
||||
this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_,
|
||||
nthreads_);
|
||||
this->WriteCache();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,13 +22,14 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
|
||||
public:
|
||||
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
|
||||
std::shared_ptr<Cache> cache, BatchParam param,
|
||||
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat,
|
||||
common::HistogramCuts cuts, bool is_dense,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
|
||||
std::isnan(param.sparse_thresh)),
|
||||
cuts_{std::move(cuts)},
|
||||
is_dense_{is_dense},
|
||||
max_bin_per_feat_{max_bin_per_feat},
|
||||
max_bin_per_feat_{param.max_bin},
|
||||
feature_types_{feature_types},
|
||||
sparse_thresh_{param.sparse_thresh} {
|
||||
this->source_ = source;
|
||||
|
||||
@@ -159,21 +159,6 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
|
||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam ¶m) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
if (param.hess.empty() && !param.regen) {
|
||||
// hist method doesn't support full external memory implementation, so we concatenate
|
||||
// all index here.
|
||||
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
|
||||
this->InitializeSparsePage();
|
||||
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh,
|
||||
param.regen, ctx_.Threads()});
|
||||
this->InitializeSparsePage();
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_page_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
this->InitializeSparsePage();
|
||||
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
|
||||
@@ -190,10 +175,9 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
|
||||
ghist_index_source_.reset();
|
||||
CHECK_NE(cuts.Values().size(), 0);
|
||||
auto ft = this->info_.feature_types.ConstHostSpan();
|
||||
ghist_index_source_.reset(
|
||||
new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
|
||||
this->IsDense(), param.max_bin, ft, sparse_page_source_));
|
||||
ghist_index_source_.reset(new GradientIndexPageSource(
|
||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_,
|
||||
cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
|
||||
} else {
|
||||
CHECK(ghist_index_source_);
|
||||
ghist_index_source_->Reset();
|
||||
|
||||
@@ -11,6 +11,9 @@ namespace data {
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
size_t row_stride = 0;
|
||||
this->InitializeSparsePage();
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "proxy_dmatrix.h"
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -118,26 +119,30 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
|
||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||
size_t fetch_it = count_;
|
||||
|
||||
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||
fetch_it %= n_batches_; // ring
|
||||
if (ring_->at(fetch_it).valid()) { continue; }
|
||||
if (ring_->at(fetch_it).valid()) {
|
||||
continue;
|
||||
}
|
||||
auto const *self = this; // make sure it's const
|
||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||
auto n = self->cache_info_->ShardName();
|
||||
size_t offset = self->cache_info_->offset.at(fetch_it);
|
||||
std::unique_ptr<dmlc::SeekStream> fi{
|
||||
dmlc::SeekStream::CreateForRead(n.c_str())};
|
||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
|
||||
fi->Seek(offset);
|
||||
CHECK_EQ(fi->Tell(), offset);
|
||||
auto page = std::make_shared<S>();
|
||||
CHECK(fmt->Read(page.get(), fi.get()));
|
||||
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds.";
|
||||
return page;
|
||||
});
|
||||
}
|
||||
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(),
|
||||
[](auto const &f) { return f.valid(); }),
|
||||
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
||||
n_prefetch_batches)
|
||||
<< "Sparse DMatrix assumes forward iteration.";
|
||||
page_ = (*ring_)[count_].get();
|
||||
@@ -146,12 +151,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
|
||||
void WriteCache() {
|
||||
CHECK(!cache_info_->written);
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||
if (!fo_) {
|
||||
auto n = cache_info_->ShardName();
|
||||
fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
|
||||
}
|
||||
auto bytes = fmt->Write(*page_, fo_.get());
|
||||
timer.Stop();
|
||||
|
||||
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
|
||||
<< timer.ElapsedSeconds() << " seconds.";
|
||||
cache_info_->offset.push_back(bytes);
|
||||
}
|
||||
|
||||
@@ -280,15 +291,24 @@ template <typename S>
|
||||
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
|
||||
protected:
|
||||
std::shared_ptr<SparsePageSource> source_;
|
||||
using Super = SparsePageSourceImpl<S>;
|
||||
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
|
||||
// so we avoid fetching it.
|
||||
bool sync_{true};
|
||||
|
||||
public:
|
||||
using SparsePageSourceImpl<S>::SparsePageSourceImpl;
|
||||
PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||
std::shared_ptr<Cache> cache, bool sync)
|
||||
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
||||
|
||||
PageSourceIncMixIn& operator++() final {
|
||||
TryLockGuard guard{this->single_threaded_};
|
||||
++(*source_);
|
||||
if (sync_) {
|
||||
++(*source_);
|
||||
}
|
||||
|
||||
++this->count_;
|
||||
this->at_end_ = source_->AtEnd();
|
||||
this->at_end_ = this->count_ == this->n_batches_;
|
||||
|
||||
if (this->at_end_) {
|
||||
this->cache_info_->Commit();
|
||||
@@ -299,7 +319,10 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
|
||||
} else {
|
||||
this->Fetch();
|
||||
}
|
||||
CHECK_EQ(source_->Iter(), this->count_);
|
||||
|
||||
if (sync_) {
|
||||
CHECK_EQ(source_->Iter(), this->count_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
@@ -318,12 +341,9 @@ class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
|
||||
}
|
||||
|
||||
public:
|
||||
CSCPageSource(
|
||||
float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||
std::shared_ptr<Cache> cache,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features,
|
||||
n_batches, cache) {
|
||||
CSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||
std::shared_ptr<Cache> cache, std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
|
||||
this->source_ = source;
|
||||
this->Fetch();
|
||||
}
|
||||
@@ -349,7 +369,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
|
||||
SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features,
|
||||
uint32_t n_batches, std::shared_ptr<Cache> cache,
|
||||
std::shared_ptr<SparsePageSource> source)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) {
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
|
||||
this->source_ = source;
|
||||
this->Fetch();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user