Remove unnecessary fetch operations in external memory. (#10342)
This commit is contained in:
parent
c2e3d4f3cd
commit
d2d01d977a
@ -468,10 +468,7 @@ class BatchIterator {
|
|||||||
return *(*impl_);
|
return *(*impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator!=(const BatchIterator&) const {
|
[[nodiscard]] bool operator!=(const BatchIterator&) const { return !this->AtEnd(); }
|
||||||
CHECK(impl_ != nullptr);
|
|
||||||
return !impl_->AtEnd();
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] bool AtEnd() const {
|
[[nodiscard]] bool AtEnd() const {
|
||||||
CHECK(impl_ != nullptr);
|
CHECK(impl_ != nullptr);
|
||||||
@ -506,13 +503,13 @@ class DMatrix {
|
|||||||
public:
|
public:
|
||||||
/*! \brief default constructor */
|
/*! \brief default constructor */
|
||||||
DMatrix() = default;
|
DMatrix() = default;
|
||||||
/*! \brief meta information of the dataset */
|
/** @brief meta information of the dataset */
|
||||||
virtual MetaInfo& Info() = 0;
|
[[nodiscard]] virtual MetaInfo& Info() = 0;
|
||||||
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
||||||
auto const& ctx = *this->Ctx();
|
auto const& ctx = *this->Ctx();
|
||||||
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
||||||
}
|
}
|
||||||
/*! \brief meta information of the dataset */
|
/** @brief meta information of the dataset */
|
||||||
[[nodiscard]] virtual const MetaInfo& Info() const = 0;
|
[[nodiscard]] virtual const MetaInfo& Info() const = 0;
|
||||||
|
|
||||||
/*! \brief Get thread local memory for returning data from DMatrix. */
|
/*! \brief Get thread local memory for returning data from DMatrix. */
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023, XGBoost contributors
|
* Copyright 2019-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "ellpack_page.cuh"
|
#include "ellpack_page.cuh"
|
||||||
#include "ellpack_page.h" // for EllpackPage
|
#include "ellpack_page.h" // for EllpackPage
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
#include <any> // for any, any_cast
|
#include <any> // for any, any_cast
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
|
||||||
#include <type_traits> // for invoke_result_t
|
#include <type_traits> // for invoke_result_t
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
|||||||
@ -56,10 +56,10 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||||
iter_, reset_, next_};
|
iter_, reset_, next_};
|
||||||
|
|
||||||
uint32_t n_batches = 0;
|
std::uint32_t n_batches = 0;
|
||||||
size_t n_features = 0;
|
bst_feature_t n_features = 0;
|
||||||
size_t n_samples = 0;
|
bst_idx_t n_samples = 0;
|
||||||
size_t nnz = 0;
|
bst_idx_t nnz = 0;
|
||||||
|
|
||||||
auto num_rows = [&]() {
|
auto num_rows = [&]() {
|
||||||
bool type_error {false};
|
bool type_error {false};
|
||||||
@ -72,7 +72,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
};
|
};
|
||||||
auto num_cols = [&]() {
|
auto num_cols = [&]() {
|
||||||
bool type_error {false};
|
bool type_error {false};
|
||||||
size_t n_features = HostAdapterDispatch(
|
bst_feature_t n_features = HostAdapterDispatch(
|
||||||
proxy, [](auto const &value) { return value.NumCols(); }, &type_error);
|
proxy, [](auto const &value) { return value.NumCols(); }, &type_error);
|
||||||
if (type_error) {
|
if (type_error) {
|
||||||
n_features = detail::NFeaturesDevice(proxy);
|
n_features = detail::NFeaturesDevice(proxy);
|
||||||
@ -121,10 +121,9 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
|
|||||||
this->n_batches_, cache_info_.at(id));
|
this->n_batches_, cache_info_.at(id));
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl(Context const* ctx) {
|
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl(Context const *ctx) {
|
||||||
this->InitializeSparsePage(ctx);
|
this->InitializeSparsePage(ctx);
|
||||||
auto begin_iter = BatchIterator<SparsePage>(sparse_page_source_);
|
return BatchSet{BatchIterator<SparsePage>{this->sparse_page_source_}};
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(begin_iter));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||||
@ -143,8 +142,7 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
|||||||
} else {
|
} else {
|
||||||
column_source_->Reset();
|
column_source_->Reset();
|
||||||
}
|
}
|
||||||
auto begin_iter = BatchIterator<CSCPage>(column_source_);
|
return BatchSet{BatchIterator<CSCPage>{this->column_source_}};
|
||||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(begin_iter));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
||||||
@ -158,8 +156,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
|
|||||||
} else {
|
} else {
|
||||||
sorted_column_source_->Reset();
|
sorted_column_source_->Reset();
|
||||||
}
|
}
|
||||||
auto begin_iter = BatchIterator<SortedCSCPage>(sorted_column_source_);
|
return BatchSet{BatchIterator<SortedCSCPage>{this->sorted_column_source_}};
|
||||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(begin_iter));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ctx,
|
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ctx,
|
||||||
@ -169,8 +166,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
|||||||
}
|
}
|
||||||
detail::CheckEmpty(batch_param_, param);
|
detail::CheckEmpty(batch_param_, param);
|
||||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||||
this->InitializeSparsePage(ctx);
|
|
||||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||||
|
this->InitializeSparsePage(ctx);
|
||||||
cache_info_.erase(id);
|
cache_info_.erase(id);
|
||||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||||
LOG(INFO) << "Generating new Gradient Index.";
|
LOG(INFO) << "Generating new Gradient Index.";
|
||||||
@ -190,15 +187,13 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
|||||||
CHECK(ghist_index_source_);
|
CHECK(ghist_index_source_);
|
||||||
ghist_index_source_->Reset();
|
ghist_index_source_->Reset();
|
||||||
}
|
}
|
||||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(ghist_index_source_);
|
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
|
||||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(begin_iter));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
|
||||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023 by XGBoost contributors
|
* Copyright 2021-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
|
|
||||||
@ -21,8 +21,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
|||||||
detail::CheckEmpty(batch_param_, param);
|
detail::CheckEmpty(batch_param_, param);
|
||||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||||
size_t row_stride = 0;
|
size_t row_stride = 0;
|
||||||
this->InitializeSparsePage(ctx);
|
|
||||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||||
|
this->InitializeSparsePage(ctx);
|
||||||
// reinitialize the cache
|
// reinitialize the cache
|
||||||
cache_info_.erase(id);
|
cache_info_.erase(id);
|
||||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||||
@ -52,7 +52,6 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
|||||||
ellpack_page_source_->Reset();
|
ellpack_page_source_->Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
|
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
|
||||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
|
||||||
}
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2015-2023, XGBoost Contributors
|
* Copyright 2015-2024, XGBoost Contributors
|
||||||
* \file sparse_page_dmatrix.h
|
* \file sparse_page_dmatrix.h
|
||||||
* \brief External-memory version of DMatrix.
|
* \brief External-memory version of DMatrix.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -7,12 +7,10 @@
|
|||||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||||
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "ellpack_page_source.h"
|
#include "ellpack_page_source.h"
|
||||||
#include "gradient_index_page_source.h"
|
#include "gradient_index_page_source.h"
|
||||||
@ -22,7 +20,7 @@
|
|||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
/**
|
/**
|
||||||
* \brief DMatrix used for external memory.
|
* @brief DMatrix used for external memory.
|
||||||
*
|
*
|
||||||
* The external memory is created for controlling memory usage by splitting up data into
|
* The external memory is created for controlling memory usage by splitting up data into
|
||||||
* multiple batches. However that doesn't mean we will actually process exactly 1 batch
|
* multiple batches. However that doesn't mean we will actually process exactly 1 batch
|
||||||
@ -51,8 +49,13 @@ namespace xgboost::data {
|
|||||||
* want to change the generated page like Ellpack, pass parameter into `GetBatches` to
|
* want to change the generated page like Ellpack, pass parameter into `GetBatches` to
|
||||||
* re-generate them instead of trying to modify the pages in-place.
|
* re-generate them instead of trying to modify the pages in-place.
|
||||||
*
|
*
|
||||||
* A possible optimization is dropping the sparse page once dependent pages like ellpack
|
* The overall chain of responsibility of external memory DMatrix:
|
||||||
* are constructed and cached.
|
*
|
||||||
|
* User defined iterator (in Python/C/R) -> Proxy DMatrix -> Sparse page Source ->
|
||||||
|
* Other sources (Like Ellpack) -> Sparse Page DMatrix -> Caller
|
||||||
|
*
|
||||||
|
* A possible optimization is skipping the sparse page source for `hist` based algorithms
|
||||||
|
* similar to the Quantile DMatrix.
|
||||||
*/
|
*/
|
||||||
class SparsePageDMatrix : public DMatrix {
|
class SparsePageDMatrix : public DMatrix {
|
||||||
MetaInfo info_;
|
MetaInfo info_;
|
||||||
@ -67,7 +70,7 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
float missing_;
|
float missing_;
|
||||||
Context fmat_ctx_;
|
Context fmat_ctx_;
|
||||||
std::string cache_prefix_;
|
std::string cache_prefix_;
|
||||||
uint32_t n_batches_{0};
|
std::uint32_t n_batches_{0};
|
||||||
// sparse page is the source to other page types, we make a special member function.
|
// sparse page is the source to other page types, we make a special member function.
|
||||||
void InitializeSparsePage(Context const *ctx);
|
void InitializeSparsePage(Context const *ctx);
|
||||||
// Non-virtual version that can be used in constructor
|
// Non-virtual version that can be used in constructor
|
||||||
@ -93,11 +96,11 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MetaInfo &Info() override;
|
[[nodiscard]] MetaInfo &Info() override;
|
||||||
const MetaInfo &Info() const override;
|
[[nodiscard]] const MetaInfo &Info() const override;
|
||||||
Context const *Ctx() const override { return &fmat_ctx_; }
|
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
|
||||||
// The only DMatrix implementation that returns false.
|
// The only DMatrix implementation that returns false.
|
||||||
bool SingleColBlock() const override { return false; }
|
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
||||||
DMatrix *Slice(common::Span<int32_t const>) override {
|
DMatrix *Slice(common::Span<int32_t const>) override {
|
||||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -107,6 +110,20 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool EllpackExists() const override {
|
||||||
|
return static_cast<bool>(ellpack_page_source_);
|
||||||
|
}
|
||||||
|
[[nodiscard]] bool GHistIndexExists() const override {
|
||||||
|
return static_cast<bool>(ghist_index_source_);
|
||||||
|
}
|
||||||
|
[[nodiscard]] bool SparsePageExists() const override {
|
||||||
|
return static_cast<bool>(sparse_page_source_);
|
||||||
|
}
|
||||||
|
// For testing, getter for the number of fetches for sparse page source.
|
||||||
|
[[nodiscard]] auto SparsePageFetchCount() const {
|
||||||
|
return this->sparse_page_source_->FetchCount();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BatchSet<SparsePage> GetRowBatches() override;
|
BatchSet<SparsePage> GetRowBatches() override;
|
||||||
BatchSet<CSCPage> GetColumnBatches(Context const *ctx) override;
|
BatchSet<CSCPage> GetColumnBatches(Context const *ctx) override;
|
||||||
@ -118,24 +135,24 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
// source data pointers.
|
// source data pointers.
|
||||||
std::shared_ptr<SparsePageSource> sparse_page_source_;
|
std::shared_ptr<SparsePageSource> sparse_page_source_;
|
||||||
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
|
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
|
||||||
std::shared_ptr<CSCPageSource> column_source_;
|
std::shared_ptr<CSCPageSource> column_source_;
|
||||||
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
|
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
|
||||||
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
|
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
|
||||||
|
|
||||||
bool EllpackExists() const override { return static_cast<bool>(ellpack_page_source_); }
|
|
||||||
bool GHistIndexExists() const override { return static_cast<bool>(ghist_index_source_); }
|
|
||||||
bool SparsePageExists() const override { return static_cast<bool>(sparse_page_source_); }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline std::string MakeId(std::string prefix, SparsePageDMatrix *ptr) {
|
[[nodiscard]] inline std::string MakeId(std::string prefix, SparsePageDMatrix *ptr) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << ptr;
|
ss << ptr;
|
||||||
return prefix + "-" + ss.str();
|
return prefix + "-" + ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Make cache if it doesn't exist yet.
|
||||||
|
*/
|
||||||
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
|
||||||
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
std::map<std::string, std::shared_ptr<Cache>> *out) {
|
||||||
auto &cache_info = *out;
|
auto &cache_info = *out;
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
#include <algorithm> // for min
|
#include <algorithm> // for min
|
||||||
#include <atomic> // for atomic
|
#include <atomic> // for atomic
|
||||||
#include <cstdio> // for remove
|
#include <cstdio> // for remove
|
||||||
#include <future> // for async
|
#include <future> // for future
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <mutex> // for mutex
|
#include <mutex> // for mutex
|
||||||
#include <numeric> // for partial_sum
|
#include <numeric> // for partial_sum
|
||||||
@ -55,7 +55,7 @@ struct Cache {
|
|||||||
offset.push_back(0);
|
offset.push_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string ShardName(std::string name, std::string format) {
|
[[nodiscard]] static std::string ShardName(std::string name, std::string format) {
|
||||||
CHECK_EQ(format.front(), '.');
|
CHECK_EQ(format.front(), '.');
|
||||||
return name + format;
|
return name + format;
|
||||||
}
|
}
|
||||||
@ -174,7 +174,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
ExceHandler exce_;
|
ExceHandler exce_;
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
|
|
||||||
bool ReadCache() {
|
[[nodiscard]] bool ReadCache() {
|
||||||
CHECK(!at_end_);
|
CHECK(!at_end_);
|
||||||
if (!cache_info_->written) {
|
if (!cache_info_->written) {
|
||||||
return false;
|
return false;
|
||||||
@ -216,7 +216,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
return page;
|
return page;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
||||||
n_prefetch_batches)
|
n_prefetch_batches)
|
||||||
<< "Sparse DMatrix assumes forward iteration.";
|
<< "Sparse DMatrix assumes forward iteration.";
|
||||||
@ -279,9 +278,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] uint32_t Iter() const { return count_; }
|
[[nodiscard]] std::uint32_t Iter() const { return count_; }
|
||||||
|
|
||||||
const S &operator*() const override {
|
[[nodiscard]] S const& operator*() const override {
|
||||||
CHECK(page_);
|
CHECK(page_);
|
||||||
return *page_;
|
return *page_;
|
||||||
}
|
}
|
||||||
@ -311,22 +310,29 @@ inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSup
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
||||||
// This is the source from the user.
|
// This is the source iterator from the user.
|
||||||
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter_;
|
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter_;
|
||||||
DMatrixProxy* proxy_;
|
DMatrixProxy* proxy_;
|
||||||
std::size_t base_row_id_{0};
|
std::size_t base_row_id_{0};
|
||||||
|
bst_idx_t fetch_cnt_{0}; // Used for sanity check.
|
||||||
|
|
||||||
void Fetch() final {
|
void Fetch() final {
|
||||||
|
fetch_cnt_++;
|
||||||
page_ = std::make_shared<SparsePage>();
|
page_ = std::make_shared<SparsePage>();
|
||||||
|
// The first round of reading, this is responsible for initialization.
|
||||||
if (!this->ReadCache()) {
|
if (!this->ReadCache()) {
|
||||||
bool type_error { false };
|
bool type_error{false};
|
||||||
CHECK(proxy_);
|
CHECK(proxy_);
|
||||||
HostAdapterDispatch(proxy_, [&](auto const &adapter_batch) {
|
HostAdapterDispatch(
|
||||||
page_->Push(adapter_batch, this->missing_, this->nthreads_);
|
proxy_,
|
||||||
}, &type_error);
|
[&](auto const& adapter_batch) {
|
||||||
|
page_->Push(adapter_batch, this->missing_, this->nthreads_);
|
||||||
|
},
|
||||||
|
&type_error);
|
||||||
if (type_error) {
|
if (type_error) {
|
||||||
DevicePush(proxy_, missing_, page_.get());
|
DevicePush(proxy_, missing_, page_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
page_->SetBaseRowId(base_row_id_);
|
page_->SetBaseRowId(base_row_id_);
|
||||||
base_row_id_ += page_->Size();
|
base_row_id_ += page_->Size();
|
||||||
n_batches_++;
|
n_batches_++;
|
||||||
@ -351,11 +357,13 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
SparsePageSource& operator++() final {
|
SparsePageSource& operator++() final {
|
||||||
TryLockGuard guard{single_threaded_};
|
TryLockGuard guard{single_threaded_};
|
||||||
count_++;
|
count_++;
|
||||||
|
|
||||||
if (cache_info_->written) {
|
if (cache_info_->written) {
|
||||||
at_end_ = (count_ == n_batches_);
|
at_end_ = (count_ == n_batches_);
|
||||||
} else {
|
} else {
|
||||||
at_end_ = !iter_.Next();
|
at_end_ = !iter_.Next();
|
||||||
}
|
}
|
||||||
|
CHECK_LE(count_, n_batches_);
|
||||||
|
|
||||||
if (at_end_) {
|
if (at_end_) {
|
||||||
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
|
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
|
||||||
@ -381,6 +389,8 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
TryLockGuard guard{single_threaded_};
|
TryLockGuard guard{single_threaded_};
|
||||||
base_row_id_ = 0;
|
base_row_id_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] auto FetchCount() const { return fetch_cnt_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// A mixin for advancing the iterator.
|
// A mixin for advancing the iterator.
|
||||||
@ -394,11 +404,11 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
|
|||||||
bool sync_{true};
|
bool sync_{true};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, std::uint32_t n_batches,
|
||||||
std::shared_ptr<Cache> cache, bool sync)
|
std::shared_ptr<Cache> cache, bool sync)
|
||||||
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
|
||||||
|
|
||||||
PageSourceIncMixIn& operator++() final {
|
[[nodiscard]] PageSourceIncMixIn& operator++() final {
|
||||||
TryLockGuard guard{this->single_threaded_};
|
TryLockGuard guard{this->single_threaded_};
|
||||||
if (sync_) {
|
if (sync_) {
|
||||||
++(*source_);
|
++(*source_);
|
||||||
@ -422,6 +432,13 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
|
|||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Reset() final {
|
||||||
|
if (sync_) {
|
||||||
|
this->source_->Reset();
|
||||||
|
}
|
||||||
|
Super::Reset();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
|
class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
#include <algorithm> // for max, copy, transform
|
#include <algorithm> // for max, copy, transform
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for uint32_t, int32_t
|
#include <cstdint> // for uint32_t, int32_t
|
||||||
#include <exception> // for exception
|
|
||||||
#include <memory> // for allocator, unique_ptr, make_unique, shared_ptr
|
#include <memory> // for allocator, unique_ptr, make_unique, shared_ptr
|
||||||
#include <ostream> // for operator<<, basic_ostream, char_traits
|
#include <ostream> // for operator<<, basic_ostream, char_traits
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
@ -20,7 +19,6 @@
|
|||||||
#include "../common/random.h" // for ColumnSampler
|
#include "../common/random.h" // for ColumnSampler
|
||||||
#include "../common/threading_utils.h" // for ParallelFor
|
#include "../common/threading_utils.h" // for ParallelFor
|
||||||
#include "../common/timer.h" // for Monitor
|
#include "../common/timer.h" // for Monitor
|
||||||
#include "../common/transform_iterator.h" // for IndexTransformIter
|
|
||||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
||||||
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023, XGBoost contributors
|
* Copyright 2021-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/context.h> // for Context
|
#include <xgboost/context.h> // for Context
|
||||||
@ -8,10 +8,10 @@
|
|||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
|
|
||||||
#include "../../../src/common/column_matrix.h"
|
#include "../../../src/common/column_matrix.h"
|
||||||
#include "../../../src/common/io.h" // for MmapResource, AlignedResourceReadStream...
|
#include "../../../src/common/io.h" // for MmapResource, AlignedResourceReadStream...
|
||||||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat
|
||||||
#include "../helpers.h" // for RandomDataGenerator
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
TEST(GHistIndexPageRawFormat, IO) {
|
TEST(GHistIndexPageRawFormat, IO) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2016-2023 by XGBoost Contributors
|
* Copyright 2016-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
@ -115,9 +115,67 @@ TEST(SparsePageDMatrix, RetainSparsePage) {
|
|||||||
TestRetainPage<SortedCSCPage>();
|
TestRetainPage<SortedCSCPage>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test GHistIndexMatrix can avoid loading sparse page after the initialization.
|
||||||
|
TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
|
||||||
|
tmpdir.path + "/", true);
|
||||||
|
Context ctx;
|
||||||
|
bst_bin_t n_bins{256};
|
||||||
|
double sparse_thresh{0.8};
|
||||||
|
BatchParam batch_param{n_bins, sparse_thresh};
|
||||||
|
|
||||||
|
auto check_ghist = [&] {
|
||||||
|
std::int32_t k = 0;
|
||||||
|
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>(&ctx, batch_param)) {
|
||||||
|
ASSERT_EQ(page.Size(), 30);
|
||||||
|
ASSERT_EQ(k, page.base_rowid);
|
||||||
|
k += page.Size();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
check_ghist();
|
||||||
|
|
||||||
|
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||||
|
CHECK(casted);
|
||||||
|
// Make the number of fetches don't change (no new fetch)
|
||||||
|
auto n_init_fetches = casted->SparsePageFetchCount();
|
||||||
|
|
||||||
|
std::vector<float> hess(Xy->Info().num_row_, 1.0f);
|
||||||
|
// Run multiple iterations to make sure fetches are consistent after reset.
|
||||||
|
for (std::int32_t i = 0; i < 4; ++i) {
|
||||||
|
auto n_fetches = casted->SparsePageFetchCount();
|
||||||
|
check_ghist();
|
||||||
|
ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches);
|
||||||
|
if (i == 0) {
|
||||||
|
ASSERT_EQ(n_fetches, n_init_fetches);
|
||||||
|
}
|
||||||
|
// Make sure other page types don't interfere the GHist. This way, we can reuse the
|
||||||
|
// DMatrix for multiple purposes.
|
||||||
|
for ([[maybe_unused]] auto const &page : Xy->GetBatches<SparsePage>(&ctx)) {
|
||||||
|
}
|
||||||
|
for ([[maybe_unused]] auto const &page : Xy->GetBatches<SortedCSCPage>(&ctx)) {
|
||||||
|
}
|
||||||
|
for ([[maybe_unused]] auto const &page : Xy->GetBatches<GHistIndexMatrix>(&ctx, batch_param)) {
|
||||||
|
}
|
||||||
|
// Approx tree method pages
|
||||||
|
{
|
||||||
|
BatchParam regen{n_bins, common::Span{hess.data(), hess.size()}, false};
|
||||||
|
for ([[maybe_unused]] auto const &page : Xy->GetBatches<GHistIndexMatrix>(&ctx, regen)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
BatchParam regen{n_bins, common::Span{hess.data(), hess.size()}, true};
|
||||||
|
for ([[maybe_unused]] auto const &page : Xy->GetBatches<GHistIndexMatrix>(&ctx, regen)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Restore the batch parameter by passing it in again through check_ghist
|
||||||
|
check_ghist();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MetaInfo) {
|
TEST(SparsePageDMatrix, MetaInfo) {
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tmpdir.path + "/simple.libsvm";
|
||||||
size_t constexpr kEntries = 24;
|
size_t constexpr kEntries = 24;
|
||||||
CreateBigTestData(tmp_file, kEntries);
|
CreateBigTestData(tmp_file, kEntries);
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,36 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
|
||||||
|
// Test Ellpack can avoid loading sparse page after the initialization.
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(6).GenerateSparsePageDMatrix(
|
||||||
|
tmpdir.path + "/", true);
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
bst_bin_t n_bins{256};
|
||||||
|
double sparse_thresh{0.8};
|
||||||
|
BatchParam batch_param{n_bins, sparse_thresh};
|
||||||
|
|
||||||
|
std::int32_t k = 0;
|
||||||
|
for (auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
|
auto impl = page.Impl();
|
||||||
|
ASSERT_EQ(page.Size(), 30);
|
||||||
|
ASSERT_EQ(k, impl->base_rowid);
|
||||||
|
k += page.Size();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||||
|
CHECK(casted);
|
||||||
|
// Make the number of fetches don't change (no new fetch)
|
||||||
|
auto n_fetches = casted->SparsePageFetchCount();
|
||||||
|
for (std::int32_t i = 0; i < 3; ++i) {
|
||||||
|
for ([[maybe_unused]] auto const& page : Xy->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
|
}
|
||||||
|
auto casted = std::dynamic_pointer_cast<data::SparsePageDMatrix>(Xy);
|
||||||
|
ASSERT_EQ(casted->SparsePageFetchCount(), n_fetches);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||||
Context ctx{MakeCUDACtx(0)};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|||||||
@ -59,7 +59,7 @@ def test_nested_config() -> None:
|
|||||||
assert verbosity == 1
|
assert verbosity == 1
|
||||||
|
|
||||||
|
|
||||||
def test_thread_safty():
|
def test_thread_safety():
|
||||||
n_threads = multiprocessing.cpu_count()
|
n_threads = multiprocessing.cpu_count()
|
||||||
futures = []
|
futures = []
|
||||||
with ThreadPoolExecutor(max_workers=n_threads) as executor:
|
with ThreadPoolExecutor(max_workers=n_threads) as executor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user