[EM] CPU implementation for external memory QDM. (#10682)
- A new DMatrix type. - Extract common code into a new QDM base class. Not yet working: - Not exposed to the interface yet, will wait for the GPU implementation. - ~No meta info yet, still working on the source.~ - Exporting data to CSR is not supported yet.
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include <xgboost/string_view.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint> // for int32_t, uint8_t
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@@ -499,8 +500,12 @@ class BatchSet {
|
||||
|
||||
struct XGBAPIThreadLocalEntry;
|
||||
|
||||
/*!
|
||||
* \brief Internal data structured used by XGBoost during training.
|
||||
/**
|
||||
* @brief Internal data structured used by XGBoost to hold all external data.
|
||||
*
|
||||
* There are multiple variants of the DMatrix class and can be accessed through the
|
||||
* @ref Create() methods. The DMatrix itself holds the predictor `X`, and other data
|
||||
* including labels and sample weights are stored in the @ref MetaInfo class.
|
||||
*/
|
||||
class DMatrix {
|
||||
public:
|
||||
@@ -518,13 +523,13 @@ class DMatrix {
|
||||
/*! \brief Get thread local memory for returning data from DMatrix. */
|
||||
[[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
|
||||
/**
|
||||
* \brief Get the context object of this DMatrix. The context is created during construction of
|
||||
* @brief Get the context object of this DMatrix. The context is created during construction of
|
||||
* DMatrix with user specified `nthread` parameter.
|
||||
*/
|
||||
[[nodiscard]] virtual Context const* Ctx() const = 0;
|
||||
|
||||
/**
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
* @brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
*/
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches();
|
||||
@@ -548,27 +553,27 @@ class DMatrix {
|
||||
[[nodiscard]] bool IsDense() const { return this->Info().IsDense(); }
|
||||
|
||||
/**
|
||||
* \brief Load DMatrix from URI.
|
||||
* @brief Load DMatrix from URI.
|
||||
*
|
||||
* \param uri The URI of input.
|
||||
* \param silent Whether print information during loading.
|
||||
* \param data_split_mode Indicate how the data was split beforehand.
|
||||
* \return The created DMatrix.
|
||||
* @param uri The URI of input.
|
||||
* @param silent Whether print information during loading.
|
||||
* @param data_split_mode Indicate how the data was split beforehand.
|
||||
* @return The created DMatrix.
|
||||
*/
|
||||
static DMatrix* Load(const std::string& uri, bool silent = true,
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow);
|
||||
|
||||
/**
|
||||
* \brief Creates a new DMatrix from an external data adapter.
|
||||
* @brief Creates a new DMatrix from an external data adapter.
|
||||
*
|
||||
* \tparam AdapterT Type of the adapter.
|
||||
* \param [in,out] adapter View onto an external data.
|
||||
* \param missing Values to count as missing.
|
||||
* \param nthread Number of threads for construction.
|
||||
* \param cache_prefix (Optional) The cache prefix for external memory.
|
||||
* \param data_split_mode (Optional) Data split mode.
|
||||
* @tparam AdapterT Type of the adapter.
|
||||
* @param [in,out] adapter View onto an external data.
|
||||
* @param missing Values to count as missing.
|
||||
* @param nthread Number of threads for construction.
|
||||
* @param cache_prefix (Optional) The cache prefix for external memory.
|
||||
* @param data_split_mode (Optional) Data split mode.
|
||||
*
|
||||
* \return a Created DMatrix.
|
||||
* @return a Created DMatrix.
|
||||
*/
|
||||
template <typename AdapterT>
|
||||
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
|
||||
@@ -576,29 +581,29 @@ class DMatrix {
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow);
|
||||
|
||||
/**
|
||||
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
|
||||
* @brief Create a new Quantile based DMatrix used for histogram based algorithm.
|
||||
*
|
||||
* \tparam DataIterHandle External iterator type, defined in C API.
|
||||
* \tparam DMatrixHandle DMatrix handle, defined in C API.
|
||||
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
|
||||
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
|
||||
* @tparam DataIterHandle External iterator type, defined in C API.
|
||||
* @tparam DMatrixHandle DMatrix handle, defined in C API.
|
||||
* @tparam DataIterResetCallback Callback for reset, prototype defined in C API.
|
||||
* @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
|
||||
*
|
||||
* \param iter External data iterator
|
||||
* \param proxy A hanlde to ProxyDMatrix
|
||||
* \param ref Reference Quantile DMatrix.
|
||||
* \param reset Callback for reset
|
||||
* \param next Callback for next
|
||||
* \param missing Value that should be treated as missing.
|
||||
* \param nthread number of threads used for initialization.
|
||||
* \param max_bin Maximum number of bins.
|
||||
* @param iter External data iterator
|
||||
* @param proxy A hanlde to ProxyDMatrix
|
||||
* @param ref Reference Quantile DMatrix.
|
||||
* @param reset Callback for reset
|
||||
* @param next Callback for next
|
||||
* @param missing Value that should be treated as missing.
|
||||
* @param nthread number of threads used for initialization.
|
||||
* @param max_bin Maximum number of bins.
|
||||
*
|
||||
* \return A created quantile based DMatrix.
|
||||
* @return A created quantile based DMatrix.
|
||||
*/
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
typename XGDMatrixCallbackNext>
|
||||
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
|
||||
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
|
||||
int nthread, bst_bin_t max_bin);
|
||||
std::int32_t nthread, bst_bin_t max_bin);
|
||||
|
||||
/**
|
||||
* @brief Create an external memory DMatrix with callbacks.
|
||||
@@ -622,9 +627,22 @@ class DMatrix {
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
typename XGDMatrixCallbackNext>
|
||||
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int32_t nthread,
|
||||
XGDMatrixCallbackNext* next, float missing, std::int32_t nthread,
|
||||
std::string cache, bool on_host);
|
||||
|
||||
/**
|
||||
* @brief Create an external memory quantile DMatrix with callbacks.
|
||||
*
|
||||
* Parameters are a combination of the external memory DMatrix and the quantile DMatrix.
|
||||
*
|
||||
* @return A created external memory quantile DMatrix.
|
||||
*/
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
typename XGDMatrixCallbackNext>
|
||||
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
|
||||
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
|
||||
std::int32_t nthread, bst_bin_t max_bin, std::string cache);
|
||||
|
||||
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user