[EM] CPU implementation for external memory QDM. (#10682)

- A new DMatrix type.
- Extract common code into a new QDM base class.

Not yet working:
- Not exposed to the interface yet, will wait for the GPU implementation.
- ~No meta info yet, still working on the source.~
- Exporting data to CSR is not supported yet.
This commit is contained in:
Jiaming Yuan
2024-08-09 09:38:02 +08:00
committed by GitHub
parent ac8366654b
commit 7bccc1ea2c
33 changed files with 1198 additions and 497 deletions

View File

@@ -17,6 +17,7 @@
#include <xgboost/string_view.h>
#include <algorithm>
#include <cstdint> // for int32_t, uint8_t
#include <limits>
#include <memory>
#include <string>
@@ -499,8 +500,12 @@ class BatchSet {
struct XGBAPIThreadLocalEntry;
/*!
* \brief Internal data structured used by XGBoost during training.
/**
* @brief Internal data structured used by XGBoost to hold all external data.
*
* There are multiple variants of the DMatrix class and can be accessed through the
* @ref Create() methods. The DMatrix itself holds the predictor `X`, and other data
* including labels and sample weights are stored in the @ref MetaInfo class.
*/
class DMatrix {
public:
@@ -518,13 +523,13 @@ class DMatrix {
/*! \brief Get thread local memory for returning data from DMatrix. */
[[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
/**
* \brief Get the context object of this DMatrix. The context is created during construction of
* @brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
[[nodiscard]] virtual Context const* Ctx() const = 0;
/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
* @brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
template <typename T>
BatchSet<T> GetBatches();
@@ -548,27 +553,27 @@ class DMatrix {
[[nodiscard]] bool IsDense() const { return this->Info().IsDense(); }
/**
* \brief Load DMatrix from URI.
* @brief Load DMatrix from URI.
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode Indicate how the data was split beforehand.
* \return The created DMatrix.
* @param uri The URI of input.
* @param silent Whether print information during loading.
* @param data_split_mode Indicate how the data was split beforehand.
* @return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri, bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Creates a new DMatrix from an external data adapter.
* @brief Creates a new DMatrix from an external data adapter.
*
* \tparam AdapterT Type of the adapter.
* \param [in,out] adapter View onto an external data.
* \param missing Values to count as missing.
* \param nthread Number of threads for construction.
* \param cache_prefix (Optional) The cache prefix for external memory.
* \param data_split_mode (Optional) Data split mode.
* @tparam AdapterT Type of the adapter.
* @param [in,out] adapter View onto an external data.
* @param missing Values to count as missing.
* @param nthread Number of threads for construction.
* @param cache_prefix (Optional) The cache prefix for external memory.
* @param data_split_mode (Optional) Data split mode.
*
* \return a Created DMatrix.
* @return a Created DMatrix.
*/
template <typename AdapterT>
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
@@ -576,29 +581,29 @@ class DMatrix {
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
* @brief Create a new Quantile based DMatrix used for histogram based algorithm.
*
* \tparam DataIterHandle External iterator type, defined in C API.
* \tparam DMatrixHandle DMatrix handle, defined in C API.
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
* @tparam DataIterHandle External iterator type, defined in C API.
* @tparam DMatrixHandle DMatrix handle, defined in C API.
* @tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param ref Reference Quantile DMatrix.
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
* \param nthread number of threads used for initialization.
* \param max_bin Maximum number of bins.
* @param iter External data iterator
* @param proxy A hanlde to ProxyDMatrix
* @param ref Reference Quantile DMatrix.
* @param reset Callback for reset
* @param next Callback for next
* @param missing Value that should be treated as missing.
* @param nthread number of threads used for initialization.
* @param max_bin Maximum number of bins.
*
* \return A created quantile based DMatrix.
* @return A created quantile based DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin);
std::int32_t nthread, bst_bin_t max_bin);
/**
* @brief Create an external memory DMatrix with callbacks.
@@ -622,9 +627,22 @@ class DMatrix {
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int32_t nthread,
XGDMatrixCallbackNext* next, float missing, std::int32_t nthread,
std::string cache, bool on_host);
/**
* @brief Create an external memory quantile DMatrix with callbacks.
*
* Parameters are a combination of the external memory DMatrix and the quantile DMatrix.
*
* @return A created external memory quantile DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
std::int32_t nthread, bst_bin_t max_bin, std::string cache);
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
/**