Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -5,16 +5,18 @@
#include <rabit/rabit.h>
#include <algorithm> // std::copy
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "../common/hist_util.h" // common::HistogramCuts
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
namespace xgboost {
namespace data {
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int nthread,
@@ -144,7 +146,6 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
} else {
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
}
size_t batch_size = num_rows();
batch_nnz.push_back(nnz_cnt());
nnz += batch_nnz.back();
@@ -161,6 +162,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
/**
* Generate quantiles
*/
@@ -249,9 +252,47 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
"of `DMatrix`.";
}
auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
for (auto const& page : this->GetGradientIndex(param)) {
auto p_out = std::make_shared<SparsePage>();
p_out->data.Resize(this->Info().num_nonzero_);
p_out->offset.Resize(this->Info().num_row_ + 1);
auto& h_offset = p_out->offset.HostVector();
CHECK_EQ(page.row_ptr.size(), h_offset.size());
std::copy(page.row_ptr.cbegin(), page.row_ptr.cend(), h_offset.begin());
auto& h_data = p_out->data.HostVector();
auto const& vals = page.cut.Values();
auto const& mins = page.cut.MinValues();
auto const& ptrs = page.cut.Ptrs();
auto ft = Info().feature_types.ConstHostSpan();
AssignColumnBinIndex(page, [&](auto bin_idx, std::size_t idx, std::size_t, bst_feature_t fidx) {
float v;
if (common::IsCat(ft, fidx)) {
v = vals[bin_idx];
} else {
v = common::HistogramCuts::NumericBinValue(ptrs, vals, mins, fidx, bin_idx);
}
h_data[idx] = Entry{fidx, v};
});
auto p_ext_out = std::make_shared<ExtSparsePage>(p_out);
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(p_ext_out));
return BatchSet<ExtSparsePage>(begin_iter);
}
LOG(FATAL) << "Unreachable";
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
return BatchSet<ExtSparsePage>(begin_iter);
}
} // namespace data
} // namespace xgboost