Use __array_interface__ for creating DMatrix from CSR. (#6675)
* Use __array_interface__ for creating DMatrix from CSR. * Add configuration.
This commit is contained in:
@@ -246,6 +246,21 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
xgboost::bst_ulong ncol,
|
||||
char const* c_json_config,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
||||
StringView{data}, ncol);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = get<Number const>(config["missing"]);
|
||||
auto nthread = get<Integer const>(config["nthread"]);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
|
||||
@@ -306,6 +306,12 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
|
||||
size_t Size() const {
|
||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
|
||||
@@ -812,6 +812,9 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||
data::FileAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(
|
||||
data::CSRArrayAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix *
|
||||
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||
XGBoostBatchCSR> *adapter,
|
||||
|
||||
@@ -209,6 +209,8 @@ template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||
|
||||
Reference in New Issue
Block a user