Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2019~2020 by Contributors
|
||||
* Copyright (c) 2019~2021 by Contributors
|
||||
* \file adapter.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ADAPTER_H_
|
||||
@@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
|
||||
size_t num_columns_;
|
||||
};
|
||||
|
||||
class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface array_interface_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
size_t ridx_;
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ArrayAdapterBatch() = default;
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto line = array_interface_.SliceRow(idx);
|
||||
return Line{line, idx};
|
||||
}
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to
|
||||
* `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
|
||||
public:
|
||||
explicit ArrayAdapter(StringView array_interface) {
|
||||
auto j = Json::Load(array_interface);
|
||||
array_interface_ = ArrayInterface(get<Object const>(j));
|
||||
batch_ = ArrayAdapterBatch{array_interface_};
|
||||
}
|
||||
ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumColumns() const { return array_interface_.num_cols; }
|
||||
|
||||
private:
|
||||
ArrayAdapterBatch batch_;
|
||||
ArrayInterface array_interface_;
|
||||
};
|
||||
|
||||
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
|
||||
}
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CSRArrayAdapterBatch() = default;
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
auto indices = indices_.SliceOffset(begin_offset);
|
||||
auto values = values_.SliceOffset(begin_offset);
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
return Line{indices, values, idx};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is
|
||||
* similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
|
||||
public:
|
||||
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
||||
size_t num_cols)
|
||||
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
|
||||
}
|
||||
|
||||
CSRArrayAdapterBatch const& Value() const override {
|
||||
return batch_;
|
||||
}
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_cols * indptr_.num_rows;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
size_t NumColumns() const { return num_cols_; }
|
||||
|
||||
private:
|
||||
CSRArrayAdapterBatch batch_;
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t num_cols_;
|
||||
};
|
||||
|
||||
class CSCAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,
|
||||
|
||||
Reference in New Issue
Block a user