[EM] CPU implementation for external memory QDM. (#10682)

- A new DMatrix type.
- Extract common code into a new QDM base class.

Not yet working:
- Not exposed to the interface yet, will wait for the GPU implementation.
- ~No meta info yet, still working on the source.~
- Exporting data to CSR is not supported yet.
This commit is contained in:
Jiaming Yuan 2024-08-09 09:38:02 +08:00 committed by GitHub
parent ac8366654b
commit 7bccc1ea2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1198 additions and 497 deletions

View File

@ -73,6 +73,9 @@ OBJECTS= \
$(PKGROOT)/src/data/gradient_index_format.o \
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
$(PKGROOT)/src/data/sparse_page_source.o \
$(PKGROOT)/src/data/extmem_quantile_dmatrix.o \
$(PKGROOT)/src/data/quantile_dmatrix.o \
$(PKGROOT)/src/data/batch_utils.o \
$(PKGROOT)/src/data/proxy_dmatrix.o \
$(PKGROOT)/src/data/iterative_dmatrix.o \
$(PKGROOT)/src/predictor/predictor.o \

View File

@ -73,6 +73,9 @@ OBJECTS= \
$(PKGROOT)/src/data/gradient_index_format.o \
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
$(PKGROOT)/src/data/sparse_page_source.o \
$(PKGROOT)/src/data/extmem_quantile_dmatrix.o \
$(PKGROOT)/src/data/quantile_dmatrix.o \
$(PKGROOT)/src/data/batch_utils.o \
$(PKGROOT)/src/data/proxy_dmatrix.o \
$(PKGROOT)/src/data/iterative_dmatrix.o \
$(PKGROOT)/src/predictor/predictor.o \

View File

@ -17,6 +17,7 @@
#include <xgboost/string_view.h>
#include <algorithm>
#include <cstdint> // for int32_t, uint8_t
#include <limits>
#include <memory>
#include <string>
@ -499,8 +500,12 @@ class BatchSet {
struct XGBAPIThreadLocalEntry;
/*!
* \brief Internal data structured used by XGBoost during training.
/**
* @brief Internal data structured used by XGBoost to hold all external data.
*
* There are multiple variants of the DMatrix class and can be accessed through the
* @ref Create() methods. The DMatrix itself holds the predictor `X`, and other data
* including labels and sample weights are stored in the @ref MetaInfo class.
*/
class DMatrix {
public:
@ -518,13 +523,13 @@ class DMatrix {
/*! \brief Get thread local memory for returning data from DMatrix. */
[[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
/**
* \brief Get the context object of this DMatrix. The context is created during construction of
* @brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
[[nodiscard]] virtual Context const* Ctx() const = 0;
/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
* @brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
template <typename T>
BatchSet<T> GetBatches();
@ -548,27 +553,27 @@ class DMatrix {
[[nodiscard]] bool IsDense() const { return this->Info().IsDense(); }
/**
* \brief Load DMatrix from URI.
* @brief Load DMatrix from URI.
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode Indicate how the data was split beforehand.
* \return The created DMatrix.
* @param uri The URI of input.
* @param silent Whether print information during loading.
* @param data_split_mode Indicate how the data was split beforehand.
* @return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri, bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Creates a new DMatrix from an external data adapter.
* @brief Creates a new DMatrix from an external data adapter.
*
* \tparam AdapterT Type of the adapter.
* \param [in,out] adapter View onto an external data.
* \param missing Values to count as missing.
* \param nthread Number of threads for construction.
* \param cache_prefix (Optional) The cache prefix for external memory.
* \param data_split_mode (Optional) Data split mode.
* @tparam AdapterT Type of the adapter.
* @param [in,out] adapter View onto an external data.
* @param missing Values to count as missing.
* @param nthread Number of threads for construction.
* @param cache_prefix (Optional) The cache prefix for external memory.
* @param data_split_mode (Optional) Data split mode.
*
* \return a Created DMatrix.
* @return a Created DMatrix.
*/
template <typename AdapterT>
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
@ -576,29 +581,29 @@ class DMatrix {
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
* @brief Create a new Quantile based DMatrix used for histogram based algorithm.
*
* \tparam DataIterHandle External iterator type, defined in C API.
* \tparam DMatrixHandle DMatrix handle, defined in C API.
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
* @tparam DataIterHandle External iterator type, defined in C API.
* @tparam DMatrixHandle DMatrix handle, defined in C API.
* @tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param ref Reference Quantile DMatrix.
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
* \param nthread number of threads used for initialization.
* \param max_bin Maximum number of bins.
* @param iter External data iterator
* @param proxy A hanlde to ProxyDMatrix
* @param ref Reference Quantile DMatrix.
* @param reset Callback for reset
* @param next Callback for next
* @param missing Value that should be treated as missing.
* @param nthread number of threads used for initialization.
* @param max_bin Maximum number of bins.
*
* \return A created quantile based DMatrix.
* @return A created quantile based DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin);
std::int32_t nthread, bst_bin_t max_bin);
/**
* @brief Create an external memory DMatrix with callbacks.
@ -622,9 +627,22 @@ class DMatrix {
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int32_t nthread,
XGDMatrixCallbackNext* next, float missing, std::int32_t nthread,
std::string cache, bool on_host);
/**
* @brief Create an external memory quantile DMatrix with callbacks.
*
* Parameters are a combination of the external memory DMatrix and the quantile DMatrix.
*
* @return A created external memory quantile DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
std::int32_t nthread, bst_bin_t max_bin, std::string cache);
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
/**

13
src/data/batch_utils.cc Normal file
View File

@ -0,0 +1,13 @@
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "batch_utils.h"
#include "../common/error_msg.h" // for InconsistentMaxBin
namespace xgboost::data::detail {
void CheckParam(BatchParam const& init, BatchParam const& param) {
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
}
} // namespace xgboost::data::detail

View File

@ -29,5 +29,10 @@ inline bool RegenGHist(BatchParam old, BatchParam p) {
}
return p.regen || old.ParamNotEqual(p);
}
/**
* @brief Validate the batch parameter from the caller
*/
void CheckParam(BatchParam const& init, BatchParam const& param);
} // namespace xgboost::data::detail
#endif // XGBOOST_DATA_BATCH_UTILS_H_

View File

@ -17,43 +17,44 @@
#include <tuple> // for get, apply
#include <type_traits> // for remove_pointer_t, remove_reference
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
#include "../collective/allgather.h"
#include "../collective/allreduce.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/math.h" // for CheckNAN
#include "../common/numeric.h" // for Iota, RunLengthEncode
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/version.h" // for Version
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/common.h" // for OMPException
#include "dmlc/data.h" // for Parser
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
#include "dmlc/io.h" // for Stream
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "ellpack_page.h" // for EllpackPage
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
#include "gradient_index.h" // for GHistIndexMatrix
#include "simple_dmatrix.h" // for SimpleDMatrix
#include "sparse_page_writer.h" // for SparsePageFormatReg
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
#include "xgboost/base.h" // for bst_group_t, bst_idx_t, bst_float, bst_ulong
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator==, operator<<, StringView
#include "../collective/allgather.h" // for AllgatherStrings
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for GetRank, IsFederated
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/math.h" // for CheckNAN
#include "../common/numeric.h" // for Iota, RunLengthEncode
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/version.h" // for Version
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
#include "../data/extmem_quantile_dmatrix.h" // for ExtMemQuantileDMatrix
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/common.h" // for OMPException
#include "dmlc/data.h" // for Parser
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
#include "dmlc/io.h" // for Stream
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "ellpack_page.h" // for EllpackPage
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
#include "gradient_index.h" // for GHistIndexMatrix
#include "simple_dmatrix.h" // for SimpleDMatrix
#include "sparse_page_writer.h" // for SparsePageFormatReg
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
#include "xgboost/base.h" // for bst_group_t, bst_idx_t, bst_float, bst_ulong
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator==, operator<<, StringView
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
@ -909,6 +910,15 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterReset
return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host};
}
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
std::int32_t nthread, bst_bin_t max_bin, std::string cache) {
return new data::ExtMemQuantileDMatrix{
iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin};
}
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref,
@ -922,6 +932,11 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
XGDMatrixCallbackNext* next, float missing,
int32_t n_threads, std::string, bool);
template DMatrix*
DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle, DMatrixHandle, std::shared_ptr<DMatrix>, DataIterResetCallback*,
XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string);
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
DataSplitMode data_split_mode) {

View File

@ -566,7 +566,7 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
CHECK_EQ(h_gidx_buffer->size(), gidx_buffer.size());
CHECK_NE(gidx_buffer.size(), 0);
dh::safe_cuda(cudaMemcpyAsync(h_gidx_buffer->data(), gidx_buffer.data(), gidx_buffer.size_bytes(),
cudaMemcpyDefault, dh::DefaultStream()));
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
return {DeviceOrd::CPU(),
cuts_,
is_dense,

View File

@ -0,0 +1,152 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include "extmem_quantile_dmatrix.h"
#include <memory> // for shared_ptr
#include <string> // for string
#include <utility> // for move
#include <vector> // for vector
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for CheckParam, RegenGHist
#include "proxy_dmatrix.h" // for DataIterProxy, HostAdapterDispatch
#include "quantile_dmatrix.h" // for GetDataShape, MakeSketches
#include "simple_batch_iterator.h" // for SimpleBatchIteratorImpl
#if !defined(XGBOOST_USE_CUDA)
#include "../common/common.h" // for AssertGPUSupport
#endif
namespace xgboost::data {
ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
std::int32_t n_threads, std::string cache,
bst_bin_t max_bin)
: cache_prefix_{std::move(cache)} {
auto iter = std::make_shared<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>(
iter_handle, reset, next);
iter->Reset();
// Fetch the first iter
bool valid = iter->Next();
CHECK(valid) << "Qauntile DMatrix must have at least 1 batch.";
auto pctx = MakeProxy(proxy)->Ctx();
Context ctx;
ctx.Init(Args{{"nthread", std::to_string(n_threads)}, {"device", pctx->DeviceName()}});
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
if (ctx.IsCPU()) {
this->InitFromCPU(&ctx, iter, proxy, p, missing, ref);
} else {
this->InitFromCUDA(&ctx, iter, proxy, p, missing, ref);
}
this->batch_ = p;
this->fmat_ctx_ = ctx;
}
ExtMemQuantileDMatrix::~ExtMemQuantileDMatrix() {
// Clear out all resources before deleting the cache file.
ghist_index_source_.reset();
std::visit([](auto &&ptr) { ptr.reset(); }, ellpack_page_source_);
DeleteCacheFiles(cache_info_);
}
BatchSet<ExtSparsePage> ExtMemQuantileDMatrix::GetExtBatches(Context const *, BatchParam const &) {
LOG(FATAL) << "Not implemented";
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
return BatchSet<ExtSparsePage>{begin_iter};
}
void ExtMemQuantileDMatrix::InitFromCPU(
Context const *ctx,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref) {
auto proxy = MakeProxy(proxy_handle);
CHECK(proxy);
common::HistogramCuts cuts;
ExternalDataInfo ext_info;
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
// From here on Info() has the correct data shape
this->Info().num_row_ = ext_info.accumulated_rows;
this->Info().num_col_ = ext_info.n_features;
this->Info().num_nonzero_ = ext_info.nnz;
this->Info().SynchronizeNumberOfColumns(ctx);
ext_info.Validate();
/**
* Generate quantiles
*/
std::vector<FeatureType> h_ft;
cpu_impl::MakeSketches(ctx, iter.get(), proxy, ref, missing, &cuts, p, this->Info(), ext_info,
&h_ft);
/**
* Generate gradient index
*/
auto id = MakeCache(this, ".gradient_index.page", false, cache_prefix_, &cache_info_);
this->ghist_index_source_ = std::make_unique<ExtGradientIndexPageSource>(
ctx, missing, &this->Info(), ext_info.n_batches, cache_info_.at(id), p, cuts, iter, proxy,
ext_info.base_rows);
/**
* Force initialize the cache and do some sanity checks along the way
*/
bst_idx_t batch_cnt = 0, k = 0;
bst_idx_t n_total_samples = 0;
for (auto const &page : this->GetGradientIndexImpl()) {
n_total_samples += page.Size();
CHECK_EQ(page.base_rowid, ext_info.base_rows[k]);
CHECK_EQ(page.Features(), this->Info().num_col_);
++k, ++batch_cnt;
}
CHECK_EQ(batch_cnt, ext_info.n_batches);
CHECK_EQ(n_total_samples, ext_info.accumulated_rows);
}
BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndexImpl() {
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
}
BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndex(Context const *,
BatchParam const &param) {
if (param.Initialized()) {
detail::CheckParam(this->batch_, param);
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
}
CHECK(this->ghist_index_source_);
this->ghist_index_source_->Reset();
if (!std::isnan(param.sparse_thresh) &&
param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
"of `DMatrix`.";
}
return this->GetGradientIndexImpl();
}
#if !defined(XGBOOST_USE_CUDA)
void ExtMemQuantileDMatrix::InitFromCUDA(
Context const *, std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>,
DMatrixHandle, BatchParam const &, float, std::shared_ptr<DMatrix>) {
common::AssertGPUSupport();
}
BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
const BatchParam &) {
common::AssertGPUSupport();
auto batch_set =
std::visit([this](auto &&ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
this->ellpack_page_source_);
return batch_set;
}
#endif
} // namespace xgboost::data

View File

@ -0,0 +1,24 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <memory> // for shared_ptr
#include <variant> // for visit
#include "extmem_quantile_dmatrix.h"
namespace xgboost::data {
void ExtMemQuantileDMatrix::InitFromCUDA(
Context const *, std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>,
DMatrixHandle, BatchParam const &, float, std::shared_ptr<DMatrix>) {
LOG(FATAL) << "Not implemented.";
}
BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
const BatchParam &) {
LOG(FATAL) << "Not implemented.";
auto batch_set =
std::visit([this](auto &&ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
this->ellpack_page_source_);
return batch_set;
}
} // namespace xgboost::data

View File

@ -0,0 +1,70 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#pragma once
#include <map> // for map
#include <memory> // for shared_ptr
#include <string> // for string
#include <variant> // for variant
#include "ellpack_page_source.h" // for EllpackPageSource, EllpackPageHostSource
#include "gradient_index_page_source.h" // for GradientIndexPageSource
#include "quantile_dmatrix.h" // for QuantileDMatrix, ExternalIter
#include "xgboost/base.h" // for bst_bin_t
#include "xgboost/c_api.h" // for DataIterHandle, DMatrixHandle
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo, BatchParam
namespace xgboost::data {
/**
* @brief A DMatrix class for building a `QuantileDMatrix` from external memory iterator.
*
* This is a combination of `IterativeDMatrix` and the `SparsePageDMatrix` . It builds
* gradient index directly from iterator inputs without going through the `SparsePage`,
* similar to how the `IterativeDMatrix` works. Also, simlar to the `SparsePageDMatrix`,
* it caches the gradient index and fetch them in batches on demand.
*/
class ExtMemQuantileDMatrix : public QuantileDMatrix {
public:
ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, std::int32_t n_threads,
std::string cache, bst_bin_t max_bin);
~ExtMemQuantileDMatrix() override;
[[nodiscard]] bool SingleColBlock() const override { return false; }
private:
void InitFromCPU(
Context const *ctx,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
DMatrixHandle proxy, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref);
void InitFromCUDA(
Context const *ctx,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref);
BatchSet<GHistIndexMatrix> GetGradientIndexImpl();
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const *ctx, BatchParam const &param) override;
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam &param) override;
[[nodiscard]] bool EllpackExists() const override {
return std::visit([](auto &&v) { return static_cast<bool>(v); }, ellpack_page_source_);
}
[[nodiscard]] bool GHistIndexExists() const override { return true; }
[[nodiscard]] BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx,
BatchParam const &param) override;
std::map<std::string, std::shared_ptr<Cache>> cache_info_;
std::string cache_prefix_;
BatchParam batch_;
using EllpackDiskPtr = std::shared_ptr<EllpackPageSource>;
using EllpackHostPtr = std::shared_ptr<EllpackPageHostSource>;
std::variant<EllpackDiskPtr, EllpackHostPtr> ellpack_page_source_;
std::shared_ptr<ExtGradientIndexPageSource> ghist_index_source_;
};
} // namespace xgboost::data

View File

@ -62,7 +62,17 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
hit_count{common::MakeFixedVecWithMalloc(cuts.TotalBins(), std::size_t{0})},
cut{std::forward<common::HistogramCuts>(cuts)},
max_numeric_bins_per_feat(max_bin_per_feat),
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
isDense_{info.IsDense()} {}
GHistIndexMatrix::GHistIndexMatrix(bst_idx_t n_samples, bst_idx_t base_rowid,
common::HistogramCuts &&cuts, bst_bin_t max_bin_per_feat,
bool is_dense)
: row_ptr{common::MakeFixedVecWithMalloc(n_samples + 1, std::size_t{0})},
hit_count{common::MakeFixedVecWithMalloc(cuts.TotalBins(), std::size_t{0})},
cut{std::forward<common::HistogramCuts>(cuts)},
max_numeric_bins_per_feat(max_bin_per_feat),
base_rowid{base_rowid},
isDense_{is_dense} {}
#if !defined(XGBOOST_USE_CUDA)
GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &,
@ -122,6 +132,11 @@ INSTANTIATION_PUSH(data::SparsePageAdapterBatch)
INSTANTIATION_PUSH(data::ColumnarAdapterBatch)
#undef INSTANTIATION_PUSH
void GHistIndexMatrix::ResizeColumns(double sparse_thresh) {
CHECK(!std::isnan(sparse_thresh));
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, sparse_thresh);
}
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) {
// Must resize instead of allocating a new one. This function is called everytime a

View File

@ -7,8 +7,8 @@
#include <algorithm> // for min
#include <atomic> // for atomic
#include <cinttypes> // for uint32_t
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for make_unique
#include <vector>
@ -53,10 +53,10 @@ class GHistIndexMatrix {
}
/**
* \brief Push a page into index matrix, the function is only necessary because hist has
* partial support for external memory.
* @brief Push a sparse page into the index matrix.
*/
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, int32_t n_threads);
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft,
std::int32_t n_threads);
template <typename Batch, typename BinIdxType, typename GetOffset, typename IsValid>
void SetIndexData(common::Span<BinIdxType> index_data_span, size_t rbegin,
@ -135,6 +135,9 @@ class GHistIndexMatrix {
this->GatherHitCount(n_threads, n_bins_total);
}
// The function is only created to avoid using the column matrix in the header.
void ResizeColumns(double sparse_thresh);
public:
/** @brief row pointer to rows by element position */
common::RefResourceView<std::size_t> row_ptr;
@ -157,34 +160,49 @@ class GHistIndexMatrix {
~GHistIndexMatrix();
/**
* \brief Constrcutor for SimpleDMatrix.
* @brief Constrcutor for SimpleDMatrix.
*/
GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch, common::Span<float const> hess = {});
/**
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
* @brief Constructor for Quantile DMatrix. Initialize basic information and prepare
* for push batch.
*/
GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat);
GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts,
bst_bin_t max_bin_per_feat);
/**
* \brief Constructor fro Iterative DMatrix where we might copy an existing ellpack page
* @brief Constructor for the external memory Quantile DMatrix. Initialize basic
* information and prepare for push batch.
*/
GHistIndexMatrix(bst_idx_t n_samples, bst_idx_t base_rowid, common::HistogramCuts&& cuts,
bst_bin_t max_bin_per_feat, bool is_dense);
/**
* @brief Constructor fro Quantile DMatrix where we might copy an existing ellpack page
* to host gradient index.
*/
GHistIndexMatrix(Context const* ctx, MetaInfo const& info, EllpackPage const& page,
BatchParam const& p);
/**
* \brief Constructor for external memory.
* @brief Constructor for external memory.
*/
GHistIndexMatrix(SparsePage const& page, common::Span<FeatureType const> ft,
common::HistogramCuts cuts, int32_t max_bins_per_feat, bool is_dense,
double sparse_thresh, int32_t n_threads);
double sparse_thresh, std::int32_t n_threads);
GHistIndexMatrix(); // also for ext mem, empty ctor so that we can read the cache back.
/**
* @brief Push a single batch into the gradient index.
*
* @param n_samples_total The total number of rows for all batches, create a column
* matrix once all batches are pushed.
*/
template <typename Batch>
void PushAdapterBatch(Context const* ctx, size_t rbegin, size_t prev_sum, Batch const& batch,
float missing, common::Span<FeatureType const> ft, double sparse_thresh,
size_t n_samples_total) {
void PushAdapterBatch(Context const* ctx, std::size_t rbegin, std::size_t prev_sum,
Batch const& batch, float missing, common::Span<FeatureType const> ft,
double sparse_thresh, bst_idx_t n_samples_total) {
auto n_bins_total = cut.TotalBins();
hit_count_tloc_.clear();
hit_count_tloc_.resize(ctx->Threads() * n_bins_total, 0);
@ -200,8 +218,7 @@ class GHistIndexMatrix {
if (rbegin + batch.Size() == n_samples_total) {
// finished
CHECK(!std::isnan(sparse_thresh));
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, sparse_thresh);
this->ResizeColumns(sparse_thresh);
}
}

View File

@ -3,6 +3,12 @@
*/
#include "gradient_index_page_source.h"
#include <memory> // for make_shared
#include <utility> // for move
#include "../common/hist_util.h" // for HistogramCuts
#include "gradient_index.h" // for GHistIndexMatrix
namespace xgboost::data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
@ -18,8 +24,43 @@ void GradientIndexPageSource::Fetch() {
CHECK_EQ(count_, source_->Iter());
auto const& csr = source_->Page();
CHECK_NE(cuts_.Values().size(), 0);
this->page_.reset(new GHistIndexMatrix(*csr, feature_types_, cuts_, max_bin_per_feat_,
is_dense_, sparse_thresh_, nthreads_));
this->page_.reset(new GHistIndexMatrix{*csr, feature_types_, cuts_, max_bin_per_feat_,
is_dense_, sparse_thresh_, nthreads_});
this->WriteCache();
}
}
void ExtGradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
CHECK_EQ(count_, source_->Iter());
++(*source_);
CHECK_GE(source_->Iter(), 1);
CHECK_NE(cuts_.Values().size(), 0);
HostAdapterDispatch(proxy_, [this](auto const& value) {
// This does three things:
// - Generate CSR matrix for gradient index.
// - Generate the column matrix for gradient index.
// - Concatenate the meta info.
common::HistogramCuts cuts{this->cuts_};
this->page_.reset();
// The external iterator has the data when the `next` method is called. Therefore,
// it's one step ahead of this source.
// FIXME(jiamingy): For now, we use the `info->IsDense()` to represent all batches
// similar to the sparse DMatrix source. We should use per-batch property with proxy
// DMatrix info instead. This requires more fine-grained tests.
this->page_ = std::make_shared<GHistIndexMatrix>(
value.NumRows(), this->base_rows_.at(source_->Iter() - 1), std::move(cuts),
this->p_.max_bin, info_->IsDense());
bst_idx_t prev_sum = 0;
bst_idx_t rbegin = 0;
// Use `value.NumRows()` for the size of a single batch. Unlike the
// `IterativeDMatrix`, external memory doesn't concatenate the pages.
this->page_->PushAdapterBatch(ctx_, rbegin, prev_sum, value, this->missing_,
this->feature_types_, this->p_.sparse_thresh, value.NumRows());
this->page_->PushAdapterBatchColumns(ctx_, value, this->missing_, rbegin);
this->info_->Extend(proxy_->Info(), false, false);
});
this->WriteCache();
}
}

View File

@ -8,6 +8,7 @@
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <utility> // for move
#include <vector> // for vector
#include "../common/hist_util.h" // for HistogramCuts
#include "gradient_index.h" // for GHistIndexMatrix
@ -38,6 +39,45 @@ class GHistIndexFormatPolicy {
void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); }
};
template <typename S,
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
class ExtQantileSourceMixin : public SparsePageSourceImpl<S, FormatCreatePolicy> {
protected:
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source_;
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
public:
ExtQantileSourceMixin(float missing, std::int32_t nthreads, bst_feature_t n_features,
bst_idx_t n_batches, std::shared_ptr<Cache> cache)
: Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache} {}
// This function always operate on the source first, then the downstream. The downstream
// can assume the source to be ready.
[[nodiscard]] ExtQantileSourceMixin& operator++() final {
TryLockGuard guard{this->single_threaded_};
// Increment self.
++this->count_;
// Set at end.
this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) {
this->EndIter();
CHECK(this->cache_info_->written);
source_ = nullptr; // release the source
}
this->Fetch();
return *this;
}
void Reset() final {
if (this->source_) {
this->source_->Reset();
}
Super::Reset();
}
};
class GradientIndexPageSource
: public PageSourceIncMixIn<
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
@ -65,5 +105,39 @@ class GradientIndexPageSource
void Fetch() final;
};
class ExtGradientIndexPageSource
: public ExtQantileSourceMixin<
GHistIndexMatrix, DefaultFormatStreamPolicy<GHistIndexMatrix, GHistIndexFormatPolicy>> {
BatchParam p_;
Context const* ctx_;
DMatrixProxy* proxy_;
MetaInfo* info_;
common::Span<FeatureType const> feature_types_;
std::vector<bst_idx_t> base_rows_;
public:
ExtGradientIndexPageSource(
Context const* ctx, float missing, MetaInfo* info, bst_idx_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param, common::HistogramCuts cuts,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
: ExtQantileSourceMixin{missing, ctx->Threads(), static_cast<bst_feature_t>(info->num_col_),
n_batches, cache},
p_{std::move(param)},
ctx_{ctx},
proxy_{proxy},
info_{info},
feature_types_{info_->feature_types.ConstHostSpan()},
base_rows_{std::move(base_rows)} {
this->source_ = source;
this->SetCuts(std::move(cuts));
this->Fetch();
}
void Fetch() final;
};
} // namespace xgboost::data
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_

View File

@ -1,24 +1,24 @@
/**
* Copyright 2022-2023, XGBoost contributors
* Copyright 2022-2024, XGBoost contributors
*/
#include "iterative_dmatrix.h"
#include <algorithm> // for copy
#include <cstddef> // for size_t
#include <memory> // for shared_ptr
#include <type_traits> // for underlying_type_t
#include <vector> // for vector
#include <algorithm> // for copy
#include <cstddef> // for size_t
#include <memory> // for shared_ptr
#include <utility> // for move
#include <vector> // for vector
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for IsDistributed
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
#include "xgboost/data.h" // for FeatureType, DMatrix
#include "../common/categorical.h" // common::IsCat
#include "../common/hist_util.h" // for HistogramCuts
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h" // for GHistIndexMatrix
#include "proxy_dmatrix.h" // for DataIterProxy
#include "quantile_dmatrix.h" // for GetCutsFromRef
#include "quantile_dmatrix.h" // for GetDataShape, MakeSketches
#include "simple_batch_iterator.h" // for SimpleBatchIteratorImpl
#include "xgboost/data.h" // for FeatureType, DMatrix
#include "xgboost/logging.h"
namespace xgboost::data {
@ -51,71 +51,6 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
this->batch_ = p;
}
void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
BatchParam p, common::HistogramCuts* p_cuts) {
CHECK(ref);
CHECK(p_cuts);
p.forbid_regen = true;
// Fetch cuts from GIDX
auto csr = [&] {
for (auto const& page : ref->GetBatches<GHistIndexMatrix>(ctx, p)) {
*p_cuts = page.cut;
break;
}
};
// Fetch cuts from Ellpack.
auto ellpack = [&] {
for (auto const& page : ref->GetBatches<EllpackPage>(ctx, p)) {
GetCutsFromEllpack(page, p_cuts);
break;
}
};
if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) {
// Both exists
if (ctx->IsCUDA()) {
ellpack();
} else {
csr();
}
} else if (ref->PageExists<GHistIndexMatrix>()) {
csr();
} else if (ref->PageExists<EllpackPage>()) {
ellpack();
} else {
// None exist
if (ctx->IsCUDA()) {
ellpack();
} else {
csr();
}
}
CHECK_EQ(ref->Info().num_col_, n_features)
<< "Invalid ref DMatrix, different number of features.";
}
namespace {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(Context const* ctx, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
auto& h_ft = *p_h_ft;
bst_idx_t n_ft = h_ft.size();
collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax));
if (!h_ft.empty()) {
// Check correct size if this is not an empty DMatrix.
CHECK_EQ(h_ft.size(), n_ft);
}
if (n_ft > 0) {
h_ft.resize(n_ft);
auto ptr = reinterpret_cast<std::underlying_type_t<FeatureType>*>(h_ft.data());
collective::SafeColl(
collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax));
}
}
} // anonymous namespace
void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
DataIterHandle iter_handle, float missing,
std::shared_ptr<DMatrix> ref) {
@ -126,135 +61,39 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
common::HistogramCuts cuts;
auto num_rows = [&]() {
return HostAdapterDispatch(proxy, [](auto const& value) { return value.Size(); });
};
auto num_cols = [&]() {
return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); });
};
std::vector<bst_idx_t> column_sizes;
auto const is_valid = data::IsValidFunctor{missing};
auto nnz_cnt = [&]() {
return HostAdapterDispatch(proxy, [&](auto const& value) {
size_t n_threads = ctx->Threads();
size_t n_features = column_sizes.size();
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, DeviceOrd::CPU());
column_sizes_tloc.Data()->Fill(0ul);
auto view = column_sizes_tloc.HostView();
common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) {
auto const& line = value.GetLine(i);
for (size_t j = 0; j < line.Size(); ++j) {
data::COOTuple const& elem = line.GetElement(j);
if (is_valid(elem)) {
view(omp_get_thread_num(), elem.column_idx)++;
}
}
});
auto ptr = column_sizes_tloc.Data()->HostPointer();
auto result = std::accumulate(ptr, ptr + column_sizes_tloc.Size(), static_cast<size_t>(0));
for (size_t tidx = 0; tidx < n_threads; ++tidx) {
for (size_t fidx = 0; fidx < n_features; ++fidx) {
column_sizes[fidx] += view(tidx, fidx);
}
}
return result;
});
};
std::uint64_t n_features = 0;
std::size_t n_batches = 0;
std::uint64_t accumulated_rows{0};
std::uint64_t nnz{0};
/**
* CPU impl needs an additional loop for accumulating the column size.
*/
std::unique_ptr<common::HostSketchContainer> p_sketch;
std::vector<size_t> batch_nnz;
do {
// We use do while here as the first batch is fetched in ctor
if (n_features == 0) {
n_features = num_cols();
collective::SafeColl(collective::Allreduce(ctx, &n_features, collective::Op::kMax));
column_sizes.clear();
column_sizes.resize(n_features, 0);
info_.num_col_ = n_features;
} else {
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
}
size_t batch_size = num_rows();
batch_nnz.push_back(nnz_cnt());
nnz += batch_nnz.back();
accumulated_rows += batch_size;
n_batches++;
} while (iter.Next());
iter.Reset();
ExternalDataInfo ext_info;
cpu_impl::GetDataShape(ctx, proxy, iter, missing, &ext_info);
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
Info().SynchronizeNumberOfColumns(ctx);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
this->Info().num_row_ = ext_info.accumulated_rows;
this->Info().num_col_ = ext_info.n_features;
this->Info().num_nonzero_ = ext_info.nnz;
this->Info().SynchronizeNumberOfColumns(ctx);
ext_info.Validate();
/**
* Generate quantiles
*/
accumulated_rows = 0;
std::vector<FeatureType> h_ft;
if (ref) {
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
h_ft = ref->Info().feature_types.HostVector();
} else {
size_t i = 0;
while (iter.Next()) {
if (!p_sketch) {
h_ft = proxy->Info().feature_types.ConstHostVector();
SyncFeatureType(ctx, &h_ft);
p_sketch = std::make_unique<common::HostSketchContainer>(ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty());
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
// We don't need base row idx here as Info is from proxy and the number of rows in
// it is consistent with data batch.
p_sketch->PushAdapterBatch(batch, 0, proxy->Info(), missing);
});
accumulated_rows += num_rows();
++i;
}
iter.Reset();
CHECK_EQ(accumulated_rows, Info().num_row_);
CHECK(p_sketch);
p_sketch->MakeCuts(ctx, Info(), &cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), n_features);
}
cpu_impl::MakeSketches(ctx, &iter, proxy, ref, missing, &cuts, p, this->Info(), ext_info, &h_ft);
/**
* Generate gradient index.
*/
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), p.max_bin);
this->ghist_ = std::make_unique<GHistIndexMatrix>(this->Info(), std::move(cuts), p.max_bin);
std::size_t rbegin = 0;
std::size_t prev_sum = 0;
std::size_t i = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
proxy->Info().num_nonzero_ = ext_info.batch_nnz[i];
this->ghist_->PushAdapterBatch(ctx, rbegin, prev_sum, batch, missing, h_ft, p.sparse_thresh,
Info().num_row_);
});
if (n_batches != 1) {
if (ext_info.n_batches != 1) {
this->info_.Extend(std::move(proxy->Info()), false, true);
}
size_t batch_size = num_rows();
auto batch_size = BatchSamples(proxy);
prev_sum = this->ghist_->row_ptr[rbegin + batch_size];
rbegin += batch_size;
++i;
@ -266,20 +105,20 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
/**
* Generate column matrix
*/
accumulated_rows = 0;
bst_idx_t accumulated_rows = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
this->ghist_->PushAdapterBatchColumns(ctx, batch, missing, accumulated_rows);
});
accumulated_rows += num_rows();
accumulated_rows += BatchSamples(proxy);
}
iter.Reset();
CHECK_EQ(accumulated_rows, Info().num_row_);
if (n_batches == 1) {
if (ext_info.n_batches == 1) {
this->info_ = std::move(proxy->Info());
this->info_.num_nonzero_ = nnz;
this->info_.num_col_ = n_features; // proxy might be empty.
this->info_.num_nonzero_ = ext_info.nnz;
this->info_.num_col_ = ext_info.n_features; // proxy might be empty.
CHECK_EQ(proxy->Info().labels.Size(), 0);
}
@ -289,7 +128,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(Context const* ctx,
BatchParam const& param) {
if (param.Initialized()) {
CheckParam(param);
detail::CheckParam(this->batch_, param);
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
}
if (!ellpack_ && !ghist_) {
@ -374,9 +213,5 @@ inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const*,
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
}
inline void GetCutsFromEllpack(EllpackPage const&, common::HistogramCuts*) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::data

View File

@ -1,13 +1,15 @@
/**
* Copyright 2020-2024, XGBoost contributors
*/
#include <algorithm>
#include <memory>
#include <algorithm> // for max
#include <memory> // for shared_ptr
#include <utility> // for move
#include <vector> // for vector
#include "../collective/allreduce.h"
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
#include "../common/hist_util.cuh"
#include "batch_utils.h" // for RegenGHist
#include "batch_utils.h" // for RegenGHist, CheckParam
#include "device_adapter.cuh"
#include "ellpack_page.cuh"
#include "iterative_dmatrix.h"
@ -179,7 +181,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
BatchParam const& param) {
if (param.Initialized()) {
CheckParam(param);
detail::CheckParam(this->batch_, param);
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
}
if (!ellpack_ && !ghist_) {
@ -209,8 +211,4 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(begin_iter);
}
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
*cuts = page.Impl()->Cuts();
}
} // namespace xgboost::data

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
* \file iterative_dmatrix.h
*
* \brief Implementation of the higher-level `QuantileDMatrix`.
@ -7,18 +7,13 @@
#ifndef XGBOOST_DATA_ITERATIVE_DMATRIX_H_
#define XGBOOST_DATA_ITERATIVE_DMATRIX_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <memory> // for shared_ptr
#include "../common/error_msg.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
#include "xgboost/base.h"
#include "xgboost/c_api.h"
#include "xgboost/context.h" // for Context
#include "xgboost/data.h"
#include "quantile_dmatrix.h" // for QuantileDMatrix
#include "xgboost/base.h" // for bst_bin_t
#include "xgboost/c_api.h" // for DataIterHandle, DMatrixHandle
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchSet
namespace xgboost {
namespace common {
@ -27,26 +22,14 @@ class HistogramCuts;
namespace data {
/**
* \brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its
* @brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its
* construction process.
*
* `QuantileDMatrix` is an intermediate storage for quantilization results including
* quantile cuts and histogram index. Quantilization is designed to be performed on stream
* of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work
* with batches of data. During initializaion, it walks through the data multiple times
* iteratively in order to perform quantilization. This design helps us reduce memory
* usage significantly by avoiding data concatenation along with removing the CSR matrix
* `SparsePage`. However, it has its limitation (can be fixed if needed):
*
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a
* re-calculation of quantiles for each iteration. We can fix this by retaining a
* reference to the callback if there are feature requests.
*
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
* histogram index while the latter uses only Ellpack.
* During initializaion, it walks through the data multiple times iteratively in order to
* perform quantilization. This design helps us reduce memory usage significantly by
* avoiding data concatenation along with removing the CSR matrix `SparsePage`.
*/
class IterativeDMatrix : public DMatrix {
MetaInfo info_;
class IterativeDMatrix : public QuantileDMatrix {
std::shared_ptr<EllpackPage> ellpack_;
std::shared_ptr<GHistIndexMatrix> ghist_;
BatchParam batch_;
@ -54,19 +37,6 @@ class IterativeDMatrix : public DMatrix {
DMatrixHandle proxy_;
DataIterResetCallback *reset_;
XGDMatrixCallbackNext *next_;
Context fmat_ctx_;
void CheckParam(BatchParam const &param) {
CHECK_EQ(param.max_bin, batch_.max_bin) << error::InconsistentMaxBin();
CHECK(!param.regen && param.hess.empty())
<< "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
}
template <typename Page>
static auto InvalidTreeMethod() {
LOG(FATAL) << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
return BatchSet<Page>(BatchIterator<Page>(nullptr));
}
void InitFromCUDA(Context const *ctx, BatchParam const &p, DataIterHandle iter_handle,
float missing, std::shared_ptr<DMatrix> ref);
@ -82,54 +52,14 @@ class IterativeDMatrix : public DMatrix {
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }
bool GHistIndexExists() const override { return static_cast<bool>(ghist_); }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
return nullptr;
}
DMatrix *SliceCol(int, int) override {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented for `QuantileDMatrix`.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches(Context const *) override {
return InvalidTreeMethod<CSCPage>();
}
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const *) override {
return InvalidTreeMethod<SortedCSCPage>();
}
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const *ctx, BatchParam const &param) override;
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam &param) override;
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const &param) override;
bool SingleColBlock() const override { return true; }
MetaInfo &Info() override { return info_; }
MetaInfo const &Info() const override { return info_; }
Context const *Ctx() const override { return &fmat_ctx_; }
};
/**
* \brief Get quantile cuts from reference (Quantile)DMatrix.
*
* \param ctx The context of the new DMatrix.
* \param ref The reference DMatrix.
* \param n_features Number of features, used for validation only.
* \param p Batch parameter for the new DMatrix.
* \param p_cuts Output quantile cuts.
*/
void GetCutsFromRef(Context const *ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
BatchParam p, common::HistogramCuts *p_cuts);
/**
* \brief Get quantile cuts from ellpack page.
*/
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
} // namespace data
} // namespace xgboost

View File

@ -1,6 +1,5 @@
/**
* Copyright 2021-2023, XGBoost Contributors
* \file proxy_dmatrix.cc
* Copyright 2021-2024, XGBoost Contributors
*/
#include "proxy_dmatrix.h"
@ -12,6 +11,10 @@
#include "xgboost/logging.h"
#include "xgboost/string_view.h" // for StringView
#if !defined(XGBOOST_USE_CUDA)
#include "../common/common.h" // for AssertGPUSupport
#endif
namespace xgboost::data {
void DMatrixProxy::SetColumnarData(StringView interface_str) {
std::shared_ptr<ColumnarAdapter> adapter{new ColumnarAdapter{interface_str}};
@ -48,6 +51,15 @@ std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *, std::shared_ptr
float) {
return nullptr;
}
[[nodiscard]] bst_idx_t BatchSamples(DMatrixProxy const *) {
common::AssertGPUSupport();
return 0;
}
[[nodiscard]] bst_idx_t BatchColumns(DMatrixProxy const *) {
common::AssertGPUSupport();
return 0;
}
#endif // XGBOOST_USE_CUDA
} // namespace cuda_impl

View File

@ -43,5 +43,13 @@ std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
return p_fmat;
});
}
[[nodiscard]] bst_idx_t BatchSamples(DMatrixProxy const* proxy) {
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumRows(); });
}
[[nodiscard]] bst_idx_t BatchColumns(DMatrixProxy const* proxy) {
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumCols(); });
}
} // namespace cuda_impl
} // namespace xgboost::data

View File

@ -4,10 +4,12 @@
#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_
#define XGBOOST_DATA_PROXY_DMATRIX_H_
#include <any> // for any, any_cast
#include <memory>
#include <type_traits> // for invoke_result_t
#include <utility>
#include <algorithm> // for none_of
#include <any> // for any, any_cast
#include <cstdint> // for uint32_t
#include <memory> // for shared_ptr
#include <type_traits> // for invoke_result_t, declval
#include <vector> // for vector
#include "adapter.h"
#include "xgboost/c_api.h"
@ -15,25 +17,45 @@
#include "xgboost/data.h"
namespace xgboost::data {
/*
* \brief A proxy to external iterator.
/**
* @brief A proxy to external iterator.
*/
template <typename ResetFn, typename NextFn>
class DataIterProxy {
DataIterHandle iter_;
ResetFn* reset_;
NextFn* next_;
std::int32_t count_{0};
public:
DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next)
: iter_{iter}, reset_{reset}, next_{next} {}
DataIterProxy(DataIterProxy&& that) = default;
DataIterProxy& operator=(DataIterProxy&& that) = default;
DataIterProxy(DataIterProxy const& that) = default;
DataIterProxy& operator=(DataIterProxy const& that) = default;
bool Next() { return next_(iter_); }
void Reset() { reset_(iter_); }
[[nodiscard]] bool Next() {
bool ret = !!next_(iter_);
if (!ret) {
return ret;
}
count_++;
return ret;
}
void Reset() {
reset_(iter_);
count_ = 0;
}
[[nodiscard]] std::uint32_t Iter() const { return this->count_; }
DataIterProxy& operator++() {
CHECK(this->Next());
return *this;
}
};
/*
* \brief A proxy of DMatrix used by external iterator.
/**
* @brief A proxy of DMatrix used by external iterator.
*/
class DMatrixProxy : public DMatrix {
MetaInfo info_;
@ -116,6 +138,27 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
return typed;
}
/**
* @brief Shape and basic information for data fetched from an external data iterator.
*/
struct ExternalDataInfo {
std::uint64_t n_features = 0; // The number of columns
bst_idx_t n_batches = 0; // The number of batches
bst_idx_t accumulated_rows = 0; // The total number of rows
bst_idx_t nnz = 0; // The number of non-missing values
std::vector<bst_idx_t> column_sizes; // The nnz for each column
std::vector<bst_idx_t> batch_nnz; // nnz for each batch
std::vector<bst_idx_t> base_rows{0}; // base_rowid
void Validate() const {
CHECK(std::none_of(this->column_sizes.cbegin(), this->column_sizes.cend(), [&](auto f) {
return f > this->accumulated_rows;
})) << "Something went wrong during iteration.";
CHECK_GE(this->n_features, 1) << "Data must has at least 1 column.";
}
};
/**
* @brief Dispatch function call based on input type.
*
@ -131,6 +174,7 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
*/
template <bool get_value = true, typename Fn>
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
CHECK(proxy->Adapter().has_value());
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
if constexpr (get_value) {
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
@ -185,5 +229,30 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
*/
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
std::shared_ptr<DMatrixProxy> proxy, float missing);
namespace cuda_impl {
[[nodiscard]] bst_idx_t BatchSamples(DMatrixProxy const*);
[[nodiscard]] bst_idx_t BatchColumns(DMatrixProxy const*);
} // namespace cuda_impl
[[nodiscard]] inline bst_idx_t BatchSamples(DMatrixProxy const* proxy) {
bool type_error = false;
auto n_samples =
HostAdapterDispatch(proxy, [](auto const& value) { return value.NumRows(); }, &type_error);
if (type_error) {
n_samples = cuda_impl::BatchSamples(proxy);
}
return n_samples;
}
[[nodiscard]] inline bst_feature_t BatchColumns(DMatrixProxy const* proxy) {
bool type_error = false;
auto n_features =
HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); }, &type_error);
if (type_error) {
n_features = cuda_impl::BatchColumns(proxy);
}
return n_features;
}
} // namespace xgboost::data
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_

View File

@ -0,0 +1,185 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include "quantile_dmatrix.h"
#include <numeric> // for accumulate
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for IsDistributed
#include "../common/threading_utils.h" // for ParallelFor
#include "gradient_index.h" // for GHistIndexMatrix
#include "xgboost/collective/result.h" // for SafeColl
#include "xgboost/linalg.h" // for Tensor
namespace xgboost::data {
void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
BatchParam p, common::HistogramCuts* p_cuts) {
CHECK(ref);
CHECK(p_cuts);
p.forbid_regen = true;
// Fetch cuts from GIDX
auto csr = [&] {
for (auto const& page : ref->GetBatches<GHistIndexMatrix>(ctx, p)) {
*p_cuts = page.cut;
break;
}
};
// Fetch cuts from Ellpack.
auto ellpack = [&] {
for (auto const& page : ref->GetBatches<EllpackPage>(ctx, p)) {
GetCutsFromEllpack(page, p_cuts);
break;
}
};
if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) {
// Both exists
if (ctx->IsCUDA()) {
ellpack();
} else {
csr();
}
} else if (ref->PageExists<GHistIndexMatrix>()) {
csr();
} else if (ref->PageExists<EllpackPage>()) {
ellpack();
} else {
// None exist
if (ctx->IsCUDA()) {
ellpack();
} else {
csr();
}
}
CHECK_EQ(ref->Info().num_col_, n_features)
<< "Invalid ref DMatrix, different number of features.";
}
#if !defined(XGBOOST_USE_CUDA)
void GetCutsFromEllpack(EllpackPage const&, common::HistogramCuts*) {
common::AssertGPUSupport();
}
#endif
namespace cpu_impl {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(Context const* ctx, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
auto& h_ft = *p_h_ft;
bst_idx_t n_ft = h_ft.size();
collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax));
if (!h_ft.empty()) {
// Check correct size if this is not an empty DMatrix.
CHECK_EQ(h_ft.size(), n_ft);
}
if (n_ft > 0) {
h_ft.resize(n_ft);
auto ptr = reinterpret_cast<std::underlying_type_t<FeatureType>*>(h_ft.data());
collective::SafeColl(
collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax));
}
}
void GetDataShape(Context const* ctx, DMatrixProxy* proxy,
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter, float missing,
ExternalDataInfo* p_info) {
auto& info = *p_info;
auto const is_valid = data::IsValidFunctor{missing};
auto nnz_cnt = [&]() {
return HostAdapterDispatch(proxy, [&](auto const& value) {
bst_idx_t n_threads = ctx->Threads();
bst_idx_t n_features = info.column_sizes.size();
linalg::Tensor<bst_idx_t, 2> column_sizes_tloc({n_threads, n_features}, DeviceOrd::CPU());
column_sizes_tloc.Data()->Fill(0ul);
auto view = column_sizes_tloc.HostView();
common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) {
auto const& line = value.GetLine(i);
for (bst_idx_t j = 0; j < line.Size(); ++j) {
data::COOTuple const& elem = line.GetElement(j);
if (is_valid(elem)) {
view(omp_get_thread_num(), elem.column_idx)++;
}
}
});
auto ptr = column_sizes_tloc.Data()->HostPointer();
auto result = std::accumulate(ptr, ptr + column_sizes_tloc.Size(), static_cast<bst_idx_t>(0));
for (bst_idx_t tidx = 0; tidx < n_threads; ++tidx) {
for (bst_idx_t fidx = 0; fidx < n_features; ++fidx) {
info.column_sizes[fidx] += view(tidx, fidx);
}
}
return result;
});
};
/**
* CPU impl needs an additional loop for accumulating the column size.
*/
do {
// We use do while here as the first batch is fetched in ctor
if (info.n_features == 0) {
info.n_features = BatchColumns(proxy);
collective::SafeColl(collective::Allreduce(ctx, &info.n_features, collective::Op::kMax));
info.column_sizes.clear();
info.column_sizes.resize(info.n_features, 0);
} else {
CHECK_EQ(info.n_features, BatchColumns(proxy)) << "Inconsistent number of columns.";
}
bst_idx_t batch_size = BatchSamples(proxy);
info.batch_nnz.push_back(nnz_cnt());
info.base_rows.push_back(batch_size);
info.nnz += info.batch_nnz.back();
info.accumulated_rows += batch_size;
info.n_batches++;
} while (iter.Next());
iter.Reset();
std::partial_sum(info.base_rows.cbegin(), info.base_rows.cend(), info.base_rows.begin());
}
void MakeSketches(Context const* ctx,
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>* iter,
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, float missing,
common::HistogramCuts* cuts, BatchParam const& p, MetaInfo const& info,
ExternalDataInfo const& ext_info, std::vector<FeatureType>* p_h_ft) {
std::unique_ptr<common::HostSketchContainer> p_sketch;
auto& h_ft = *p_h_ft;
bst_idx_t accumulated_rows = 0;
if (ref) {
GetCutsFromRef(ctx, ref, info.num_col_, p, cuts);
h_ft = ref->Info().feature_types.HostVector();
} else {
size_t i = 0;
while (iter->Next()) {
if (!p_sketch) {
h_ft = proxy->Info().feature_types.ConstHostVector();
cpu_impl::SyncFeatureType(ctx, &h_ft);
p_sketch = std::make_unique<common::HostSketchContainer>(
ctx, p.max_bin, h_ft, ext_info.column_sizes, !proxy->Info().group_ptr_.empty());
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = ext_info.batch_nnz[i];
// We don't need base row idx here as Info is from proxy and the number of rows in
// it is consistent with data batch.
p_sketch->PushAdapterBatch(batch, 0, proxy->Info(), missing);
});
accumulated_rows += BatchSamples(proxy);
++i;
}
iter->Reset();
CHECK_EQ(accumulated_rows, info.num_row_);
CHECK(p_sketch);
p_sketch->MakeCuts(ctx, info, cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), ext_info.n_features);
}
}
} // namespace cpu_impl
} // namespace xgboost::data

View File

@ -0,0 +1,10 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include "ellpack_page.cuh"
namespace xgboost::data {
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
*cuts = page.Impl()->Cuts();
}
} // namespace xgboost::data

107
src/data/quantile_dmatrix.h Normal file
View File

@ -0,0 +1,107 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <vector> // for vector
#include "proxy_dmatrix.h" // for DataIterProxy
#include "xgboost/data.h" // for DMatrix, BatchIterator, SparsePage
#include "xgboost/span.h" // for Span
namespace xgboost::common {
class HistogramCuts;
} // namespace xgboost::common
namespace xgboost::data {
/**
* @brief Base class for quantile-based DMatrix.
*
* `QuantileDMatrix` is an intermediate storage for quantilization results including
* quantile cuts and histogram index. Quantilization is designed to be performed on stream
* of data. In practice, we feed batches of data into the QuantileDMatrix.
*
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a
* re-calculation of quantiles for each iteration. We can fix this by retaining a
* reference to the callback if there are feature requests.
*
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
* histogram index while the latter uses only Ellpack.
*/
class QuantileDMatrix : public DMatrix {
template <typename Page>
static auto InvalidTreeMethod() {
LOG(FATAL) << "Only `hist` tree method can use `QuantileDMatrix`.";
return BatchSet<Page>(BatchIterator<Page>(nullptr));
}
public:
DMatrix *Slice(common::Span<std::int32_t const>) final {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}
DMatrix *SliceCol(std::int32_t, std::int32_t) final {
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
return nullptr;
}
[[nodiscard]] bool SparsePageExists() const final { return false; }
BatchSet<SparsePage> GetRowBatches() final {
LOG(FATAL) << "Not implemented for `QuantileDMatrix`.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches(Context const *) final { return InvalidTreeMethod<CSCPage>(); }
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const *) final {
return InvalidTreeMethod<SortedCSCPage>();
}
[[nodiscard]] MetaInfo &Info() final { return info_; }
[[nodiscard]] MetaInfo const &Info() const final { return info_; }
[[nodiscard]] Context const *Ctx() const final { return &fmat_ctx_; }
protected:
Context fmat_ctx_;
MetaInfo info_;
};
/**
* @brief Get quantile cuts from reference (Quantile)DMatrix.
*
* @param ctx The context of the new DMatrix.
* @param ref The reference DMatrix.
* @param n_features Number of features, used for validation only.
* @param p Batch parameter for the new DMatrix.
* @param p_cuts Output quantile cuts.
*/
void GetCutsFromRef(Context const *ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
BatchParam p, common::HistogramCuts *p_cuts);
/**
* @brief Get quantile cuts from ellpack page.
*/
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
namespace cpu_impl {
void SyncFeatureType(Context const *ctx, std::vector<FeatureType> *p_h_ft);
/**
* @brief Fetch the external data shape.
*/
void GetDataShape(Context const *ctx, DMatrixProxy *proxy,
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter, float missing,
ExternalDataInfo *p_info);
/**
* @brief Create quantile sketch for CPU from an external iterator or from a reference
* DMatrix.
*/
void MakeSketches(Context const *ctx,
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> *iter,
DMatrixProxy *proxy, std::shared_ptr<DMatrix> ref, float missing,
common::HistogramCuts *cuts, BatchParam const &p, MetaInfo const &info,
ExternalDataInfo const &ext_info, std::vector<FeatureType> *p_h_ft);
} // namespace cpu_impl
} // namespace xgboost::data

View File

@ -1,18 +1,16 @@
/*!
* Copyright 2019-2021 XGBoost contributors
/**
* Copyright 2019-2024, XGBoost contributors
*/
#ifndef XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
#define XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_
#include <memory>
#include <utility>
#include <memory> // for shared_ptr
#include <utility> // for move
#include "xgboost/data.h"
#include "xgboost/data.h" // for BatchIteratorImpl
namespace xgboost {
namespace data {
template<typename T>
namespace xgboost::data {
template <typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(std::shared_ptr<T const> page) : page_(std::move(page)) {}
@ -20,7 +18,7 @@ class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
CHECK(page_ != nullptr);
return *page_;
}
SimpleBatchIteratorImpl &operator++() override {
SimpleBatchIteratorImpl& operator++() override {
page_ = nullptr;
return *this;
}
@ -31,7 +29,5 @@ class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
private:
std::shared_ptr<T const> page_{nullptr};
};
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_

View File

@ -16,28 +16,6 @@ MetaInfo &SparsePageDMatrix::Info() { return info_; }
const MetaInfo &SparsePageDMatrix::Info() const { return info_; }
namespace detail {
// Use device dispatch
std::size_t NSamplesDevice(DMatrixProxy *) // NOLINT
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
return 0;
}
#endif
std::size_t NFeaturesDevice(DMatrixProxy *) // NOLINT
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
return 0;
}
#endif
} // namespace detail
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
float missing, int32_t nthreads, std::string cache_prefix,
@ -65,31 +43,12 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
bst_idx_t n_samples = 0;
bst_idx_t nnz = 0;
auto num_rows = [&]() {
bool type_error {false};
size_t n_samples = HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumRows(); }, &type_error);
if (type_error) {
n_samples = detail::NSamplesDevice(proxy);
}
return n_samples;
};
auto num_cols = [&]() {
bool type_error {false};
bst_feature_t n_features = HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumCols(); }, &type_error);
if (type_error) {
n_features = detail::NFeaturesDevice(proxy);
}
return n_features;
};
// the proxy is iterated together with the sparse page source so we can obtain all
// The proxy is iterated together with the sparse page source so we can obtain all
// information in 1 pass.
for (auto const &page : this->GetRowBatchesImpl(&ctx)) {
this->info_.Extend(std::move(proxy->Info()), false, false);
n_features = std::max(n_features, num_cols());
n_samples += num_rows();
n_features = std::max(n_features, BatchColumns(proxy));
n_samples += BatchSamples(proxy);
nnz += page.data.Size();
n_batches++;
}
@ -115,14 +74,7 @@ SparsePageDMatrix::~SparsePageDMatrix() {
sorted_column_source_.reset();
ghist_index_source_.reset();
for (auto const &kv : cache_info_) {
CHECK(kv.second);
auto n = kv.second->ShardName();
if (kv.second->OnHost()) {
continue;
}
TryDeleteCacheFile(n);
}
DeleteCacheFiles(cache_info_);
}
void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
@ -194,7 +146,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
this->InitializeSparsePage(ctx);
cache_info_.erase(id);
MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
LOG(INFO) << "Generating new Gradient Index.";
// Use sorted sketch for approx.
auto sorted_sketch = param.regen;

View File

@ -28,7 +28,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
this->InitializeSparsePage(ctx);
// reinitialize the cache
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
id = MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_);
LOG(INFO) << "Generating new a Ellpack page.";
std::shared_ptr<common::HistogramCuts> cuts;
if (!param.hess.empty()) {

View File

@ -10,7 +10,6 @@
#include <cstdint> // for uint32_t, int32_t
#include <map> // for map
#include <memory> // for shared_ptr
#include <sstream> // for stringstream
#include <string> // for string
#include <variant> // for variant, visit
@ -137,28 +136,5 @@ class SparsePageDMatrix : public DMatrix {
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;
};
[[nodiscard]] inline std::string MakeId(std::string prefix, SparsePageDMatrix *ptr) {
std::stringstream ss;
ss << ptr;
return prefix + "-" + ss.str();
}
/**
* @brief Make cache if it doesn't exist yet.
*/
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, bool on_host,
std::string prefix,
std::map<std::string, std::shared_ptr<Cache>> *out) {
auto &cache_info = *out;
auto name = MakeId(prefix, ptr);
auto id = name + format;
auto it = cache_info.find(id);
if (it == cache_info.cend()) {
cache_info[id].reset(new Cache{false, name, format, on_host});
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName();
}
return id;
}
} // namespace xgboost::data
#endif // XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#include "../common/device_helpers.cuh" // for CurrentDevice
#include "proxy_dmatrix.cuh" // for Dispatch, DMatrixProxy
@ -8,16 +8,6 @@
#include "xgboost/data.h" // for SparsePage
namespace xgboost::data {
namespace detail {
std::size_t NSamplesDevice(DMatrixProxy *proxy) {
return cuda_impl::Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
}
std::size_t NFeaturesDevice(DMatrixProxy *proxy) {
return cuda_impl::Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
}
} // namespace detail
void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
auto device = proxy->Device();
if (!device.IsCUDA()) {

View File

@ -9,6 +9,7 @@
#include <atomic> // for atomic
#include <cstdint> // for uint64_t
#include <future> // for future
#include <map> // for map
#include <memory> // for unique_ptr
#include <mutex> // for mutex
#include <string> // for string
@ -80,6 +81,40 @@ struct Cache {
void Commit();
};
inline void DeleteCacheFiles(std::map<std::string, std::shared_ptr<Cache>> const& cache_info) {
for (auto const& kv : cache_info) {
CHECK(kv.second);
auto n = kv.second->ShardName();
if (kv.second->OnHost()) {
continue;
}
TryDeleteCacheFile(n);
}
}
[[nodiscard]] inline std::string MakeId(std::string prefix, void const* ptr) {
std::stringstream ss;
ss << ptr;
return prefix + "-" + ss.str();
}
/**
* @brief Make cache if it doesn't exist yet.
*/
[[nodiscard]] inline std::string MakeCache(void const* ptr, std::string format, bool on_host,
std::string prefix,
std::map<std::string, std::shared_ptr<Cache>>* out) {
auto& cache_info = *out;
auto name = MakeId(prefix, ptr);
auto id = name + format;
auto it = cache_info.find(id);
if (it == cache_info.cend()) {
cache_info[id].reset(new Cache{false, name, format, on_host});
LOG(INFO) << "Make cache:" << cache_info[id]->ShardName();
}
return id;
}
// Prevents multi-threaded call to `GetBatches`.
class TryLockGuard {
std::mutex& lock_;

View File

@ -0,0 +1,112 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/data.h> // for BatchParam
#include <algorithm> // for equal
#include "../../../src/common/column_matrix.h" // for ColumnMatrix
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h" // for RandomDataGenerator
namespace xgboost::data {
namespace {
class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam<float> {
public:
void Run(float sparsity) {
bst_idx_t n_samples = 256, n_features = 16, n_batches = 4;
bst_bin_t max_bin = 64;
bst_target_t n_targets = 3;
auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}
.Bins(max_bin)
.Batches(n_batches)
.Targets(n_targets)
.GenerateExtMemQuantileDMatrix("temp", true);
ASSERT_FALSE(p_fmat->SingleColBlock());
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
Context ctx;
// Loop over the batches and count the number of pages
bst_idx_t batch_cnt = 0;
bst_idx_t base_cnt = 0;
bst_idx_t row_cnt = 0;
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p)) {
ASSERT_EQ(page.base_rowid, base_cnt);
++batch_cnt;
base_cnt += n_samples / n_batches;
row_cnt += page.Size();
ASSERT_EQ((sparsity == 0.0f), page.IsDense());
}
ASSERT_EQ(n_batches, batch_cnt);
ASSERT_EQ(p_fmat->Info().num_row_, n_samples);
EXPECT_EQ(p_fmat->Info().num_row_, row_cnt);
ASSERT_EQ(p_fmat->Info().num_col_, n_features);
if (sparsity == 0.0f) {
ASSERT_EQ(p_fmat->Info().num_nonzero_, n_samples * n_features);
} else {
ASSERT_LT(p_fmat->Info().num_nonzero_, n_samples * n_features);
ASSERT_GT(p_fmat->Info().num_nonzero_, 0);
}
ASSERT_EQ(p_fmat->Info().labels.Shape(0), n_samples);
ASSERT_EQ(p_fmat->Info().labels.Shape(1), n_targets);
// Compare against the sparse page DMatrix
auto p_sparse = RandomDataGenerator{n_samples, n_features, sparsity}
.Bins(max_bin)
.Batches(n_batches)
.Targets(n_targets)
.GenerateSparsePageDMatrix("temp", true);
auto it = p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p).begin();
for (auto const& page : p_sparse->GetBatches<GHistIndexMatrix>(&ctx, p)) {
auto orig = it.Page();
// Check the CSR matrix
auto orig_cuts = it.Page()->Cuts();
auto sparse_cuts = page.Cuts();
ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values());
ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues());
ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs());
auto orig_ptr = orig->data.data();
auto sparse_ptr = page.data.data();
ASSERT_EQ(orig->data.size(), page.data.size());
auto equal = std::equal(orig_ptr, orig_ptr + orig->data.size(), sparse_ptr);
ASSERT_TRUE(equal);
// Check the column matrix
common::ColumnMatrix const& orig_columns = orig->Transpose();
common::ColumnMatrix const& sparse_columns = page.Transpose();
std::string str_orig, str_sparse;
common::AlignedMemWriteStream fo_orig{&str_orig}, fo_sparse{&str_sparse};
auto n_bytes_orig = orig_columns.Write(&fo_orig);
auto n_bytes_sparse = sparse_columns.Write(&fo_sparse);
ASSERT_EQ(n_bytes_orig, n_bytes_sparse);
ASSERT_EQ(str_orig, str_sparse);
++it;
}
// Check meta info
auto h_y_sparse = p_sparse->Info().labels.HostView();
auto h_y = p_fmat->Info().labels.HostView();
for (std::size_t i = 0, m = h_y_sparse.Shape(0); i < m; ++i) {
for (std::size_t j = 0, n = h_y_sparse.Shape(1); j < n; ++j) {
ASSERT_EQ(h_y(i, j), h_y_sparse(i, j));
}
}
}
};
} // anonymous namespace
TEST_P(ExtMemQuantileDMatrixCpu, Basic) { this->Run(this->GetParam()); }
INSTANTIATE_TEST_SUITE_P(ExtMemQuantileDMatrix, ExtMemQuantileDMatrixCpu, ::testing::ValuesIn([] {
std::vector<float> sparsities{
0.0f, tree::TrainParam::DftSparseThreshold(), 0.4f, 0.8f};
return sparsities;
}()));
} // namespace xgboost::data

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, BatchParam

View File

@ -1,5 +1,5 @@
/**
* Copyright 2016-2024 by XGBoost contributors
* Copyright 2016-2024, XGBoost contributors
*/
#include "helpers.h"
@ -12,6 +12,7 @@
#include <xgboost/objective.h>
#include <algorithm>
#include <limits> // for numeric_limits
#include <random>
#include "../../src/collective/communicator-inl.h" // for GetRank
@ -20,13 +21,13 @@
#include "../../src/data/simple_dmatrix.h"
#include "../../src/data/sparse_page_dmatrix.h"
#include "../../src/gbm/gbtree_model.h"
#include "filesystem.h" // dmlc::TemporaryDirectory
#include "../../src/tree/param.h" // for TrainParam
#include "filesystem.h" // dmlc::TemporaryDirectory
#include "xgboost/c_api.h"
#include "xgboost/predictor.h"
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#include <memory>
#include <numeric>
#include <vector>
#include "rmm/mr/device/per_device_resource.hpp"
#include "rmm/mr/device/cuda_memory_resource.hpp"
@ -466,6 +467,38 @@ void RandomDataGenerator::GenerateCSR(
return dmat;
}
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateExtMemQuantileDMatrix(
std::string prefix, bool with_label) const {
CHECK_GE(this->rows_, this->n_batches_);
CHECK_GE(this->n_batches_, 1)
<< "Must set the n_batches before generating an external memory DMatrix.";
// The iterator should be freed after construction of the DMatrix.
std::unique_ptr<ArrayIterForTest> iter;
if (device_.IsCPU()) {
iter = std::make_unique<NumpyArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
} else {
#if defined(XGBOOST_USE_CUDA)
iter = std::make_unique<CudaArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
#endif // defined(XGBOOST_USE_CUDA)
}
CHECK(iter);
std::shared_ptr<DMatrix> p_fmat{
DMatrix::Create(static_cast<DataIterHandle>(iter.get()), iter->Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 0, this->bins_, prefix)};
auto page_path = data::MakeId(prefix, p_fmat.get()) + ".gradient_index.page";
EXPECT_TRUE(FileExists(page_path)) << page_path;
if (with_label) {
RandomDataGenerator{static_cast<bst_idx_t>(p_fmat->Info().num_row_), this->n_targets_, 0.0f}
.GenerateDense(p_fmat->Info().labels.Data());
CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_);
p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_);
}
return p_fmat;
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
auto m = std::make_shared<data::IterativeDMatrix>(
@ -747,7 +780,7 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
}
}
if (!use_rmm_pool) {
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
return {nullptr, DeleteRMMResource};
}
LOG(INFO) << "Using RMM memory pool";
auto ptr = RMMAllocatorPtr(new RMMAllocator(), DeleteRMMResource);

View File

@ -321,6 +321,9 @@ class RandomDataGenerator {
[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
bool with_label) const;
[[nodiscard]] std::shared_ptr<DMatrix> GenerateExtMemQuantileDMatrix(std::string prefix,
bool with_label) const;
#if defined(XGBOOST_USE_CUDA)
std::shared_ptr<DMatrix> GenerateDeviceDMatrix(bool with_label);
#endif