[EM] Add GPU version of the external memory QDM. (#10689)
This commit is contained in:
parent
18b28d9315
commit
d414fdf2e7
@ -641,7 +641,7 @@ class DMatrix {
|
|||||||
typename XGDMatrixCallbackNext>
|
typename XGDMatrixCallbackNext>
|
||||||
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
|
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
|
||||||
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
|
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
|
||||||
std::int32_t nthread, bst_bin_t max_bin, std::string cache);
|
std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host);
|
||||||
|
|
||||||
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
||||||
|
|
||||||
|
|||||||
@ -116,6 +116,13 @@ inline int32_t CurrentDevice() {
|
|||||||
return device;
|
return device;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to get a device from a potentially CPU context.
|
||||||
|
inline auto GetDevice(xgboost::Context const *ctx) {
|
||||||
|
auto d = (ctx->IsCUDA()) ? ctx->Device() : xgboost::DeviceOrd::CUDA(dh::CurrentDevice());
|
||||||
|
CHECK(!d.IsCPU());
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
inline size_t TotalMemory(int device_idx) {
|
inline size_t TotalMemory(int device_idx) {
|
||||||
size_t device_free = 0;
|
size_t device_free = 0;
|
||||||
size_t device_total = 0;
|
size_t device_total = 0;
|
||||||
|
|||||||
@ -914,9 +914,9 @@ template <typename DataIterHandle, typename DMatrixHandle, typename DataIterRese
|
|||||||
typename XGDMatrixCallbackNext>
|
typename XGDMatrixCallbackNext>
|
||||||
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
|
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
|
||||||
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
|
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
|
||||||
std::int32_t nthread, bst_bin_t max_bin, std::string cache) {
|
std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host) {
|
||||||
return new data::ExtMemQuantileDMatrix{
|
return new data::ExtMemQuantileDMatrix{
|
||||||
iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin};
|
iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin, on_host};
|
||||||
}
|
}
|
||||||
|
|
||||||
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
|
||||||
@ -935,7 +935,7 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
|
|||||||
template DMatrix*
|
template DMatrix*
|
||||||
DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback, XGDMatrixCallbackNext>(
|
DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback, XGDMatrixCallbackNext>(
|
||||||
DataIterHandle, DMatrixHandle, std::shared_ptr<DMatrix>, DataIterResetCallback*,
|
DataIterHandle, DMatrixHandle, std::shared_ptr<DMatrix>, DataIterResetCallback*,
|
||||||
XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string);
|
XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string, bool);
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
||||||
|
|||||||
@ -47,6 +47,18 @@ bst_idx_t EllpackPage::Size() const {
|
|||||||
"EllpackPage is required";
|
"EllpackPage is required";
|
||||||
return impl_->Cuts();
|
return impl_->Cuts();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t EllpackPage::BaseRowId() const {
|
||||||
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||||
|
"EllpackPage is required";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool EllpackPage::IsDense() const {
|
||||||
|
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||||
|
"EllpackPage is required";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_USE_CUDA
|
#endif // XGBOOST_USE_CUDA
|
||||||
|
|||||||
@ -39,6 +39,9 @@ void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id)
|
|||||||
return impl_->Cuts();
|
return impl_->Cuts();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t EllpackPage::BaseRowId() const { return this->Impl()->base_rowid; }
|
||||||
|
[[nodiscard]] bool EllpackPage::IsDense() const { return this->Impl()->IsDense(); }
|
||||||
|
|
||||||
// Bin each input data entry, store the bin indices in compressed form.
|
// Bin each input data entry, store the bin indices in compressed form.
|
||||||
__global__ void CompressBinEllpackKernel(
|
__global__ void CompressBinEllpackKernel(
|
||||||
common::CompressedBufferWriter wr,
|
common::CompressedBufferWriter wr,
|
||||||
@ -397,7 +400,7 @@ struct CopyPage {
|
|||||||
size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) {
|
size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) {
|
||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
bst_idx_t num_elements = page->n_rows * page->row_stride;
|
bst_idx_t num_elements = page->n_rows * page->row_stride;
|
||||||
CHECK_EQ(row_stride, page->row_stride);
|
CHECK_EQ(this->row_stride, page->row_stride);
|
||||||
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
||||||
CHECK_GE(n_rows * row_stride, offset + num_elements);
|
CHECK_GE(n_rows * row_stride, offset + num_elements);
|
||||||
if (page == this) {
|
if (page == this) {
|
||||||
|
|||||||
@ -203,6 +203,7 @@ class EllpackPageImpl {
|
|||||||
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
|
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
|
||||||
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) { cuts_ = cuts; }
|
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) { cuts_ = cuts; }
|
||||||
|
|
||||||
|
[[nodiscard]] bool IsDense() const { return is_dense; }
|
||||||
/** @return Estimation of memory cost of this page. */
|
/** @return Estimation of memory cost of this page. */
|
||||||
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
|
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,7 @@ class EllpackPage {
|
|||||||
|
|
||||||
/*! \return Number of instances in the page. */
|
/*! \return Number of instances in the page. */
|
||||||
[[nodiscard]] bst_idx_t Size() const;
|
[[nodiscard]] bst_idx_t Size() const;
|
||||||
|
[[nodiscard]] bool IsDense() const;
|
||||||
|
|
||||||
/*! \brief Set the base row id for this page. */
|
/*! \brief Set the base row id for this page. */
|
||||||
void SetBaseRowId(std::size_t row_id);
|
void SetBaseRowId(std::size_t row_id);
|
||||||
@ -50,6 +51,7 @@ class EllpackPage {
|
|||||||
EllpackPageImpl* Impl() { return impl_.get(); }
|
EllpackPageImpl* Impl() { return impl_.get(); }
|
||||||
|
|
||||||
[[nodiscard]] common::HistogramCuts const& Cuts() const;
|
[[nodiscard]] common::HistogramCuts const& Cuts() const;
|
||||||
|
[[nodiscard]] bst_idx_t BaseRowId() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<EllpackPageImpl> impl_;
|
std::unique_ptr<EllpackPageImpl> impl_;
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include "ellpack_page.cuh" // for EllpackPageImpl
|
#include "ellpack_page.cuh" // for EllpackPageImpl
|
||||||
#include "ellpack_page.h" // for EllpackPage
|
#include "ellpack_page.h" // for EllpackPage
|
||||||
#include "ellpack_page_source.h"
|
#include "ellpack_page_source.h"
|
||||||
|
#include "proxy_dmatrix.cuh" // for Dispatch
|
||||||
#include "xgboost/base.h" // for bst_idx_t
|
#include "xgboost/base.h" // for bst_idx_t
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
@ -182,4 +183,51 @@ template void
|
|||||||
EllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
EllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
template void
|
template void
|
||||||
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ExtEllpackPageSourceImpl
|
||||||
|
*/
|
||||||
|
template <typename F>
|
||||||
|
void ExtEllpackPageSourceImpl<F>::Fetch() {
|
||||||
|
dh::safe_cuda(cudaSetDevice(this->Device().ordinal));
|
||||||
|
if (!this->ReadCache()) {
|
||||||
|
auto iter = this->source_->Iter();
|
||||||
|
CHECK_EQ(this->count_, iter);
|
||||||
|
++(*this->source_);
|
||||||
|
CHECK_GE(this->source_->Iter(), 1);
|
||||||
|
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
|
||||||
|
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
|
||||||
|
auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan();
|
||||||
|
auto n_samples = value.NumRows();
|
||||||
|
|
||||||
|
dh::device_vector<size_t> row_counts(n_samples + 1, 0);
|
||||||
|
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||||
|
cuda_impl::Dispatch(proxy_, [=](auto const& value) {
|
||||||
|
return GetRowCounts(value, row_counts_span, dh::GetDevice(this->ctx_), this->missing_);
|
||||||
|
});
|
||||||
|
|
||||||
|
this->page_.reset(new EllpackPage{});
|
||||||
|
*this->page_->Impl() = EllpackPageImpl{this->ctx_,
|
||||||
|
value,
|
||||||
|
this->missing_,
|
||||||
|
this->info_->IsDense(),
|
||||||
|
row_counts_span,
|
||||||
|
d_feature_types,
|
||||||
|
this->ext_info_.row_stride,
|
||||||
|
n_samples,
|
||||||
|
this->GetCuts()};
|
||||||
|
this->info_->Extend(proxy_->Info(), false, true);
|
||||||
|
});
|
||||||
|
this->page_->SetBaseRowId(this->ext_info_.base_rows.at(iter));
|
||||||
|
this->WriteCache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiation
|
||||||
|
template void
|
||||||
|
ExtEllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
|
template void
|
||||||
|
ExtEllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
|
template void
|
||||||
|
ExtEllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../common/cuda_rt_utils.h" // for SupportsPageableMem
|
#include "../common/cuda_rt_utils.h" // for SupportsPageableMem
|
||||||
#include "../common/hist_util.h" // for HistogramCuts
|
#include "../common/hist_util.h" // for HistogramCuts
|
||||||
@ -169,6 +170,51 @@ using EllpackPageHostSource =
|
|||||||
using EllpackPageSource =
|
using EllpackPageSource =
|
||||||
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||||
|
|
||||||
|
template <typename FormatCreatePolicy>
|
||||||
|
class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, FormatCreatePolicy> {
|
||||||
|
using Super = ExtQantileSourceMixin<EllpackPage, FormatCreatePolicy>;
|
||||||
|
|
||||||
|
Context const* ctx_;
|
||||||
|
BatchParam p_;
|
||||||
|
DMatrixProxy* proxy_;
|
||||||
|
MetaInfo* info_;
|
||||||
|
ExternalDataInfo ext_info_;
|
||||||
|
|
||||||
|
std::vector<bst_idx_t> base_rows_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ExtEllpackPageSourceImpl(
|
||||||
|
Context const* ctx, float missing, MetaInfo* info, ExternalDataInfo ext_info,
|
||||||
|
std::shared_ptr<Cache> cache, BatchParam param, std::shared_ptr<common::HistogramCuts> cuts,
|
||||||
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
|
||||||
|
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
|
||||||
|
: Super{missing,
|
||||||
|
ctx->Threads(),
|
||||||
|
static_cast<bst_feature_t>(info->num_col_),
|
||||||
|
ext_info.n_batches,
|
||||||
|
source,
|
||||||
|
cache},
|
||||||
|
ctx_{ctx},
|
||||||
|
p_{std::move(param)},
|
||||||
|
proxy_{proxy},
|
||||||
|
info_{info},
|
||||||
|
ext_info_{std::move(ext_info)},
|
||||||
|
base_rows_{std::move(base_rows)} {
|
||||||
|
this->SetCuts(std::move(cuts), ctx->Device());
|
||||||
|
this->Fetch();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fetch() final;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Cache to host
|
||||||
|
using ExtEllpackPageHostSource =
|
||||||
|
ExtEllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||||
|
|
||||||
|
// Cache to disk
|
||||||
|
using ExtEllpackPageSource =
|
||||||
|
ExtEllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
template <typename F>
|
template <typename F>
|
||||||
inline void EllpackPageSourceImpl<F>::Fetch() {
|
inline void EllpackPageSourceImpl<F>::Fetch() {
|
||||||
@ -177,6 +223,11 @@ inline void EllpackPageSourceImpl<F>::Fetch() {
|
|||||||
(void)(is_dense_);
|
(void)(is_dense_);
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
inline void ExtEllpackPageSourceImpl<F>::Fetch() {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,8 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix
|
|||||||
DataIterResetCallback *reset,
|
DataIterResetCallback *reset,
|
||||||
XGDMatrixCallbackNext *next, float missing,
|
XGDMatrixCallbackNext *next, float missing,
|
||||||
std::int32_t n_threads, std::string cache,
|
std::int32_t n_threads, std::string cache,
|
||||||
bst_bin_t max_bin)
|
bst_bin_t max_bin, bool on_host)
|
||||||
: cache_prefix_{std::move(cache)} {
|
: cache_prefix_{std::move(cache)}, on_host_{on_host} {
|
||||||
auto iter = std::make_shared<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>(
|
auto iter = std::make_shared<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>(
|
||||||
iter_handle, reset, next);
|
iter_handle, reset, next);
|
||||||
iter->Reset();
|
iter->Reset();
|
||||||
@ -72,13 +72,7 @@ void ExtMemQuantileDMatrix::InitFromCPU(
|
|||||||
common::HistogramCuts cuts;
|
common::HistogramCuts cuts;
|
||||||
ExternalDataInfo ext_info;
|
ExternalDataInfo ext_info;
|
||||||
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
|
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
|
||||||
|
ext_info.SetInfo(ctx, &this->info_);
|
||||||
// From here on Info() has the correct data shape
|
|
||||||
this->Info().num_row_ = ext_info.accumulated_rows;
|
|
||||||
this->Info().num_col_ = ext_info.n_features;
|
|
||||||
this->Info().num_nonzero_ = ext_info.nnz;
|
|
||||||
this->Info().SynchronizeNumberOfColumns(ctx);
|
|
||||||
ext_info.Validate();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate quantiles
|
* Generate quantiles
|
||||||
@ -110,7 +104,7 @@ void ExtMemQuantileDMatrix::InitFromCPU(
|
|||||||
CHECK_EQ(n_total_samples, ext_info.accumulated_rows);
|
CHECK_EQ(n_total_samples, ext_info.accumulated_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndexImpl() {
|
[[nodiscard]] BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndexImpl() {
|
||||||
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
|
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,5 +142,13 @@ BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
|
|||||||
this->ellpack_page_source_);
|
this->ellpack_page_source_);
|
||||||
return batch_set;
|
return batch_set;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackPageImpl() {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
auto batch_set =
|
||||||
|
std::visit([this](auto &&ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
|
||||||
|
this->ellpack_page_source_);
|
||||||
|
return batch_set;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -4,21 +4,81 @@
|
|||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <variant> // for visit
|
#include <variant> // for visit
|
||||||
|
|
||||||
|
#include "batch_utils.h" // for CheckParam, RegenGHist
|
||||||
|
#include "ellpack_page.cuh" // for EllpackPage
|
||||||
#include "extmem_quantile_dmatrix.h"
|
#include "extmem_quantile_dmatrix.h"
|
||||||
|
#include "proxy_dmatrix.h" // for DataIterProxy
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/data.h" // for BatchParam
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
void ExtMemQuantileDMatrix::InitFromCUDA(
|
void ExtMemQuantileDMatrix::InitFromCUDA(
|
||||||
Context const *, std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>,
|
Context const *ctx,
|
||||||
DMatrixHandle, BatchParam const &, float, std::shared_ptr<DMatrix>) {
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
|
||||||
LOG(FATAL) << "Not implemented.";
|
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref) {
|
||||||
|
// A handle passed to external iterator.
|
||||||
|
auto proxy = MakeProxy(proxy_handle);
|
||||||
|
CHECK(proxy);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate quantiles
|
||||||
|
*/
|
||||||
|
auto cuts = std::make_shared<common::HistogramCuts>();
|
||||||
|
ExternalDataInfo ext_info;
|
||||||
|
cuda_impl::MakeSketches(ctx, iter.get(), proxy, ref, p, missing, cuts, this->Info(), &ext_info);
|
||||||
|
ext_info.SetInfo(ctx, &this->info_);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate gradient index
|
||||||
|
*/
|
||||||
|
auto id = MakeCache(this, ".ellpack.page", false, cache_prefix_, &cache_info_);
|
||||||
|
if (on_host_ && std::get_if<EllpackHostPtr>(&ellpack_page_source_) == nullptr) {
|
||||||
|
ellpack_page_source_.emplace<EllpackHostPtr>(nullptr);
|
||||||
|
}
|
||||||
|
std::visit(
|
||||||
|
[&](auto &&ptr) {
|
||||||
|
using SourceT = typename std::remove_reference_t<decltype(ptr)>::element_type;
|
||||||
|
ptr = std::make_shared<SourceT>(ctx, missing, &this->Info(), ext_info, cache_info_.at(id),
|
||||||
|
p, cuts, iter, proxy, ext_info.base_rows);
|
||||||
|
},
|
||||||
|
ellpack_page_source_);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Force initialize the cache and do some sanity checks along the way
|
||||||
|
*/
|
||||||
|
bst_idx_t batch_cnt = 0, k = 0;
|
||||||
|
bst_idx_t n_total_samples = 0;
|
||||||
|
for (auto const &page : this->GetEllpackPageImpl()) {
|
||||||
|
n_total_samples += page.Size();
|
||||||
|
CHECK_EQ(page.Impl()->base_rowid, ext_info.base_rows[k]);
|
||||||
|
CHECK_EQ(page.Impl()->row_stride, ext_info.row_stride);
|
||||||
|
++k, ++batch_cnt;
|
||||||
|
}
|
||||||
|
CHECK_EQ(batch_cnt, ext_info.n_batches);
|
||||||
|
CHECK_EQ(n_total_samples, ext_info.accumulated_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
|
[[nodiscard]] BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackPageImpl() {
|
||||||
const BatchParam &) {
|
|
||||||
LOG(FATAL) << "Not implemented.";
|
|
||||||
auto batch_set =
|
auto batch_set =
|
||||||
std::visit([this](auto &&ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
|
std::visit([this](auto &&ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
|
||||||
this->ellpack_page_source_);
|
this->ellpack_page_source_);
|
||||||
return batch_set;
|
return batch_set;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
|
||||||
|
const BatchParam ¶m) {
|
||||||
|
if (param.Initialized()) {
|
||||||
|
detail::CheckParam(this->batch_, param);
|
||||||
|
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::visit(
|
||||||
|
[this](auto &&ptr) {
|
||||||
|
CHECK(ptr);
|
||||||
|
ptr->Reset();
|
||||||
|
},
|
||||||
|
this->ellpack_page_source_);
|
||||||
|
|
||||||
|
return this->GetEllpackPageImpl();
|
||||||
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
|
|||||||
ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
|
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
|
||||||
XGDMatrixCallbackNext *next, float missing, std::int32_t n_threads,
|
XGDMatrixCallbackNext *next, float missing, std::int32_t n_threads,
|
||||||
std::string cache, bst_bin_t max_bin);
|
std::string cache, bst_bin_t max_bin, bool on_host);
|
||||||
~ExtMemQuantileDMatrix() override;
|
~ExtMemQuantileDMatrix() override;
|
||||||
|
|
||||||
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
[[nodiscard]] bool SingleColBlock() const override { return false; }
|
||||||
@ -45,9 +45,10 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
|
|||||||
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
|
||||||
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref);
|
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref);
|
||||||
|
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndexImpl();
|
[[nodiscard]] BatchSet<GHistIndexMatrix> GetGradientIndexImpl();
|
||||||
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const *ctx, BatchParam const ¶m) override;
|
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const *ctx, BatchParam const ¶m) override;
|
||||||
|
|
||||||
|
[[nodiscard]] BatchSet<EllpackPage> GetEllpackPageImpl();
|
||||||
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override;
|
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override;
|
||||||
|
|
||||||
[[nodiscard]] bool EllpackExists() const override {
|
[[nodiscard]] bool EllpackExists() const override {
|
||||||
@ -60,10 +61,11 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
|
|||||||
|
|
||||||
std::map<std::string, std::shared_ptr<Cache>> cache_info_;
|
std::map<std::string, std::shared_ptr<Cache>> cache_info_;
|
||||||
std::string cache_prefix_;
|
std::string cache_prefix_;
|
||||||
|
bool on_host_;
|
||||||
BatchParam batch_;
|
BatchParam batch_;
|
||||||
|
|
||||||
using EllpackDiskPtr = std::shared_ptr<EllpackPageSource>;
|
using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
|
||||||
using EllpackHostPtr = std::shared_ptr<EllpackPageHostSource>;
|
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;
|
||||||
std::variant<EllpackDiskPtr, EllpackHostPtr> ellpack_page_source_;
|
std::variant<EllpackDiskPtr, EllpackHostPtr> ellpack_page_source_;
|
||||||
std::shared_ptr<ExtGradientIndexPageSource> ghist_index_source_;
|
std::shared_ptr<ExtGradientIndexPageSource> ghist_index_source_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -242,6 +242,7 @@ class GHistIndexMatrix {
|
|||||||
|
|
||||||
[[nodiscard]] bool IsDense() const { return isDense_; }
|
[[nodiscard]] bool IsDense() const { return isDense_; }
|
||||||
void SetDense(bool is_dense) { isDense_ = is_dense; }
|
void SetDense(bool is_dense) { isDense_ = is_dense; }
|
||||||
|
[[nodiscard]] bst_idx_t BaseRowId() const { return base_rowid; }
|
||||||
/**
|
/**
|
||||||
* @brief Get the local row index.
|
* @brief Get the local row index.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -39,45 +39,6 @@ class GHistIndexFormatPolicy {
|
|||||||
void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); }
|
void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename S,
|
|
||||||
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
|
|
||||||
class ExtQantileSourceMixin : public SparsePageSourceImpl<S, FormatCreatePolicy> {
|
|
||||||
protected:
|
|
||||||
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source_;
|
|
||||||
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
ExtQantileSourceMixin(float missing, std::int32_t nthreads, bst_feature_t n_features,
|
|
||||||
bst_idx_t n_batches, std::shared_ptr<Cache> cache)
|
|
||||||
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache} {}
|
|
||||||
// This function always operate on the source first, then the downstream. The downstream
|
|
||||||
// can assume the source to be ready.
|
|
||||||
[[nodiscard]] ExtQantileSourceMixin& operator++() final {
|
|
||||||
TryLockGuard guard{this->single_threaded_};
|
|
||||||
// Increment self.
|
|
||||||
++this->count_;
|
|
||||||
// Set at end.
|
|
||||||
this->at_end_ = this->count_ == this->n_batches_;
|
|
||||||
|
|
||||||
if (this->at_end_) {
|
|
||||||
this->EndIter();
|
|
||||||
|
|
||||||
CHECK(this->cache_info_->written);
|
|
||||||
source_ = nullptr; // release the source
|
|
||||||
}
|
|
||||||
this->Fetch();
|
|
||||||
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reset() final {
|
|
||||||
if (this->source_) {
|
|
||||||
this->source_->Reset();
|
|
||||||
}
|
|
||||||
Super::Reset();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class GradientIndexPageSource
|
class GradientIndexPageSource
|
||||||
: public PageSourceIncMixIn<
|
: public PageSourceIncMixIn<
|
||||||
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
|
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
|
||||||
@ -125,14 +86,13 @@ class ExtGradientIndexPageSource
|
|||||||
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
|
||||||
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
|
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
|
||||||
: ExtQantileSourceMixin{missing, ctx->Threads(), static_cast<bst_feature_t>(info->num_col_),
|
: ExtQantileSourceMixin{missing, ctx->Threads(), static_cast<bst_feature_t>(info->num_col_),
|
||||||
n_batches, cache},
|
n_batches, source, cache},
|
||||||
p_{std::move(param)},
|
p_{std::move(param)},
|
||||||
ctx_{ctx},
|
ctx_{ctx},
|
||||||
proxy_{proxy},
|
proxy_{proxy},
|
||||||
info_{info},
|
info_{info},
|
||||||
feature_types_{info_->feature_types.ConstHostSpan()},
|
feature_types_{info_->feature_types.ConstHostSpan()},
|
||||||
base_rows_{std::move(base_rows)} {
|
base_rows_{std::move(base_rows)} {
|
||||||
this->source_ = source;
|
|
||||||
this->SetCuts(std::move(cuts));
|
this->SetCuts(std::move(cuts));
|
||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,13 +63,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
|||||||
common::HistogramCuts cuts;
|
common::HistogramCuts cuts;
|
||||||
ExternalDataInfo ext_info;
|
ExternalDataInfo ext_info;
|
||||||
cpu_impl::GetDataShape(ctx, proxy, iter, missing, &ext_info);
|
cpu_impl::GetDataShape(ctx, proxy, iter, missing, &ext_info);
|
||||||
|
ext_info.SetInfo(ctx, &this->info_);
|
||||||
// From here on Info() has the correct data shape
|
|
||||||
this->Info().num_row_ = ext_info.accumulated_rows;
|
|
||||||
this->Info().num_col_ = ext_info.n_features;
|
|
||||||
this->Info().num_nonzero_ = ext_info.nnz;
|
|
||||||
this->Info().SynchronizeNumberOfColumns(ctx);
|
|
||||||
ext_info.Validate();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate quantiles
|
* Generate quantiles
|
||||||
|
|||||||
@ -1,20 +1,15 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2020-2024, XGBoost contributors
|
* Copyright 2020-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <algorithm> // for max
|
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
#include <vector> // for vector
|
|
||||||
|
|
||||||
#include "../collective/allreduce.h"
|
|
||||||
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
|
|
||||||
#include "../common/hist_util.cuh"
|
|
||||||
#include "batch_utils.h" // for RegenGHist, CheckParam
|
#include "batch_utils.h" // for RegenGHist, CheckParam
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
#include "ellpack_page.cuh"
|
#include "ellpack_page.cuh"
|
||||||
#include "iterative_dmatrix.h"
|
#include "iterative_dmatrix.h"
|
||||||
#include "proxy_dmatrix.cuh"
|
#include "proxy_dmatrix.cuh"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h" // for BatchSamples, BatchColumns
|
||||||
#include "simple_batch_iterator.h"
|
#include "simple_batch_iterator.h"
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
@ -31,103 +26,32 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|
||||||
auto num_rows = [&]() {
|
|
||||||
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumRows(); });
|
|
||||||
};
|
|
||||||
auto num_cols = [&]() {
|
|
||||||
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumCols(); });
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t row_stride = 0;
|
|
||||||
size_t nnz = 0;
|
|
||||||
// Sketch for all batches.
|
// Sketch for all batches.
|
||||||
std::vector<common::SketchContainer> sketch_containers;
|
|
||||||
size_t batches = 0;
|
|
||||||
size_t accumulated_rows = 0;
|
|
||||||
bst_feature_t cols = 0;
|
|
||||||
|
|
||||||
int32_t current_device;
|
std::int32_t current_device{dh::CurrentDevice()};
|
||||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
|
||||||
auto get_ctx = [&]() {
|
auto get_ctx = [&]() {
|
||||||
Context d_ctx = (ctx->IsCUDA()) ? *ctx : Context{}.MakeCUDA(current_device);
|
Context d_ctx = (ctx->IsCUDA()) ? *ctx : Context{}.MakeCUDA(current_device);
|
||||||
CHECK(!d_ctx.IsCPU());
|
CHECK(!d_ctx.IsCPU());
|
||||||
return d_ctx;
|
return d_ctx;
|
||||||
};
|
};
|
||||||
auto get_device = [&]() {
|
|
||||||
auto d = (ctx->IsCUDA()) ? ctx->Device() : DeviceOrd::CUDA(current_device);
|
|
||||||
CHECK(!d.IsCPU());
|
|
||||||
return d;
|
|
||||||
};
|
|
||||||
fmat_ctx_ = get_ctx();
|
fmat_ctx_ = get_ctx();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate quantiles
|
* Generate quantiles
|
||||||
*/
|
*/
|
||||||
auto cuts = std::make_shared<common::HistogramCuts>();
|
auto cuts = std::make_shared<common::HistogramCuts>();
|
||||||
do {
|
ExternalDataInfo ext_info;
|
||||||
// We use do while here as the first batch is fetched in ctor
|
cuda_impl::MakeSketches(ctx, &iter, proxy, ref, p, missing, cuts, this->Info(), &ext_info);
|
||||||
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
|
ext_info.SetInfo(ctx, &this->info_);
|
||||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
|
||||||
if (cols == 0) {
|
|
||||||
cols = num_cols();
|
|
||||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax);
|
|
||||||
SafeColl(rc);
|
|
||||||
this->info_.num_col_ = cols;
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
|
||||||
}
|
|
||||||
if (!ref) {
|
|
||||||
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, cols, num_rows(),
|
|
||||||
get_device());
|
|
||||||
auto* p_sketch = &sketch_containers.back();
|
|
||||||
proxy->Info().weights_.SetDevice(get_device());
|
|
||||||
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
|
||||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
auto batch_rows = num_rows();
|
|
||||||
accumulated_rows += batch_rows;
|
|
||||||
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
|
|
||||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
|
||||||
row_stride = std::max(row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
|
||||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
|
||||||
}));
|
|
||||||
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
|
|
||||||
batches++;
|
|
||||||
} while (iter.Next());
|
|
||||||
iter.Reset();
|
|
||||||
|
|
||||||
auto n_features = cols;
|
auto init_page = [this, &cuts, &ext_info]() {
|
||||||
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
|
||||||
if (!ref) {
|
|
||||||
HostDeviceVector<FeatureType> ft;
|
|
||||||
common::SketchContainer final_sketch(
|
|
||||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin, cols,
|
|
||||||
accumulated_rows, get_device());
|
|
||||||
for (auto const& sketch : sketch_containers) {
|
|
||||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
|
||||||
final_sketch.FixError();
|
|
||||||
}
|
|
||||||
sketch_containers.clear();
|
|
||||||
sketch_containers.shrink_to_fit();
|
|
||||||
|
|
||||||
final_sketch.MakeCuts(ctx, cuts.get(), this->info_.IsColumnSplit());
|
|
||||||
} else {
|
|
||||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, cuts.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
this->info_.num_row_ = accumulated_rows;
|
|
||||||
this->info_.num_nonzero_ = nnz;
|
|
||||||
|
|
||||||
auto init_page = [this, &cuts, row_stride, accumulated_rows, get_device]() {
|
|
||||||
if (!ellpack_) {
|
if (!ellpack_) {
|
||||||
// Should be put inside the while loop to protect against empty batch. In
|
// Should be put inside the while loop to protect against empty batch. In
|
||||||
// that case device id is invalid.
|
// that case device id is invalid.
|
||||||
ellpack_.reset(new EllpackPage);
|
ellpack_.reset(new EllpackPage);
|
||||||
*(ellpack_->Impl()) =
|
*(ellpack_->Impl()) = EllpackPageImpl(&fmat_ctx_, cuts, this->IsDense(), ext_info.row_stride,
|
||||||
EllpackPageImpl(&fmat_ctx_, cuts, this->IsDense(), row_stride, accumulated_rows);
|
ext_info.accumulated_rows);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -139,43 +63,42 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
size_t n_batches_for_verification = 0;
|
size_t n_batches_for_verification = 0;
|
||||||
while (iter.Next()) {
|
while (iter.Next()) {
|
||||||
init_page();
|
init_page();
|
||||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
||||||
auto rows = num_rows();
|
auto rows = BatchSamples(proxy);
|
||||||
dh::device_vector<size_t> row_counts(rows + 1, 0);
|
dh::device_vector<size_t> row_counts(rows + 1, 0);
|
||||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||||
cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
||||||
});
|
});
|
||||||
auto is_dense = this->IsDense();
|
auto is_dense = this->IsDense();
|
||||||
|
|
||||||
proxy->Info().feature_types.SetDevice(get_device());
|
proxy->Info().feature_types.SetDevice(dh::GetDevice(ctx));
|
||||||
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
|
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
|
||||||
auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||||
return EllpackPageImpl(&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types,
|
return EllpackPageImpl(&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types,
|
||||||
row_stride, rows, cuts);
|
ext_info.row_stride, rows, cuts);
|
||||||
});
|
});
|
||||||
std::size_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset);
|
std::size_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
|
|
||||||
proxy->Info().num_row_ = num_rows();
|
proxy->Info().num_row_ = BatchSamples(proxy);
|
||||||
proxy->Info().num_col_ = cols;
|
proxy->Info().num_col_ = ext_info.n_features;
|
||||||
if (batches != 1) {
|
if (ext_info.n_batches != 1) {
|
||||||
this->info_.Extend(std::move(proxy->Info()), false, true);
|
this->info_.Extend(std::move(proxy->Info()), false, true);
|
||||||
}
|
}
|
||||||
n_batches_for_verification++;
|
n_batches_for_verification++;
|
||||||
}
|
}
|
||||||
CHECK_EQ(batches, n_batches_for_verification)
|
CHECK_EQ(ext_info.n_batches, n_batches_for_verification)
|
||||||
<< "Different number of batches returned between 2 iterations";
|
<< "Different number of batches returned between 2 iterations";
|
||||||
|
|
||||||
if (batches == 1) {
|
if (ext_info.n_batches == 1) {
|
||||||
this->info_ = std::move(proxy->Info());
|
this->info_ = std::move(proxy->Info());
|
||||||
this->info_.num_nonzero_ = nnz;
|
this->info_.num_nonzero_ = ext_info.nnz;
|
||||||
CHECK_EQ(proxy->Info().labels.Size(), 0);
|
CHECK_EQ(proxy->Info().labels.Size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
iter.Reset();
|
iter.Reset();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.SynchronizeNumberOfColumns(ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
|
|||||||
@ -142,13 +142,14 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
|
|||||||
* @brief Shape and basic information for data fetched from an external data iterator.
|
* @brief Shape and basic information for data fetched from an external data iterator.
|
||||||
*/
|
*/
|
||||||
struct ExternalDataInfo {
|
struct ExternalDataInfo {
|
||||||
std::uint64_t n_features = 0; // The number of columns
|
bst_idx_t n_features = 0; // The number of columns
|
||||||
bst_idx_t n_batches = 0; // The number of batches
|
bst_idx_t n_batches = 0; // The number of batches
|
||||||
bst_idx_t accumulated_rows = 0; // The total number of rows
|
bst_idx_t accumulated_rows = 0; // The total number of rows
|
||||||
bst_idx_t nnz = 0; // The number of non-missing values
|
bst_idx_t nnz = 0; // The number of non-missing values
|
||||||
std::vector<bst_idx_t> column_sizes; // The nnz for each column
|
std::vector<bst_idx_t> column_sizes; // The nnz for each column
|
||||||
std::vector<bst_idx_t> batch_nnz; // nnz for each batch
|
std::vector<bst_idx_t> batch_nnz; // nnz for each batch
|
||||||
std::vector<bst_idx_t> base_rows{0}; // base_rowid
|
std::vector<bst_idx_t> base_rows{0}; // base_rowid
|
||||||
|
bst_idx_t row_stride{0}; // Used by ellpack
|
||||||
|
|
||||||
void Validate() const {
|
void Validate() const {
|
||||||
CHECK(std::none_of(this->column_sizes.cbegin(), this->column_sizes.cend(), [&](auto f) {
|
CHECK(std::none_of(this->column_sizes.cbegin(), this->column_sizes.cend(), [&](auto f) {
|
||||||
@ -157,6 +158,16 @@ struct ExternalDataInfo {
|
|||||||
|
|
||||||
CHECK_GE(this->n_features, 1) << "Data must has at least 1 column.";
|
CHECK_GE(this->n_features, 1) << "Data must has at least 1 column.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetInfo(Context const* ctx, MetaInfo* p_info) {
|
||||||
|
// From here on Info() has the correct data shape
|
||||||
|
auto& info = *p_info;
|
||||||
|
info.num_row_ = this->accumulated_rows;
|
||||||
|
info.num_col_ = this->n_features;
|
||||||
|
info.num_nonzero_ = this->nnz;
|
||||||
|
info.SynchronizeNumberOfColumns(ctx);
|
||||||
|
this->Validate();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,10 +1,93 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2024, XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "ellpack_page.cuh"
|
#include <algorithm> // for max
|
||||||
|
#include <numeric> // for partial_sum
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../collective/allreduce.h" // for Allreduce
|
||||||
|
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
|
||||||
|
#include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator
|
||||||
|
#include "../common/hist_util.cuh" // for AdapterDeviceSketch
|
||||||
|
#include "../common/quantile.cuh" // for SketchContainer
|
||||||
|
#include "ellpack_page.cuh" // for EllpackPage
|
||||||
|
#include "proxy_dmatrix.cuh" // for Dispatch
|
||||||
|
#include "proxy_dmatrix.h" // for DataIterProxy
|
||||||
|
#include "quantile_dmatrix.h" // for GetCutsFromRef
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
|
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
|
||||||
*cuts = page.Impl()->Cuts();
|
*cuts = page.Impl()->Cuts();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace cuda_impl {
|
||||||
|
void MakeSketches(Context const* ctx,
|
||||||
|
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>* iter,
|
||||||
|
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, BatchParam const& p,
|
||||||
|
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const& info,
|
||||||
|
ExternalDataInfo* p_ext_info) {
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
std::vector<common::SketchContainer> sketch_containers;
|
||||||
|
auto& ext_info = *p_ext_info;
|
||||||
|
|
||||||
|
do {
|
||||||
|
// We use do while here as the first batch is fetched in ctor
|
||||||
|
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
|
||||||
|
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
||||||
|
if (ext_info.n_features == 0) {
|
||||||
|
ext_info.n_features = data::BatchColumns(proxy);
|
||||||
|
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&ext_info.n_features, 1),
|
||||||
|
collective::Op::kMax);
|
||||||
|
SafeColl(rc);
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(ext_info.n_features, ::xgboost::data::BatchColumns(proxy))
|
||||||
|
<< "Inconsistent number of columns.";
|
||||||
|
}
|
||||||
|
if (!ref) {
|
||||||
|
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, ext_info.n_features,
|
||||||
|
data::BatchSamples(proxy), dh::GetDevice(ctx));
|
||||||
|
auto* p_sketch = &sketch_containers.back();
|
||||||
|
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
|
||||||
|
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||||
|
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
auto batch_rows = data::BatchSamples(proxy);
|
||||||
|
ext_info.accumulated_rows += batch_rows;
|
||||||
|
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
|
||||||
|
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||||
|
ext_info.row_stride =
|
||||||
|
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||||
|
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
|
||||||
|
}));
|
||||||
|
ext_info.nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
|
||||||
|
ext_info.n_batches++;
|
||||||
|
ext_info.base_rows.push_back(batch_rows);
|
||||||
|
} while (iter->Next());
|
||||||
|
iter->Reset();
|
||||||
|
|
||||||
|
CHECK_GE(ext_info.n_features, 1) << "Data must has at least 1 column.";
|
||||||
|
std::partial_sum(ext_info.base_rows.cbegin(), ext_info.base_rows.cend(),
|
||||||
|
ext_info.base_rows.begin());
|
||||||
|
|
||||||
|
// Get reference
|
||||||
|
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
|
||||||
|
if (!ref) {
|
||||||
|
HostDeviceVector<FeatureType> ft;
|
||||||
|
common::SketchContainer final_sketch(
|
||||||
|
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin,
|
||||||
|
ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx));
|
||||||
|
for (auto const& sketch : sketch_containers) {
|
||||||
|
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||||
|
final_sketch.FixError();
|
||||||
|
}
|
||||||
|
sketch_containers.clear();
|
||||||
|
sketch_containers.shrink_to_fit();
|
||||||
|
|
||||||
|
final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
|
||||||
|
} else {
|
||||||
|
GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace cuda_impl
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -104,4 +104,12 @@ void MakeSketches(Context const *ctx,
|
|||||||
common::HistogramCuts *cuts, BatchParam const &p, MetaInfo const &info,
|
common::HistogramCuts *cuts, BatchParam const &p, MetaInfo const &info,
|
||||||
ExternalDataInfo const &ext_info, std::vector<FeatureType> *p_h_ft);
|
ExternalDataInfo const &ext_info, std::vector<FeatureType> *p_h_ft);
|
||||||
} // namespace cpu_impl
|
} // namespace cpu_impl
|
||||||
|
|
||||||
|
namespace cuda_impl {
|
||||||
|
void MakeSketches(Context const *ctx,
|
||||||
|
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> *iter,
|
||||||
|
DMatrixProxy *proxy, std::shared_ptr<DMatrix> ref, BatchParam const &p,
|
||||||
|
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const &info,
|
||||||
|
ExternalDataInfo *p_ext_info);
|
||||||
|
} // namespace cuda_impl
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -38,30 +38,23 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||||
iter_, reset_, next_};
|
iter_, reset_, next_};
|
||||||
|
|
||||||
std::uint32_t n_batches = 0;
|
ExternalDataInfo ext_info;
|
||||||
bst_feature_t n_features = 0;
|
|
||||||
bst_idx_t n_samples = 0;
|
|
||||||
bst_idx_t nnz = 0;
|
|
||||||
|
|
||||||
// The proxy is iterated together with the sparse page source so we can obtain all
|
// The proxy is iterated together with the sparse page source so we can obtain all
|
||||||
// information in 1 pass.
|
// information in 1 pass.
|
||||||
for (auto const &page : this->GetRowBatchesImpl(&ctx)) {
|
for (auto const &page : this->GetRowBatchesImpl(&ctx)) {
|
||||||
this->info_.Extend(std::move(proxy->Info()), false, false);
|
this->info_.Extend(std::move(proxy->Info()), false, false);
|
||||||
n_features = std::max(n_features, BatchColumns(proxy));
|
ext_info.n_features =
|
||||||
n_samples += BatchSamples(proxy);
|
std::max(static_cast<bst_feature_t>(ext_info.n_features), BatchColumns(proxy));
|
||||||
nnz += page.data.Size();
|
ext_info.accumulated_rows += BatchSamples(proxy);
|
||||||
n_batches++;
|
ext_info.nnz += page.data.Size();
|
||||||
|
ext_info.n_batches++;
|
||||||
}
|
}
|
||||||
|
|
||||||
iter.Reset();
|
iter.Reset();
|
||||||
|
|
||||||
this->n_batches_ = n_batches;
|
this->n_batches_ = ext_info.n_batches;
|
||||||
this->info_.num_row_ = n_samples;
|
ext_info.SetInfo(&ctx, &this->info_);
|
||||||
this->info_.num_col_ = n_features;
|
|
||||||
this->info_.num_nonzero_ = nnz;
|
|
||||||
|
|
||||||
info_.SynchronizeNumberOfColumns(&ctx);
|
|
||||||
CHECK_NE(info_.num_col_, 0);
|
|
||||||
|
|
||||||
fmat_ctx_ = ctx;
|
fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -585,5 +585,50 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
|
|||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief operator++ implementation for QDM.
|
||||||
|
*/
|
||||||
|
template <typename S,
|
||||||
|
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
|
||||||
|
class ExtQantileSourceMixin : public SparsePageSourceImpl<S, FormatCreatePolicy> {
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source_;
|
||||||
|
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ExtQantileSourceMixin(
|
||||||
|
float missing, std::int32_t nthreads, bst_feature_t n_features, bst_idx_t n_batches,
|
||||||
|
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
|
||||||
|
std::shared_ptr<Cache> cache)
|
||||||
|
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache},
|
||||||
|
source_{std::move(source)} {}
|
||||||
|
// This function always operate on the source first, then the downstream. The downstream
|
||||||
|
// can assume the source to be ready.
|
||||||
|
[[nodiscard]] ExtQantileSourceMixin& operator++() final {
|
||||||
|
TryLockGuard guard{this->single_threaded_};
|
||||||
|
// Increment self.
|
||||||
|
++this->count_;
|
||||||
|
// Set at end.
|
||||||
|
this->at_end_ = this->count_ == this->n_batches_;
|
||||||
|
|
||||||
|
if (this->at_end_) {
|
||||||
|
this->EndIter();
|
||||||
|
|
||||||
|
CHECK(this->cache_info_->written);
|
||||||
|
source_ = nullptr; // release the source
|
||||||
|
}
|
||||||
|
this->Fetch();
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() final {
|
||||||
|
if (this->source_) {
|
||||||
|
this->source_->Reset();
|
||||||
|
}
|
||||||
|
Super::Reset();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2024, XGBoost Contributors
|
* Copyright 2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include "test_extmem_quantile_dmatrix.h" // for TestExtMemQdmBasic
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h> // for BatchParam
|
#include <xgboost/data.h> // for BatchParam
|
||||||
|
|
||||||
@ -9,76 +11,30 @@
|
|||||||
#include "../../../src/common/column_matrix.h" // for ColumnMatrix
|
#include "../../../src/common/column_matrix.h" // for ColumnMatrix
|
||||||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "../../../src/tree/param.h" // for TrainParam
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../helpers.h" // for RandomDataGenerator
|
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
namespace {
|
namespace {
|
||||||
class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam<float> {
|
class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam<float> {
|
||||||
public:
|
public:
|
||||||
void Run(float sparsity) {
|
void Run(float sparsity) {
|
||||||
bst_idx_t n_samples = 256, n_features = 16, n_batches = 4;
|
auto equal = [](Context const*, GHistIndexMatrix const& orig, GHistIndexMatrix const& sparse) {
|
||||||
bst_bin_t max_bin = 64;
|
|
||||||
bst_target_t n_targets = 3;
|
|
||||||
auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}
|
|
||||||
.Bins(max_bin)
|
|
||||||
.Batches(n_batches)
|
|
||||||
.Targets(n_targets)
|
|
||||||
.GenerateExtMemQuantileDMatrix("temp", true);
|
|
||||||
ASSERT_FALSE(p_fmat->SingleColBlock());
|
|
||||||
|
|
||||||
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
|
|
||||||
Context ctx;
|
|
||||||
|
|
||||||
// Loop over the batches and count the number of pages
|
|
||||||
bst_idx_t batch_cnt = 0;
|
|
||||||
bst_idx_t base_cnt = 0;
|
|
||||||
bst_idx_t row_cnt = 0;
|
|
||||||
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p)) {
|
|
||||||
ASSERT_EQ(page.base_rowid, base_cnt);
|
|
||||||
++batch_cnt;
|
|
||||||
base_cnt += n_samples / n_batches;
|
|
||||||
row_cnt += page.Size();
|
|
||||||
ASSERT_EQ((sparsity == 0.0f), page.IsDense());
|
|
||||||
}
|
|
||||||
ASSERT_EQ(n_batches, batch_cnt);
|
|
||||||
ASSERT_EQ(p_fmat->Info().num_row_, n_samples);
|
|
||||||
EXPECT_EQ(p_fmat->Info().num_row_, row_cnt);
|
|
||||||
ASSERT_EQ(p_fmat->Info().num_col_, n_features);
|
|
||||||
if (sparsity == 0.0f) {
|
|
||||||
ASSERT_EQ(p_fmat->Info().num_nonzero_, n_samples * n_features);
|
|
||||||
} else {
|
|
||||||
ASSERT_LT(p_fmat->Info().num_nonzero_, n_samples * n_features);
|
|
||||||
ASSERT_GT(p_fmat->Info().num_nonzero_, 0);
|
|
||||||
}
|
|
||||||
ASSERT_EQ(p_fmat->Info().labels.Shape(0), n_samples);
|
|
||||||
ASSERT_EQ(p_fmat->Info().labels.Shape(1), n_targets);
|
|
||||||
|
|
||||||
// Compare against the sparse page DMatrix
|
|
||||||
auto p_sparse = RandomDataGenerator{n_samples, n_features, sparsity}
|
|
||||||
.Bins(max_bin)
|
|
||||||
.Batches(n_batches)
|
|
||||||
.Targets(n_targets)
|
|
||||||
.GenerateSparsePageDMatrix("temp", true);
|
|
||||||
auto it = p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p).begin();
|
|
||||||
for (auto const& page : p_sparse->GetBatches<GHistIndexMatrix>(&ctx, p)) {
|
|
||||||
auto orig = it.Page();
|
|
||||||
// Check the CSR matrix
|
// Check the CSR matrix
|
||||||
auto orig_cuts = it.Page()->Cuts();
|
auto orig_cuts = orig.Cuts();
|
||||||
auto sparse_cuts = page.Cuts();
|
auto sparse_cuts = sparse.Cuts();
|
||||||
ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values());
|
ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values());
|
||||||
ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues());
|
ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues());
|
||||||
ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs());
|
ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs());
|
||||||
|
|
||||||
auto orig_ptr = orig->data.data();
|
auto orig_ptr = orig.data.data();
|
||||||
auto sparse_ptr = page.data.data();
|
auto sparse_ptr = sparse.data.data();
|
||||||
ASSERT_EQ(orig->data.size(), page.data.size());
|
ASSERT_EQ(orig.data.size(), sparse.data.size());
|
||||||
|
|
||||||
auto equal = std::equal(orig_ptr, orig_ptr + orig->data.size(), sparse_ptr);
|
auto equal = std::equal(orig_ptr, orig_ptr + orig.data.size(), sparse_ptr);
|
||||||
ASSERT_TRUE(equal);
|
ASSERT_TRUE(equal);
|
||||||
|
|
||||||
// Check the column matrix
|
// Check the column matrix
|
||||||
common::ColumnMatrix const& orig_columns = orig->Transpose();
|
common::ColumnMatrix const& orig_columns = orig.Transpose();
|
||||||
common::ColumnMatrix const& sparse_columns = page.Transpose();
|
common::ColumnMatrix const& sparse_columns = sparse.Transpose();
|
||||||
|
|
||||||
std::string str_orig, str_sparse;
|
std::string str_orig, str_sparse;
|
||||||
common::AlignedMemWriteStream fo_orig{&str_orig}, fo_sparse{&str_sparse};
|
common::AlignedMemWriteStream fo_orig{&str_orig}, fo_sparse{&str_sparse};
|
||||||
@ -86,18 +42,10 @@ class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam<float> {
|
|||||||
auto n_bytes_sparse = sparse_columns.Write(&fo_sparse);
|
auto n_bytes_sparse = sparse_columns.Write(&fo_sparse);
|
||||||
ASSERT_EQ(n_bytes_orig, n_bytes_sparse);
|
ASSERT_EQ(n_bytes_orig, n_bytes_sparse);
|
||||||
ASSERT_EQ(str_orig, str_sparse);
|
ASSERT_EQ(str_orig, str_sparse);
|
||||||
|
};
|
||||||
|
|
||||||
++it;
|
Context ctx;
|
||||||
}
|
TestExtMemQdmBasic<GHistIndexMatrix>(&ctx, false, sparsity, equal);
|
||||||
|
|
||||||
// Check meta info
|
|
||||||
auto h_y_sparse = p_sparse->Info().labels.HostView();
|
|
||||||
auto h_y = p_fmat->Info().labels.HostView();
|
|
||||||
for (std::size_t i = 0, m = h_y_sparse.Shape(0); i < m; ++i) {
|
|
||||||
for (std::size_t j = 0, n = h_y_sparse.Shape(1); j < n; ++j) {
|
|
||||||
ASSERT_EQ(h_y(i, j), h_y_sparse(i, j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|||||||
45
tests/cpp/data/test_extmem_quantile_dmatrix.cu
Normal file
45
tests/cpp/data/test_extmem_quantile_dmatrix.cu
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/data.h> // for BatchParam
|
||||||
|
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
|
||||||
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
|
#include "test_extmem_quantile_dmatrix.h" // for TestExtMemQdmBasic
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
|
class ExtMemQuantileDMatrixGpu : public ::testing::TestWithParam<float> {
|
||||||
|
public:
|
||||||
|
void Run(float sparsity) {
|
||||||
|
auto equal = [](Context const* ctx, EllpackPage const& orig, EllpackPage const& sparse) {
|
||||||
|
auto const& orig_cuts = orig.Cuts();
|
||||||
|
auto const& sparse_cuts = sparse.Cuts();
|
||||||
|
ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values());
|
||||||
|
ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues());
|
||||||
|
ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs());
|
||||||
|
|
||||||
|
std::vector<common::CompressedByteT> h_orig, h_sparse;
|
||||||
|
auto orig_acc = orig.Impl()->GetHostAccessor(ctx, &h_orig, {});
|
||||||
|
auto sparse_acc = sparse.Impl()->GetHostAccessor(ctx, &h_sparse, {});
|
||||||
|
ASSERT_EQ(h_orig.size(), h_sparse.size());
|
||||||
|
|
||||||
|
auto equal = std::equal(h_orig.cbegin(), h_orig.cend(), h_sparse.cbegin());
|
||||||
|
ASSERT_TRUE(equal);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
TestExtMemQdmBasic<EllpackPage>(&ctx, true, sparsity, equal);
|
||||||
|
TestExtMemQdmBasic<EllpackPage>(&ctx, false, sparsity, equal);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(ExtMemQuantileDMatrixGpu, Basic) { this->Run(this->GetParam()); }
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ExtMemQuantileDMatrix, ExtMemQuantileDMatrixGpu, ::testing::ValuesIn([] {
|
||||||
|
std::vector<float> sparsities{0.0f, 0.2f, 0.4f, 0.8f};
|
||||||
|
return sparsities;
|
||||||
|
}()));
|
||||||
|
} // namespace xgboost::data
|
||||||
73
tests/cpp/data/test_extmem_quantile_dmatrix.h
Normal file
73
tests/cpp/data/test_extmem_quantile_dmatrix.h
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <xgboost/base.h>
|
||||||
|
#include <xgboost/context.h>
|
||||||
|
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
|
template <typename Page, typename Equal>
|
||||||
|
void TestExtMemQdmBasic(Context const* ctx, bool on_host, float sparsity, Equal&& check_equal) {
|
||||||
|
bst_idx_t n_samples = 256, n_features = 16, n_batches = 4;
|
||||||
|
bst_bin_t max_bin = 64;
|
||||||
|
bst_target_t n_targets = 3;
|
||||||
|
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
|
||||||
|
|
||||||
|
auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}
|
||||||
|
.Bins(max_bin)
|
||||||
|
.Batches(n_batches)
|
||||||
|
.Targets(n_targets)
|
||||||
|
.Device(ctx->Device())
|
||||||
|
.OnHost(on_host)
|
||||||
|
.GenerateExtMemQuantileDMatrix("temp", true);
|
||||||
|
ASSERT_FALSE(p_fmat->SingleColBlock());
|
||||||
|
|
||||||
|
// Loop over the batches and count the number of pages
|
||||||
|
bst_idx_t batch_cnt = 0, base_cnt = 0, row_cnt = 0;
|
||||||
|
for (auto const& page : p_fmat->GetBatches<Page>(ctx, p)) {
|
||||||
|
ASSERT_EQ(page.BaseRowId(), base_cnt);
|
||||||
|
++batch_cnt;
|
||||||
|
base_cnt += n_samples / n_batches;
|
||||||
|
row_cnt += page.Size();
|
||||||
|
ASSERT_EQ((sparsity == 0.0f), page.IsDense());
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_batches, batch_cnt);
|
||||||
|
ASSERT_EQ(p_fmat->Info().num_row_, n_samples);
|
||||||
|
EXPECT_EQ(p_fmat->Info().num_row_, row_cnt);
|
||||||
|
ASSERT_EQ(p_fmat->Info().num_col_, n_features);
|
||||||
|
if (sparsity == 0.0f) {
|
||||||
|
ASSERT_EQ(p_fmat->Info().num_nonzero_, n_samples * n_features);
|
||||||
|
} else {
|
||||||
|
ASSERT_LT(p_fmat->Info().num_nonzero_, n_samples * n_features);
|
||||||
|
ASSERT_GT(p_fmat->Info().num_nonzero_, 0);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(p_fmat->Info().labels.Shape(0), n_samples);
|
||||||
|
ASSERT_EQ(p_fmat->Info().labels.Shape(1), n_targets);
|
||||||
|
|
||||||
|
// Compare against the sparse page DMatrix
|
||||||
|
auto p_sparse = RandomDataGenerator{n_samples, n_features, sparsity}
|
||||||
|
.Bins(max_bin)
|
||||||
|
.Batches(n_batches)
|
||||||
|
.Targets(n_targets)
|
||||||
|
.Device(ctx->Device())
|
||||||
|
.OnHost(on_host)
|
||||||
|
.GenerateSparsePageDMatrix("temp", true);
|
||||||
|
auto it = p_fmat->GetBatches<Page>(ctx, p).begin();
|
||||||
|
for (auto const& page : p_sparse->GetBatches<Page>(ctx, p)) {
|
||||||
|
auto orig = it.Page();
|
||||||
|
check_equal(ctx, *orig, page);
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check meta info
|
||||||
|
auto h_y_sparse = p_sparse->Info().labels.HostView();
|
||||||
|
auto h_y = p_fmat->Info().labels.HostView();
|
||||||
|
for (std::size_t i = 0, m = h_y_sparse.Shape(0); i < m; ++i) {
|
||||||
|
for (std::size_t j = 0, n = h_y_sparse.Shape(1); j < n; ++j) {
|
||||||
|
ASSERT_EQ(h_y(i, j), h_y_sparse(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::data
|
||||||
@ -483,12 +483,15 @@ void RandomDataGenerator::GenerateCSR(
|
|||||||
}
|
}
|
||||||
CHECK(iter);
|
CHECK(iter);
|
||||||
|
|
||||||
std::shared_ptr<DMatrix> p_fmat{
|
std::shared_ptr<DMatrix> p_fmat{DMatrix::Create(
|
||||||
DMatrix::Create(static_cast<DataIterHandle>(iter.get()), iter->Proxy(), nullptr, Reset, Next,
|
static_cast<DataIterHandle>(iter.get()), iter->Proxy(), nullptr, Reset, Next,
|
||||||
std::numeric_limits<float>::quiet_NaN(), 0, this->bins_, prefix)};
|
std::numeric_limits<float>::quiet_NaN(), 0, this->bins_, prefix, this->on_host_)};
|
||||||
|
|
||||||
auto page_path = data::MakeId(prefix, p_fmat.get()) + ".gradient_index.page";
|
auto page_path = data::MakeId(prefix, p_fmat.get());
|
||||||
|
page_path += device_.IsCPU() ? ".gradient_index.page" : ".ellpack.page";
|
||||||
|
if (!this->on_host_) {
|
||||||
EXPECT_TRUE(FileExists(page_path)) << page_path;
|
EXPECT_TRUE(FileExists(page_path)) << page_path;
|
||||||
|
}
|
||||||
|
|
||||||
if (with_label) {
|
if (with_label) {
|
||||||
RandomDataGenerator{static_cast<bst_idx_t>(p_fmat->Info().num_row_), this->n_targets_, 0.0f}
|
RandomDataGenerator{static_cast<bst_idx_t>(p_fmat->Info().num_row_), this->n_targets_, 0.0f}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user