[EM] Allow staging ellpack on host for GPU external memory. (#10488)

- New parameter `on_host`.
- Abstract format creation and stream creation into policy classes.
This commit is contained in:
Jiaming Yuan
2024-06-28 04:42:18 +08:00
committed by GitHub
parent 824fba783e
commit e8a962575a
36 changed files with 842 additions and 317 deletions

View File

@@ -50,7 +50,7 @@ class MetaInfo {
static constexpr uint64_t kNumField = 12;
/*! \brief number of rows in the data */
uint64_t num_row_{0}; // NOLINT
bst_idx_t num_row_{0}; // NOLINT
/*! \brief number of columns in the data */
uint64_t num_col_{0}; // NOLINT
/*! \brief number of nonzero entries in the data */
@@ -535,10 +535,11 @@ class DMatrix {
template <typename T>
[[nodiscard]] bool PageExists() const;
// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
/**
* @return Whether the data columns single column block.
*/
[[nodiscard]] virtual bool SingleColBlock() const = 0;
/*! \brief virtual destructor */
virtual ~DMatrix();
/**
@@ -600,34 +601,34 @@ class DMatrix {
int nthread, bst_bin_t max_bin);
/**
* \brief Create an external memory DMatrix with callbacks.
* @brief Create an external memory DMatrix with callbacks.
*
* \tparam DataIterHandle External iterator type, defined in C API.
* \tparam DMatrixHandle DMatrix handle, defined in C API.
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
* @tparam DataIterHandle External iterator type, defined in C API.
* @tparam DMatrixHandle DMatrix handle, defined in C API.
* @tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
* \param nthread number of threads used for initialization.
* \param cache Prefix of cache file path.
* @param iter External data iterator
* @param proxy A hanlde to ProxyDMatrix
* @param reset Callback for reset
* @param next Callback for next
* @param missing Value that should be treated as missing.
* @param nthread number of threads used for initialization.
* @param cache Prefix of cache file path.
* @param on_host Used for GPU, whether the data should be cached on host memory.
*
* \return A created external memory DMatrix.
* @return A created external memory DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int32_t nthread, std::string cache);
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int32_t nthread,
std::string cache, bool on_host);
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
/**
* \brief Slice a DMatrix by columns.
* @brief Slice a DMatrix by columns.
*
* @param num_slices Total number of slices
* @param slice_id Index of the current slice