Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -761,6 +761,39 @@ XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
*/
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
bst_ulong *out);
/*!
* \brief Get number of valid values from DMatrix.
*
* \param handle the handle to the DMatrix
* \param out The output of number of non-missing values
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
/*!
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
* quantized DMatrix, quantized values are returned instead.
*
* Unlike most of XGBoost C functions, caller of `XGDMatrixGetDataAsCSR` is required to
* allocate the memory for return buffer instead of using thread local memory from
* XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until
* exiting the thread.
*
* \param handle the handle to the DMatrix
* \param config Json configuration string. At the moment it should be an empty document,
* preserved for future use.
* \param out_indptr indptr of output CSR matrix.
* \param out_indices Column index of output CSR matrix.
* \param out_data Data value of CSR matrix.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
bst_ulong *out_indptr, unsigned *out_indices, float *out_data);
// --- start XGBoost class
/*!
* \brief create xgboost learner

View File

@@ -284,12 +284,17 @@ class SparsePage {
return {offset.ConstHostSpan(), data.ConstHostSpan()};
}
/*! \brief constructor */
SparsePage() {
this->Clear();
}
SparsePage(SparsePage const& that) = delete;
SparsePage(SparsePage&& that) = default;
SparsePage& operator=(SparsePage const& that) = delete;
SparsePage& operator=(SparsePage&& that) = default;
virtual ~SparsePage() = default;
/*! \return Number of instances in the page. */
inline size_t Size() const {
return offset.Size() == 0 ? 0 : offset.Size() - 1;
@@ -358,6 +363,16 @@ class CSCPage: public SparsePage {
explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
};
/**
* \brief Sparse page for exporting DMatrix. Same as SparsePage, just a different type to
* prevent being used internally.
*/
class ExtSparsePage {
public:
std::shared_ptr<SparsePage const> page;
explicit ExtSparsePage(std::shared_ptr<SparsePage const> p) : page{std::move(p)} {}
};
class SortedCSCPage : public SparsePage {
public:
SortedCSCPage() : SparsePage() {}
@@ -610,6 +625,7 @@ class DMatrix {
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
@@ -651,10 +667,15 @@ inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
}
template<>
template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param);
}
template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
return GetExtBatches(BatchParam{});
}
} // namespace xgboost
namespace dmlc {