Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -684,9 +684,9 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_row_);
API_END();
}
@@ -694,9 +694,52 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_col_);
API_END();
}
// We name the function non-missing instead of non-zero since zero is perfectly valid for XGBoost.
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_nonzero_);
API_END();
}
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
float *out_data) {
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out_indptr);
xgboost_CHECK_C_ARG_PTR(out_indices);
xgboost_CHECK_C_ARG_PTR(out_data);
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
CHECK(page.page);
auto const &h_offset = page.page->offset.ConstHostVector();
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
auto pv = page.page->GetView();
common::ParallelFor(page.page->data.Size(), p_m->Ctx()->Threads(), [&](std::size_t i) {
auto fvalue = pv.data[i].fvalue;
auto findex = pv.data[i].index;
out_data[i] = fvalue;
out_indices[i] = findex;
});
}
API_END();
}