[EM] Pass batch parameter into extmem format. (#10736)
- Allow customization for format reading. - Customize the number of pre-fetch batches.
This commit is contained in:
parent
074cad2343
commit
25966e4ba8
@ -239,42 +239,52 @@ struct Entry {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Parameters for constructing histogram index batches.
|
* @brief Parameters for constructing histogram index batches.
|
||||||
*/
|
*/
|
||||||
struct BatchParam {
|
struct BatchParam {
|
||||||
/**
|
/**
|
||||||
* \brief Maximum number of bins per feature for histograms.
|
* @brief Maximum number of bins per feature for histograms.
|
||||||
*/
|
*/
|
||||||
bst_bin_t max_bin{0};
|
bst_bin_t max_bin{0};
|
||||||
/**
|
/**
|
||||||
* \brief Hessian, used for sketching with future approx implementation.
|
* @brief Hessian, used for sketching with future approx implementation.
|
||||||
*/
|
*/
|
||||||
common::Span<float const> hess;
|
common::Span<float const> hess;
|
||||||
/**
|
/**
|
||||||
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
|
* @brief Whether should we force DMatrix to regenerate the batch. Only used for
|
||||||
* GHistIndex.
|
* GHistIndex.
|
||||||
*/
|
*/
|
||||||
bool regen{false};
|
bool regen{false};
|
||||||
/**
|
/**
|
||||||
* \brief Forbid regenerating the gradient index. Used for internal validation.
|
* @brief Forbid regenerating the gradient index. Used for internal validation.
|
||||||
*/
|
*/
|
||||||
bool forbid_regen{false};
|
bool forbid_regen{false};
|
||||||
/**
|
/**
|
||||||
* \brief Parameter used to generate column matrix for hist.
|
* @brief Parameter used to generate column matrix for hist.
|
||||||
*/
|
*/
|
||||||
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
|
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
|
||||||
|
/**
|
||||||
|
* @brief Used for GPU external memory. Whether to copy the data into device.
|
||||||
|
*
|
||||||
|
* This affects only the current round of iteration.
|
||||||
|
*/
|
||||||
|
bool prefetch_copy{true};
|
||||||
|
/**
|
||||||
|
* @brief The number of batches to pre-fetch for external memory.
|
||||||
|
*/
|
||||||
|
std::int32_t n_prefetch_batches{3};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Exact or others that don't need histogram.
|
* @brief Exact or others that don't need histogram.
|
||||||
*/
|
*/
|
||||||
BatchParam() = default;
|
BatchParam() = default;
|
||||||
/**
|
/**
|
||||||
* \brief Used by the hist tree method.
|
* @brief Used by the hist tree method.
|
||||||
*/
|
*/
|
||||||
BatchParam(bst_bin_t max_bin, double sparse_thresh)
|
BatchParam(bst_bin_t max_bin, double sparse_thresh)
|
||||||
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
|
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
|
||||||
/**
|
/**
|
||||||
* \brief Used by the approx tree method.
|
* @brief Used by the approx tree method.
|
||||||
*
|
*
|
||||||
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
|
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
|
||||||
* span is changed, so caller should keep the span for each iteration.
|
* span is changed, so caller should keep the span for each iteration.
|
||||||
@ -295,7 +305,7 @@ struct BatchParam {
|
|||||||
}
|
}
|
||||||
[[nodiscard]] bool Initialized() const { return max_bin != 0; }
|
[[nodiscard]] bool Initialized() const { return max_bin != 0; }
|
||||||
/**
|
/**
|
||||||
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
|
* @brief Make a copy of self for DMatrix to describe how its existing index was generated.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] BatchParam MakeCache() const {
|
[[nodiscard]] BatchParam MakeCache() const {
|
||||||
auto p = *this;
|
auto p = *this;
|
||||||
|
|||||||
@ -60,7 +60,7 @@ template <typename T>
|
|||||||
RET_IF_NOT(fi->Read(&impl->is_dense));
|
RET_IF_NOT(fi->Read(&impl->is_dense));
|
||||||
RET_IF_NOT(fi->Read(&impl->row_stride));
|
RET_IF_NOT(fi->Read(&impl->row_stride));
|
||||||
|
|
||||||
if (has_hmm_ats_) {
|
if (has_hmm_ats_ && !this->param_.prefetch_copy) {
|
||||||
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
|
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
|
||||||
} else {
|
} else {
|
||||||
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
|
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
|
||||||
@ -95,7 +95,7 @@ template <typename T>
|
|||||||
CHECK(this->cuts_->cut_values_.DeviceCanRead());
|
CHECK(this->cuts_->cut_values_.DeviceCanRead());
|
||||||
impl->SetCuts(this->cuts_);
|
impl->SetCuts(this->cuts_);
|
||||||
|
|
||||||
fi->Read(page);
|
fi->Read(page, this->param_.prefetch_copy);
|
||||||
dh::DefaultStream().Sync();
|
dh::DefaultStream().Sync();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@ -26,13 +26,17 @@ class EllpackHostCacheStream;
|
|||||||
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts_;
|
std::shared_ptr<common::HistogramCuts const> cuts_;
|
||||||
DeviceOrd device_;
|
DeviceOrd device_;
|
||||||
|
BatchParam param_;
|
||||||
// Supports CUDA HMM or ATS
|
// Supports CUDA HMM or ATS
|
||||||
bool has_hmm_ats_{false};
|
bool has_hmm_ats_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
|
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
|
||||||
bool has_hmm_ats)
|
BatchParam param, bool has_hmm_ats)
|
||||||
: cuts_{std::move(cuts)}, device_{device}, has_hmm_ats_{has_hmm_ats} {}
|
: cuts_{std::move(cuts)},
|
||||||
|
device_{device},
|
||||||
|
param_{std::move(param)},
|
||||||
|
has_hmm_ats_{has_hmm_ats} {}
|
||||||
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
|
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
|
||||||
[[nodiscard]] std::size_t Write(const EllpackPage& page,
|
[[nodiscard]] std::size_t Write(const EllpackPage& page,
|
||||||
common::AlignedFileWriteStream* fo) override;
|
common::AlignedFileWriteStream* fo) override;
|
||||||
|
|||||||
@ -11,7 +11,6 @@
|
|||||||
|
|
||||||
#include "../common/common.h" // for safe_cuda
|
#include "../common/common.h" // for safe_cuda
|
||||||
#include "../common/ref_resource_view.cuh"
|
#include "../common/ref_resource_view.cuh"
|
||||||
#include "../common/cuda_pinned_allocator.h" // for pinned_allocator
|
|
||||||
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
|
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
|
||||||
#include "../common/resource.cuh" // for PrivateCudaMmapConstStream
|
#include "../common/resource.cuh" // for PrivateCudaMmapConstStream
|
||||||
#include "ellpack_page.cuh" // for EllpackPageImpl
|
#include "ellpack_page.cuh" // for EllpackPageImpl
|
||||||
@ -19,7 +18,6 @@
|
|||||||
#include "ellpack_page_source.h"
|
#include "ellpack_page_source.h"
|
||||||
#include "proxy_dmatrix.cuh" // for Dispatch
|
#include "proxy_dmatrix.cuh" // for Dispatch
|
||||||
#include "xgboost/base.h" // for bst_idx_t
|
#include "xgboost/base.h" // for bst_idx_t
|
||||||
#include "../common/cuda_rt_utils.h" // for NvtxScopedRange
|
|
||||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
@ -91,14 +89,20 @@ class EllpackHostCacheStreamImpl {
|
|||||||
ptr_ += 1;
|
ptr_ += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Read(EllpackPage* out) const {
|
void Read(EllpackPage* out, bool prefetch_copy) const {
|
||||||
auto page = this->cache_->Get(ptr_);
|
auto page = this->cache_->Get(ptr_);
|
||||||
|
|
||||||
auto impl = out->Impl();
|
auto impl = out->Impl();
|
||||||
|
if (prefetch_copy) {
|
||||||
impl->gidx_buffer =
|
impl->gidx_buffer =
|
||||||
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(page->gidx_buffer.size());
|
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(page->gidx_buffer.size());
|
||||||
dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(),
|
dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(),
|
||||||
page->gidx_buffer.size_bytes(), cudaMemcpyDefault));
|
page->gidx_buffer.size_bytes(), cudaMemcpyDefault));
|
||||||
|
} else {
|
||||||
|
auto res = page->gidx_buffer.Resource();
|
||||||
|
impl->gidx_buffer = common::RefResourceView<common::CompressedByteT>{
|
||||||
|
res->DataAs<common::CompressedByteT>(), page->gidx_buffer.size(), res};
|
||||||
|
}
|
||||||
|
|
||||||
impl->n_rows = page->Size();
|
impl->n_rows = page->Size();
|
||||||
impl->is_dense = page->IsDense();
|
impl->is_dense = page->IsDense();
|
||||||
@ -120,7 +124,9 @@ std::shared_ptr<EllpackHostCache> EllpackHostCacheStream::Share() { return p_imp
|
|||||||
|
|
||||||
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
|
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
|
||||||
|
|
||||||
void EllpackHostCacheStream::Read(EllpackPage* page) const { this->p_impl_->Read(page); }
|
void EllpackHostCacheStream::Read(EllpackPage* page, bool prefetch_copy) const {
|
||||||
|
this->p_impl_->Read(page, prefetch_copy);
|
||||||
|
}
|
||||||
|
|
||||||
void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); }
|
void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); }
|
||||||
|
|
||||||
@ -162,8 +168,9 @@ EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateWriter(StringV
|
|||||||
|
|
||||||
template std::unique_ptr<
|
template std::unique_ptr<
|
||||||
typename EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
|
typename EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
|
||||||
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(
|
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(StringView name,
|
||||||
StringView name, std::uint64_t offset, std::uint64_t length) const;
|
bst_idx_t offset,
|
||||||
|
bst_idx_t length) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* EllpackMmapStreamPolicy
|
* EllpackMmapStreamPolicy
|
||||||
@ -233,6 +240,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
|
|||||||
++(*this->source_);
|
++(*this->source_);
|
||||||
CHECK_GE(this->source_->Iter(), 1);
|
CHECK_GE(this->source_->Iter(), 1);
|
||||||
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
|
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
|
||||||
|
CHECK(this->proxy_->Ctx()->IsCUDA()) << "All batches must use the same device type.";
|
||||||
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
|
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
|
||||||
auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan();
|
auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan();
|
||||||
auto n_samples = value.NumRows();
|
auto n_samples = value.NumRows();
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class EllpackHostCacheStream {
|
|||||||
|
|
||||||
void Seek(bst_idx_t offset_bytes);
|
void Seek(bst_idx_t offset_bytes);
|
||||||
|
|
||||||
void Read(EllpackPage* page) const;
|
void Read(EllpackPage* page, bool prefetch_copy) const;
|
||||||
void Write(EllpackPage const& page);
|
void Write(EllpackPage const& page);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -71,9 +71,9 @@ class EllpackFormatPolicy {
|
|||||||
// For testing with the HMM flag.
|
// For testing with the HMM flag.
|
||||||
explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {}
|
explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {}
|
||||||
|
|
||||||
[[nodiscard]] auto CreatePageFormat() const {
|
[[nodiscard]] auto CreatePageFormat(BatchParam const& param) const {
|
||||||
CHECK_EQ(cuts_->cut_values_.Device(), device_);
|
CHECK_EQ(cuts_->cut_values_.Device(), device_);
|
||||||
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, has_hmm_}};
|
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}};
|
||||||
return fmt;
|
return fmt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -66,6 +66,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
|
|||||||
Context const *ctx,
|
Context const *ctx,
|
||||||
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
|
||||||
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref) {
|
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref) {
|
||||||
|
xgboost_NVTX_FN_RANGE();
|
||||||
|
|
||||||
auto proxy = MakeProxy(proxy_handle);
|
auto proxy = MakeProxy(proxy_handle);
|
||||||
CHECK(proxy);
|
CHECK(proxy);
|
||||||
|
|
||||||
@ -118,7 +120,7 @@ BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndex(Context const
|
|||||||
}
|
}
|
||||||
|
|
||||||
CHECK(this->ghist_index_source_);
|
CHECK(this->ghist_index_source_);
|
||||||
this->ghist_index_source_->Reset();
|
this->ghist_index_source_->Reset(param);
|
||||||
|
|
||||||
if (!std::isnan(param.sparse_thresh) &&
|
if (!std::isnan(param.sparse_thresh) &&
|
||||||
param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {
|
param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
#include "proxy_dmatrix.h" // for DataIterProxy
|
#include "proxy_dmatrix.h" // for DataIterProxy
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/data.h" // for BatchParam
|
#include "xgboost/data.h" // for BatchParam
|
||||||
|
#include "../common/cuda_rt_utils.h"
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
void ExtMemQuantileDMatrix::InitFromCUDA(
|
void ExtMemQuantileDMatrix::InitFromCUDA(
|
||||||
@ -78,9 +79,9 @@ BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::visit(
|
std::visit(
|
||||||
[this](auto &&ptr) {
|
[this, param](auto &&ptr) {
|
||||||
CHECK(ptr);
|
CHECK(ptr);
|
||||||
ptr->Reset();
|
ptr->Reset(param);
|
||||||
},
|
},
|
||||||
this->ellpack_page_source_);
|
this->ellpack_page_source_);
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,7 @@ void ExtGradientIndexPageSource::Fetch() {
|
|||||||
CHECK_GE(source_->Iter(), 1);
|
CHECK_GE(source_->Iter(), 1);
|
||||||
CHECK_NE(cuts_.Values().size(), 0);
|
CHECK_NE(cuts_.Values().size(), 0);
|
||||||
HostAdapterDispatch(proxy_, [this](auto const& value) {
|
HostAdapterDispatch(proxy_, [this](auto const& value) {
|
||||||
|
CHECK(this->proxy_->Ctx()->IsCPU()) << "All batches must use the same device type.";
|
||||||
// This does three things:
|
// This does three things:
|
||||||
// - Generate CSR matrix for gradient index.
|
// - Generate CSR matrix for gradient index.
|
||||||
// - Generate the column matrix for gradient index.
|
// - Generate the column matrix for gradient index.
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class GHistIndexFormatPolicy {
|
|||||||
using FormatT = SparsePageFormat<GHistIndexMatrix>;
|
using FormatT = SparsePageFormat<GHistIndexMatrix>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
[[nodiscard]] auto CreatePageFormat() const {
|
[[nodiscard]] auto CreatePageFormat(BatchParam const&) const {
|
||||||
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
|
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
|
||||||
return fmt;
|
return fmt;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -82,7 +82,7 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
|
|||||||
// release the iterator and data.
|
// release the iterator and data.
|
||||||
if (cache_info_.at(id)->written) {
|
if (cache_info_.at(id)->written) {
|
||||||
CHECK(sparse_page_source_);
|
CHECK(sparse_page_source_);
|
||||||
sparse_page_source_->Reset();
|
sparse_page_source_->Reset({});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
|||||||
std::make_shared<CSCPageSource>(this->missing_, ctx->Threads(), this->Info().num_col_,
|
std::make_shared<CSCPageSource>(this->missing_, ctx->Threads(), this->Info().num_col_,
|
||||||
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
this->n_batches_, cache_info_.at(id), sparse_page_source_);
|
||||||
} else {
|
} else {
|
||||||
column_source_->Reset();
|
column_source_->Reset({});
|
||||||
}
|
}
|
||||||
return BatchSet{BatchIterator<CSCPage>{this->column_source_}};
|
return BatchSet{BatchIterator<CSCPage>{this->column_source_}};
|
||||||
}
|
}
|
||||||
@ -129,7 +129,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
|
|||||||
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||||
sparse_page_source_);
|
sparse_page_source_);
|
||||||
} else {
|
} else {
|
||||||
sorted_column_source_->Reset();
|
sorted_column_source_->Reset({});
|
||||||
}
|
}
|
||||||
return BatchSet{BatchIterator<SortedCSCPage>{this->sorted_column_source_}};
|
return BatchSet{BatchIterator<SortedCSCPage>{this->sorted_column_source_}};
|
||||||
}
|
}
|
||||||
@ -161,7 +161,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
|||||||
param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
|
param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
|
||||||
} else {
|
} else {
|
||||||
CHECK(ghist_index_source_);
|
CHECK(ghist_index_source_);
|
||||||
ghist_index_source_->Reset();
|
ghist_index_source_->Reset(param);
|
||||||
}
|
}
|
||||||
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
|
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -61,7 +61,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
|||||||
ellpack_page_source_);
|
ellpack_page_source_);
|
||||||
} else {
|
} else {
|
||||||
CHECK(sparse_page_source_);
|
CHECK(sparse_page_source_);
|
||||||
std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_);
|
std::visit([&](auto&& ptr) { ptr->Reset(param); }, this->ellpack_page_source_);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto batch_set =
|
auto batch_set =
|
||||||
|
|||||||
@ -204,7 +204,7 @@ class DefaultFormatPolicy {
|
|||||||
using FormatT = SparsePageFormat<S>;
|
using FormatT = SparsePageFormat<S>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
auto CreatePageFormat() const {
|
auto CreatePageFormat(BatchParam const&) const {
|
||||||
std::unique_ptr<FormatT> fmt{::xgboost::data::CreatePageFormat<S>("raw")};
|
std::unique_ptr<FormatT> fmt{::xgboost::data::CreatePageFormat<S>("raw")};
|
||||||
return fmt;
|
return fmt;
|
||||||
}
|
}
|
||||||
@ -245,6 +245,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
std::uint32_t count_{0};
|
std::uint32_t count_{0};
|
||||||
// Total number of batches.
|
// Total number of batches.
|
||||||
bst_idx_t n_batches_{0};
|
bst_idx_t n_batches_{0};
|
||||||
|
// How we pre-fetch the data.
|
||||||
|
BatchParam param_;
|
||||||
|
|
||||||
std::shared_ptr<Cache> cache_info_;
|
std::shared_ptr<Cache> cache_info_;
|
||||||
|
|
||||||
@ -267,12 +269,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
}
|
}
|
||||||
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
||||||
// to let user adjust number of pre-fetched batches when needed.
|
// to let user adjust number of pre-fetched batches when needed.
|
||||||
std::int32_t constexpr kPrefetches = 3;
|
std::int32_t n_prefetches = std::min(nthreads_, this->param_.n_prefetch_batches);
|
||||||
std::int32_t n_prefetches = std::min(nthreads_, kPrefetches);
|
|
||||||
n_prefetches = std::max(n_prefetches, 1);
|
n_prefetches = std::max(n_prefetches, 1);
|
||||||
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
|
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
|
||||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
CHECK_GT(n_prefetch_batches, 0);
|
||||||
CHECK_LE(n_prefetch_batches, kPrefetches);
|
CHECK_LE(n_prefetch_batches, this->param_.n_prefetch_batches);
|
||||||
std::size_t fetch_it = count_;
|
std::size_t fetch_it = count_;
|
||||||
|
|
||||||
exce_.Rethrow();
|
exce_.Rethrow();
|
||||||
@ -287,7 +288,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
|
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
this->exce_.Run([&] {
|
this->exce_.Run([&] {
|
||||||
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{this->CreatePageFormat()};
|
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{
|
||||||
|
this->CreatePageFormat(this->param_)};
|
||||||
auto name = self->cache_info_->ShardName();
|
auto name = self->cache_info_->ShardName();
|
||||||
auto [offset, length] = self->cache_info_->View(fetch_it);
|
auto [offset, length] = self->cache_info_->View(fetch_it);
|
||||||
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
|
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
|
||||||
@ -317,7 +319,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
CHECK(!cache_info_->written);
|
CHECK(!cache_info_->written);
|
||||||
common::Timer timer;
|
common::Timer timer;
|
||||||
timer.Start();
|
timer.Start();
|
||||||
auto fmt{this->CreatePageFormat()};
|
auto fmt{this->CreatePageFormat(this->param_)};
|
||||||
|
|
||||||
auto name = cache_info_->ShardName();
|
auto name = cache_info_->ShardName();
|
||||||
std::unique_ptr<typename FormatStreamPolicy::WriterT> fo{
|
std::unique_ptr<typename FormatStreamPolicy::WriterT> fo{
|
||||||
@ -382,13 +384,16 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
|
|||||||
this->count_ = 0;
|
this->count_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void Reset() {
|
virtual void Reset(BatchParam const& param) {
|
||||||
TryLockGuard guard{single_threaded_};
|
TryLockGuard guard{single_threaded_};
|
||||||
|
|
||||||
this->at_end_ = false;
|
this->at_end_ = false;
|
||||||
auto cnt = this->count_;
|
auto cnt = this->count_;
|
||||||
this->count_ = 0;
|
this->count_ = 0;
|
||||||
if (cnt != 0) {
|
bool changed = this->param_.n_prefetch_batches != param.n_prefetch_batches;
|
||||||
|
this->param_ = param;
|
||||||
|
|
||||||
|
if (cnt != 0 || changed) {
|
||||||
// The last iteration did not get to the end, clear the ring to start from 0.
|
// The last iteration did not get to the end, clear the ring to start from 0.
|
||||||
this->ring_ = std::make_unique<Ring>();
|
this->ring_ = std::make_unique<Ring>();
|
||||||
this->Fetch();
|
this->Fetch();
|
||||||
@ -468,12 +473,12 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() override {
|
void Reset(BatchParam const& param) override {
|
||||||
if (proxy_) {
|
if (proxy_) {
|
||||||
TryLockGuard guard{single_threaded_};
|
TryLockGuard guard{single_threaded_};
|
||||||
iter_.Reset();
|
iter_.Reset();
|
||||||
}
|
}
|
||||||
SparsePageSourceImpl::Reset();
|
SparsePageSourceImpl::Reset(param);
|
||||||
|
|
||||||
TryLockGuard guard{single_threaded_};
|
TryLockGuard guard{single_threaded_};
|
||||||
base_row_id_ = 0;
|
base_row_id_ = 0;
|
||||||
@ -535,9 +540,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() final {
|
void Reset(BatchParam const& param) final {
|
||||||
this->source_->Reset();
|
this->source_->Reset(param);
|
||||||
Super::Reset();
|
Super::Reset(param);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -626,11 +631,11 @@ class ExtQantileSourceMixin : public SparsePageSourceImpl<S, FormatCreatePolicy>
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() final {
|
void Reset(BatchParam const& param) final {
|
||||||
if (this->source_) {
|
if (this->source_) {
|
||||||
this->source_->Reset();
|
this->source_->Reset();
|
||||||
}
|
}
|
||||||
Super::Reset();
|
Super::Reset(param);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -119,8 +119,11 @@ struct DeviceSplitCandidate {
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
inline BatchParam HistBatch(TrainParam const& param) {
|
inline BatchParam HistBatch(TrainParam const& param, bool prefetch_copy = true) {
|
||||||
return {param.max_bin, TrainParam::DftSparseThreshold()};
|
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
|
||||||
|
p.prefetch_copy = prefetch_copy;
|
||||||
|
p.n_prefetch_batches = 1;
|
||||||
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BatchParam HistBatch(bst_bin_t max_bin) {
|
inline BatchParam HistBatch(bst_bin_t max_bin) {
|
||||||
|
|||||||
@ -13,11 +13,14 @@
|
|||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
namespace {
|
namespace {
|
||||||
template <typename FormatStreamPolicy>
|
class TestEllpackPageRawFormat : public ::testing::TestWithParam<bool> {
|
||||||
void TestEllpackPageRawFormat(FormatStreamPolicy *p_policy) {
|
public:
|
||||||
|
template <typename FormatStreamPolicy>
|
||||||
|
void Run(FormatStreamPolicy *p_policy, bool prefetch_copy) {
|
||||||
auto &policy = *p_policy;
|
auto &policy = *p_policy;
|
||||||
Context ctx{MakeCUDACtx(0)};
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
param.prefetch_copy = prefetch_copy;
|
||||||
|
|
||||||
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
@ -32,7 +35,7 @@ void TestEllpackPageRawFormat(FormatStreamPolicy *p_policy) {
|
|||||||
ASSERT_TRUE(cuts->cut_values_.DeviceCanRead());
|
ASSERT_TRUE(cuts->cut_values_.DeviceCanRead());
|
||||||
policy.SetCuts(cuts, ctx.Device());
|
policy.SetCuts(cuts, ctx.Device());
|
||||||
|
|
||||||
std::unique_ptr<EllpackPageRawFormat> format{policy.CreatePageFormat()};
|
std::unique_ptr<EllpackPageRawFormat> format{policy.CreatePageFormat(param)};
|
||||||
|
|
||||||
std::size_t n_bytes{0};
|
std::size_t n_bytes{0};
|
||||||
{
|
{
|
||||||
@ -59,31 +62,35 @@ void TestEllpackPageRawFormat(FormatStreamPolicy *p_policy) {
|
|||||||
[[maybe_unused]] auto h_orig_acc = orig->GetHostAccessor(&ctx, &h_orig);
|
[[maybe_unused]] auto h_orig_acc = orig->GetHostAccessor(&ctx, &h_orig);
|
||||||
ASSERT_EQ(h_loaded, h_orig);
|
ASSERT_EQ(h_loaded, h_orig);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(EllpackPageRawFormat, DiskIO) {
|
TEST_P(TestEllpackPageRawFormat, DiskIO) {
|
||||||
EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy> policy{false};
|
EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy> policy{false};
|
||||||
TestEllpackPageRawFormat(&policy);
|
this->Run(&policy, this->GetParam());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(EllpackPageRawFormat, DiskIOHmm) {
|
TEST_P(TestEllpackPageRawFormat, DiskIOHmm) {
|
||||||
if (common::SupportsPageableMem()) {
|
if (common::SupportsPageableMem()) {
|
||||||
EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy> policy{true};
|
EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy> policy{true};
|
||||||
TestEllpackPageRawFormat(&policy);
|
this->Run(&policy, this->GetParam());
|
||||||
} else {
|
} else {
|
||||||
GTEST_SKIP_("HMM is not supported.");
|
GTEST_SKIP_("HMM is not supported.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(EllpackPageRawFormat, HostIO) {
|
TEST_P(TestEllpackPageRawFormat, HostIO) {
|
||||||
{
|
{
|
||||||
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy> policy;
|
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy> policy;
|
||||||
TestEllpackPageRawFormat(&policy);
|
this->Run(&policy, this->GetParam());
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto ctx = MakeCUDACtx(0);
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
param.n_prefetch_batches = 1;
|
||||||
|
param.prefetch_copy = this->GetParam();
|
||||||
|
|
||||||
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy> policy;
|
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy> policy;
|
||||||
std::unique_ptr<EllpackPageRawFormat> format{};
|
std::unique_ptr<EllpackPageRawFormat> format{};
|
||||||
Cache cache{false, "name", "ellpack", true};
|
Cache cache{false, "name", "ellpack", true};
|
||||||
@ -92,7 +99,7 @@ TEST(EllpackPageRawFormat, HostIO) {
|
|||||||
for (auto const &page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto const &page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
if (!format) {
|
if (!format) {
|
||||||
policy.SetCuts(page.Impl()->CutsShared(), ctx.Device());
|
policy.SetCuts(page.Impl()->CutsShared(), ctx.Device());
|
||||||
format = policy.CreatePageFormat();
|
format = policy.CreatePageFormat(param);
|
||||||
}
|
}
|
||||||
auto writer = policy.CreateWriter({}, i);
|
auto writer = policy.CreateWriter({}, i);
|
||||||
auto n_bytes = format->Write(page, writer.get());
|
auto n_bytes = format->Write(page, writer.get());
|
||||||
@ -123,4 +130,6 @@ TEST(EllpackPageRawFormat, HostIO) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(EllpackPageRawFormat, TestEllpackPageRawFormat, ::testing::Bool());
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023, XGBoost contributors
|
* Copyright 2021-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h> // for CSCPage, SortedCSCPage, SparsePage
|
#include <xgboost/data.h> // for CSCPage, SortedCSCPage, SparsePage
|
||||||
@ -11,8 +11,6 @@
|
|||||||
#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat
|
#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat
|
||||||
#include "../helpers.h" // for RandomDataGenerator
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
#include "dmlc/filesystem.h" // for TemporaryDirectory
|
#include "dmlc/filesystem.h" // for TemporaryDirectory
|
||||||
#include "dmlc/io.h" // for Stream
|
|
||||||
#include "gtest/gtest_pred_impl.h" // for Test, AssertionResult, ASSERT_EQ, TEST
|
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user