Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -761,6 +761,39 @@ XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
|
||||
bst_ulong *out);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Get number of valid values from DMatrix.
|
||||
*
|
||||
* \param handle the handle to the DMatrix
|
||||
* \param out The output of number of non-missing values
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
|
||||
|
||||
/*!
|
||||
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
|
||||
* quantized DMatrix, quantized values are returned instead.
|
||||
*
|
||||
* Unlike most of XGBoost C functions, caller of `XGDMatrixGetDataAsCSR` is required to
|
||||
* allocate the memory for return buffer instead of using thread local memory from
|
||||
* XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until
|
||||
* exiting the thread.
|
||||
*
|
||||
* \param handle the handle to the DMatrix
|
||||
* \param config Json configuration string. At the moment it should be an empty document,
|
||||
* preserved for future use.
|
||||
* \param out_indptr indptr of output CSR matrix.
|
||||
* \param out_indices Column index of output CSR matrix.
|
||||
* \param out_data Data value of CSR matrix.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
bst_ulong *out_indptr, unsigned *out_indices, float *out_data);
|
||||
|
||||
// --- start XGBoost class
|
||||
/*!
|
||||
* \brief create xgboost learner
|
||||
|
||||
@@ -284,12 +284,17 @@ class SparsePage {
|
||||
return {offset.ConstHostSpan(), data.ConstHostSpan()};
|
||||
}
|
||||
|
||||
|
||||
/*! \brief constructor */
|
||||
SparsePage() {
|
||||
this->Clear();
|
||||
}
|
||||
|
||||
SparsePage(SparsePage const& that) = delete;
|
||||
SparsePage(SparsePage&& that) = default;
|
||||
SparsePage& operator=(SparsePage const& that) = delete;
|
||||
SparsePage& operator=(SparsePage&& that) = default;
|
||||
virtual ~SparsePage() = default;
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
inline size_t Size() const {
|
||||
return offset.Size() == 0 ? 0 : offset.Size() - 1;
|
||||
@@ -358,6 +363,16 @@ class CSCPage: public SparsePage {
|
||||
explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Sparse page for exporting DMatrix. Same as SparsePage, just a different type to
|
||||
* prevent being used internally.
|
||||
*/
|
||||
class ExtSparsePage {
|
||||
public:
|
||||
std::shared_ptr<SparsePage const> page;
|
||||
explicit ExtSparsePage(std::shared_ptr<SparsePage const> p) : page{std::move(p)} {}
|
||||
};
|
||||
|
||||
class SortedCSCPage : public SparsePage {
|
||||
public:
|
||||
SortedCSCPage() : SparsePage() {}
|
||||
@@ -610,6 +625,7 @@ class DMatrix {
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool GHistIndexExists() const = 0;
|
||||
@@ -651,10 +667,15 @@ inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetEllpackBatches(param);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetGradientIndex(param);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
|
||||
return GetExtBatches(BatchParam{});
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace dmlc {
|
||||
|
||||
Reference in New Issue
Block a user