Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -684,9 +684,9 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_row_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -694,9 +694,52 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_col_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// We name the function non-missing instead of non-zero since zero is perfectly valid for XGBoost.
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_nonzero_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
|
||||
float *out_data) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(out_indptr);
|
||||
xgboost_CHECK_C_ARG_PTR(out_indices);
|
||||
xgboost_CHECK_C_ARG_PTR(out_data);
|
||||
|
||||
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
|
||||
|
||||
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
|
||||
CHECK(page.page);
|
||||
auto const &h_offset = page.page->offset.ConstHostVector();
|
||||
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
|
||||
auto pv = page.page->GetView();
|
||||
common::ParallelFor(page.page->data.Size(), p_m->Ctx()->Threads(), [&](std::size_t i) {
|
||||
auto fvalue = pv.data[i].fvalue;
|
||||
auto findex = pv.data[i].index;
|
||||
out_data[i] = fvalue;
|
||||
out_indices[i] = findex;
|
||||
});
|
||||
}
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user