Use __array_interface__ for creating DMatrix from CSR. (#6675)

* Use __array_interface__ for creating DMatrix from CSR.
* Add configuration.
This commit is contained in:
Jiaming Yuan
2021-02-05 21:09:47 +08:00
committed by GitHub
parent 1e949110da
commit dbb5208a0a
7 changed files with 97 additions and 18 deletions

View File

@@ -306,6 +306,12 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {}
size_t Size() const {
size_t size = indptr_.num_rows * indptr_.num_cols;
size = size == 0 ? 0 : size - 1;
return size;
}
Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx);
auto end_offset = indptr_.GetElement<size_t>(idx + 1);

View File

@@ -812,6 +812,9 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(
data::CSRArrayAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix *
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> *adapter,

View File

@@ -209,6 +209,8 @@ template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,