[EM] Pass batch parameter into extmem format. (#10736)

- Allow customization for format reading.
- Customize the number of pre-fetch batches.
This commit is contained in:
Jiaming Yuan
2024-08-27 02:37:50 +08:00
committed by GitHub
parent 074cad2343
commit 25966e4ba8
15 changed files with 144 additions and 103 deletions

View File

@@ -60,7 +60,7 @@ template <typename T>
RET_IF_NOT(fi->Read(&impl->is_dense));
RET_IF_NOT(fi->Read(&impl->row_stride));
if (has_hmm_ats_) {
if (has_hmm_ats_ && !this->param_.prefetch_copy) {
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
} else {
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
@@ -95,7 +95,7 @@ template <typename T>
CHECK(this->cuts_->cut_values_.DeviceCanRead());
impl->SetCuts(this->cuts_);
fi->Read(page);
fi->Read(page, this->param_.prefetch_copy);
dh::DefaultStream().Sync();
return true;

View File

@@ -26,13 +26,17 @@ class EllpackHostCacheStream;
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
std::shared_ptr<common::HistogramCuts const> cuts_;
DeviceOrd device_;
BatchParam param_;
// Supports CUDA HMM or ATS
bool has_hmm_ats_{false};
public:
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
bool has_hmm_ats)
: cuts_{std::move(cuts)}, device_{device}, has_hmm_ats_{has_hmm_ats} {}
BatchParam param, bool has_hmm_ats)
: cuts_{std::move(cuts)},
device_{device},
param_{std::move(param)},
has_hmm_ats_{has_hmm_ats} {}
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
[[nodiscard]] std::size_t Write(const EllpackPage& page,
common::AlignedFileWriteStream* fo) override;

View File

@@ -11,7 +11,6 @@
#include "../common/common.h" // for safe_cuda
#include "../common/ref_resource_view.cuh"
#include "../common/cuda_pinned_allocator.h" // for pinned_allocator
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
#include "../common/resource.cuh" // for PrivateCudaMmapConstStream
#include "ellpack_page.cuh" // for EllpackPageImpl
@@ -19,7 +18,6 @@
#include "ellpack_page_source.h"
#include "proxy_dmatrix.cuh" // for Dispatch
#include "xgboost/base.h" // for bst_idx_t
#include "../common/cuda_rt_utils.h" // for NvtxScopedRange
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
namespace xgboost::data {
@@ -91,14 +89,20 @@ class EllpackHostCacheStreamImpl {
ptr_ += 1;
}
void Read(EllpackPage* out) const {
void Read(EllpackPage* out, bool prefetch_copy) const {
auto page = this->cache_->Get(ptr_);
auto impl = out->Impl();
impl->gidx_buffer =
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(page->gidx_buffer.size());
dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(),
page->gidx_buffer.size_bytes(), cudaMemcpyDefault));
if (prefetch_copy) {
impl->gidx_buffer =
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(page->gidx_buffer.size());
dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(),
page->gidx_buffer.size_bytes(), cudaMemcpyDefault));
} else {
auto res = page->gidx_buffer.Resource();
impl->gidx_buffer = common::RefResourceView<common::CompressedByteT>{
res->DataAs<common::CompressedByteT>(), page->gidx_buffer.size(), res};
}
impl->n_rows = page->Size();
impl->is_dense = page->IsDense();
@@ -120,7 +124,9 @@ std::shared_ptr<EllpackHostCache> EllpackHostCacheStream::Share() { return p_imp
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
void EllpackHostCacheStream::Read(EllpackPage* page) const { this->p_impl_->Read(page); }
void EllpackHostCacheStream::Read(EllpackPage* page, bool prefetch_copy) const {
this->p_impl_->Read(page, prefetch_copy);
}
void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); }
@@ -162,8 +168,9 @@ EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateWriter(StringV
template std::unique_ptr<
typename EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(
StringView name, std::uint64_t offset, std::uint64_t length) const;
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(StringView name,
bst_idx_t offset,
bst_idx_t length) const;
/**
* EllpackMmapStreamPolicy
@@ -233,6 +240,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
++(*this->source_);
CHECK_GE(this->source_->Iter(), 1);
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
CHECK(this->proxy_->Ctx()->IsCUDA()) << "All batches must use the same device type.";
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan();
auto n_samples = value.NumRows();

View File

@@ -53,7 +53,7 @@ class EllpackHostCacheStream {
void Seek(bst_idx_t offset_bytes);
void Read(EllpackPage* page) const;
void Read(EllpackPage* page, bool prefetch_copy) const;
void Write(EllpackPage const& page);
};
@@ -71,9 +71,9 @@ class EllpackFormatPolicy {
// For testing with the HMM flag.
explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {}
[[nodiscard]] auto CreatePageFormat() const {
[[nodiscard]] auto CreatePageFormat(BatchParam const& param) const {
CHECK_EQ(cuts_->cut_values_.Device(), device_);
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, has_hmm_}};
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}};
return fmt;
}

View File

@@ -66,6 +66,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
Context const *ctx,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref) {
xgboost_NVTX_FN_RANGE();
auto proxy = MakeProxy(proxy_handle);
CHECK(proxy);
@@ -118,7 +120,7 @@ BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndex(Context const
}
CHECK(this->ghist_index_source_);
this->ghist_index_source_->Reset();
this->ghist_index_source_->Reset(param);
if (!std::isnan(param.sparse_thresh) &&
param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {

View File

@@ -11,6 +11,7 @@
#include "proxy_dmatrix.h" // for DataIterProxy
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchParam
#include "../common/cuda_rt_utils.h"
namespace xgboost::data {
void ExtMemQuantileDMatrix::InitFromCUDA(
@@ -78,9 +79,9 @@ BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
}
std::visit(
[this](auto &&ptr) {
[this, param](auto &&ptr) {
CHECK(ptr);
ptr->Reset();
ptr->Reset(param);
},
this->ellpack_page_source_);

View File

@@ -37,6 +37,7 @@ void ExtGradientIndexPageSource::Fetch() {
CHECK_GE(source_->Iter(), 1);
CHECK_NE(cuts_.Values().size(), 0);
HostAdapterDispatch(proxy_, [this](auto const& value) {
CHECK(this->proxy_->Ctx()->IsCPU()) << "All batches must use the same device type.";
// This does three things:
// - Generate CSR matrix for gradient index.
// - Generate the column matrix for gradient index.

View File

@@ -31,7 +31,7 @@ class GHistIndexFormatPolicy {
using FormatT = SparsePageFormat<GHistIndexMatrix>;
public:
[[nodiscard]] auto CreatePageFormat() const {
[[nodiscard]] auto CreatePageFormat(BatchParam const&) const {
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
return fmt;
}

View File

@@ -82,7 +82,7 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
// release the iterator and data.
if (cache_info_.at(id)->written) {
CHECK(sparse_page_source_);
sparse_page_source_->Reset();
sparse_page_source_->Reset({});
return;
}
@@ -114,7 +114,7 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
std::make_shared<CSCPageSource>(this->missing_, ctx->Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), sparse_page_source_);
} else {
column_source_->Reset();
column_source_->Reset({});
}
return BatchSet{BatchIterator<CSCPage>{this->column_source_}};
}
@@ -129,7 +129,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
sparse_page_source_);
} else {
sorted_column_source_->Reset();
sorted_column_source_->Reset({});
}
return BatchSet{BatchIterator<SortedCSCPage>{this->sorted_column_source_}};
}
@@ -161,7 +161,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
} else {
CHECK(ghist_index_source_);
ghist_index_source_->Reset();
ghist_index_source_->Reset(param);
}
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
}

View File

@@ -61,7 +61,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
ellpack_page_source_);
} else {
CHECK(sparse_page_source_);
std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_);
std::visit([&](auto&& ptr) { ptr->Reset(param); }, this->ellpack_page_source_);
}
auto batch_set =

View File

@@ -204,7 +204,7 @@ class DefaultFormatPolicy {
using FormatT = SparsePageFormat<S>;
public:
auto CreatePageFormat() const {
auto CreatePageFormat(BatchParam const&) const {
std::unique_ptr<FormatT> fmt{::xgboost::data::CreatePageFormat<S>("raw")};
return fmt;
}
@@ -245,6 +245,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
std::uint32_t count_{0};
// Total number of batches.
bst_idx_t n_batches_{0};
// How we pre-fetch the data.
BatchParam param_;
std::shared_ptr<Cache> cache_info_;
@@ -267,12 +269,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
}
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed.
std::int32_t constexpr kPrefetches = 3;
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
std::int32_t n_prefetches = std::min(nthreads_, this->param_.n_prefetch_batches);
n_prefetches = std::max(n_prefetches, 1);
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
CHECK_LE(n_prefetch_batches, kPrefetches);
CHECK_GT(n_prefetch_batches, 0);
CHECK_LE(n_prefetch_batches, this->param_.n_prefetch_batches);
std::size_t fetch_it = count_;
exce_.Rethrow();
@@ -287,7 +288,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{
this->CreatePageFormat(this->param_)};
auto name = self->cache_info_->ShardName();
auto [offset, length] = self->cache_info_->View(fetch_it);
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
@@ -317,7 +319,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
CHECK(!cache_info_->written);
common::Timer timer;
timer.Start();
auto fmt{this->CreatePageFormat()};
auto fmt{this->CreatePageFormat(this->param_)};
auto name = cache_info_->ShardName();
std::unique_ptr<typename FormatStreamPolicy::WriterT> fo{
@@ -382,13 +384,16 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
this->count_ = 0;
}
virtual void Reset() {
virtual void Reset(BatchParam const& param) {
TryLockGuard guard{single_threaded_};
this->at_end_ = false;
auto cnt = this->count_;
this->count_ = 0;
if (cnt != 0) {
bool changed = this->param_.n_prefetch_batches != param.n_prefetch_batches;
this->param_ = param;
if (cnt != 0 || changed) {
// The last iteration did not get to the end, clear the ring to start from 0.
this->ring_ = std::make_unique<Ring>();
this->Fetch();
@@ -468,12 +473,12 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
return *this;
}
void Reset() override {
void Reset(BatchParam const& param) override {
if (proxy_) {
TryLockGuard guard{single_threaded_};
iter_.Reset();
}
SparsePageSourceImpl::Reset();
SparsePageSourceImpl::Reset(param);
TryLockGuard guard{single_threaded_};
base_row_id_ = 0;
@@ -535,9 +540,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
return *this;
}
void Reset() final {
this->source_->Reset();
Super::Reset();
void Reset(BatchParam const& param) final {
this->source_->Reset(param);
Super::Reset(param);
}
};
@@ -626,11 +631,11 @@ class ExtQantileSourceMixin : public SparsePageSourceImpl<S, FormatCreatePolicy>
return *this;
}
void Reset() final {
void Reset(BatchParam const& param) final {
if (this->source_) {
this->source_->Reset();
}
Super::Reset();
Super::Reset(param);
}
};
} // namespace xgboost::data

View File

@@ -119,8 +119,11 @@ struct DeviceSplitCandidate {
};
namespace cuda_impl {
inline BatchParam HistBatch(TrainParam const& param) {
return {param.max_bin, TrainParam::DftSparseThreshold()};
inline BatchParam HistBatch(TrainParam const& param, bool prefetch_copy = true) {
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
p.prefetch_copy = prefetch_copy;
p.n_prefetch_batches = 1;
return p;
}
inline BatchParam HistBatch(bst_bin_t max_bin) {