Support numpy array interface (#6998)

This commit is contained in:
Jiaming Yuan
2021-05-27 16:08:22 +08:00
committed by GitHub
parent ab6fd304c4
commit 4cf95a6041
6 changed files with 59 additions and 38 deletions

View File

@@ -231,6 +231,10 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
};
class ArrayAdapterBatch : public detail::NoMetaInfo {
public:
static constexpr bool kIsRowMajor = true;
private:
ArrayInterface array_interface_;
class Line {
@@ -253,6 +257,7 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
Line const GetLine(size_t idx) const {
return Line{array_interface_, idx};
}
size_t Size() const { return array_interface_.num_rows; }
explicit ArrayAdapterBatch(ArrayInterface array_interface)
: array_interface_{std::move(array_interface)} {}

View File

@@ -803,6 +803,9 @@ DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
template DMatrix* DMatrix::Create<data::DenseAdapter>(
data::DenseAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::ArrayAdapter>(
data::ArrayAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::CSRAdapter>(
data::CSRAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
@@ -1037,6 +1040,8 @@ void SparsePage::PushCSC(const SparsePage &batch) {
template uint64_t
SparsePage::Push(const data::DenseAdapterBatch& batch, float missing, int nthread);
template uint64_t
SparsePage::Push(const data::ArrayAdapterBatch& batch, float missing, int nthread);
template uint64_t
SparsePage::Push(const data::CSRAdapterBatch& batch, float missing, int nthread);
template uint64_t
SparsePage::Push(const data::CSCAdapterBatch& batch, float missing, int nthread);

View File

@@ -203,6 +203,8 @@ void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing,