Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -684,9 +684,9 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_row_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -694,9 +694,52 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_col_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// We name the function non-missing instead of non-zero since zero is perfectly valid for XGBoost.
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_nonzero_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
|
||||
float *out_data) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(out_indptr);
|
||||
xgboost_CHECK_C_ARG_PTR(out_indices);
|
||||
xgboost_CHECK_C_ARG_PTR(out_data);
|
||||
|
||||
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
|
||||
|
||||
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
|
||||
CHECK(page.page);
|
||||
auto const &h_offset = page.page->offset.ConstHostVector();
|
||||
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
|
||||
auto pv = page.page->GetView();
|
||||
common::ParallelFor(page.page->data.Size(), p_m->Ctx()->Threads(), [&](std::size_t i) {
|
||||
auto fvalue = pv.data[i].fvalue;
|
||||
auto findex = pv.data[i].index;
|
||||
out_data[i] = fvalue;
|
||||
out_indices[i] = findex;
|
||||
});
|
||||
}
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@@ -6,14 +6,16 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <memory> // std::shared_ptr
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h" // DMatrix
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
@@ -259,5 +261,17 @@ auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
|
||||
}
|
||||
return dft;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get shared ptr from DMatrix C handle with additional checks.
|
||||
*/
|
||||
inline std::shared_ptr<DMatrix> CastDMatrixHandle(DMatrixHandle const handle) {
|
||||
auto pp_m = static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
StringView msg{"Invalid DMatrix handle"};
|
||||
CHECK(pp_m) << msg;
|
||||
auto p_m = *pp_m;
|
||||
CHECK(p_m) << msg;
|
||||
return p_m;
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user